import React from 'react';
import {ScrollableFlex, WrapperFlex} from '../../utils/scrolling';
import {
  TableContainer,
  Table,
  Text,
  Thead,
  Tr,
  Th,
  Tbody,
  Td,
  HStack,
  Box,
  Button,
  Tag,
  useDisclosure,
  Tooltip,
  Skeleton,
  Menu,
  MenuButton,
  MenuList,
  MenuItem,
  Link,
  VStack,
} from '@chakra-ui/react';
import {
  useReactTable,
  getCoreRowModel,
  flexRender,
  createColumnHelper,
} from '@tanstack/react-table';
import {
  useCreateModelTestControlList,
  useDeleteModelTestList,
  useGetModelTestDefinition,
  useGetModelTestListResults,
  useGetModelTestResults,
  useRefreshModelTestControlList,
} from '../../../hooks/api/scoringModel';
import {useMarketFromContext} from '../MarketProvider';
import _ from 'lodash';
import {
  ModelTestResults,
  ModelTestSignalCounts,
  TierCounts,
} from '../../../shared/scoring';
import {ObjectId} from 'bson';
import {TierPieChart} from './TierPieChart';
import {AccountTier, AccountTiers} from '../../../shared/scoredAccounts';
import {UseQueryResult} from '@tanstack/react-query';
import {DelayedSpinner} from '../../DelayedSpinner';
import {ModelTestListImportModal} from './ModelTestListImport';
import {MaxModelTestLists} from '../../../shared/api/definitions';
import {Renew, TrashCan} from '@carbon/icons-react';
import {Parser} from '@json2csv/plainjs';
import FileSaver from 'file-saver';
import {useCustomer} from '../../../hooks/api/metadata';
import {useScoringSignalResolver} from '../../../hooks/scoringSignal';
import {ScoringSignalResolver} from '../../../shared/signals';
import {pluralize} from '../../../lib/helpers';
import {TestList} from '../../../shared/testLists';
import {WithId} from '../../../shared/util';

const ListResultsQueryContext =
  React.createContext<UseQueryResult<TierCounts> | null>(null);
const useListResultsQuery = () => React.useContext(ListResultsQueryContext);

const ListResultsQueryProvider = ({
  listId,
  children,
}: React.PropsWithChildren<{listId: ObjectId}>) => {
  const {id: marketId} = useMarketFromContext();
  const query = useGetModelTestListResults({
    listId,
    marketId,
  });

  return (
    <ListResultsQueryContext.Provider value={query}>
      {children}
    </ListResultsQueryContext.Provider>
  );
};

const ConditionalListQueryProvider = ({
  list,
  children,
}: React.PropsWithChildren<{list: WithId<TestList>}>) => {
  if (list.status.type === 'processing') {
    return <>{children}</>;
  }

  return (
    <ListResultsQueryProvider listId={list._id}>
      {children}
    </ListResultsQueryProvider>
  );
};

const TierCell = ({
  tier,
  isLastColumn,
}: {
  tier: AccountTier;
  isLastColumn?: boolean;
}) => {
  const query = useListResultsQuery();
  const counts = query?.data;

  if (!query) {
    return <></>;
  }

  if (!counts || query.isFetching) {
    return (
      <Skeleton
        isLoaded={false}
        style={{width: 'auto', display: 'inline-block'}}
      >
        00%
      </Skeleton>
    );
  }

  let percentage;
  const sum = _.sum(Object.values(counts));
  if (!isLastColumn) {
    percentage = Math.round((counts[tier] / sum) * 100);
  } else {
    percentage = 100;

    // for the last column, we want to ensure that the sum of all rounded the percentages equals 100
    for (const currentTier of AccountTiers) {
      if (currentTier === tier) {
        continue;
      }

      percentage -= Math.round((counts[currentTier] / sum) * 100);
    }
  }

  return <>{percentage}%</>;
};

const DeleteCell = ({listId}: {listId: ObjectId}) => {
  const deleteList = useDeleteModelTestList();
  const {id: marketId} = useMarketFromContext();
  return (
    <Button
      key="delete"
      colorScheme="red"
      variant="outline"
      isDisabled={deleteList.isLoading}
      onClick={() => deleteList.mutate({listId, marketId})}
    >
      <TrashCan />
    </Button>
  );
};

const ListStatus = ({list}: {list: TestList}) => {
  const query = useListResultsQuery();
  const tierCounts = query?.data;

  return list.status.type === 'processing' ? (
    <Tag
      key="default"
      size={'md'}
      fontWeight={500}
      fontSize={12}
      textTransform="uppercase"
    >
      Processing
    </Tag>
  ) : tierCounts && !query?.isFetching ? (
    <TierPieChart tierCounts={tierCounts} />
  ) : (
    <></>
  );
};

const columnHelper = createColumnHelper<WithId<TestList>>();
const columns = [
  columnHelper.accessor((row) => row, {
    id: 'name',
    header: 'List Name',
    size: 300,
    cell: function Cell(info) {
      const list = info.getValue();
      const {status, label, timestamp, type, _id: listId} = list;
      const {data: tierCounts, isFetching} = useListResultsQuery() ?? {};
      const {samDefinition, id: marketId} = useMarketFromContext();
      const refreshModelTestControlList = useRefreshModelTestControlList();

      const tooltipLabel =
        !tierCounts || isFetching
          ? 'Loading...'
          : pluralize(_.sum(Object.values(tierCounts)), 'account', 'accounts');

      const needsRefresh =
        type === 'samSample' &&
        samDefinition.lastModified &&
        samDefinition.lastModified > (list.lastRefreshed ?? timestamp);

      return (
        <HStack spacing={3}>
          <Tooltip
            label={status.type !== 'processing' ? tooltipLabel : ''}
            shouldWrapChildren={true}
            placement="bottom-start"
          >
            <Text>{label}</Text>
          </Tooltip>
          {needsRefresh && (
            <Tooltip label="SAM definition has been updated. Refresh recommended.">
              <Button
                onClick={() =>
                  refreshModelTestControlList.mutate({marketId, listId})
                }
                colorScheme="blue"
                variant="ghost"
                isLoading={refreshModelTestControlList.isLoading}
              >
                <Renew />
              </Button>
            </Tooltip>
          )}
        </HStack>
      );
    },
  }),
  columnHelper.accessor((row) => row, {
    id: 'pie',
    header: '',
    cell: (info) => <ListStatus list={info.getValue()} />,
    size: 125,
  }),
  columnHelper.accessor((row) => row, {
    id: 'tierA',
    header: 'A',
    cell: () => <TierCell tier="A" />,
    size: 100,
  }),
  columnHelper.accessor((row) => row, {
    id: 'tierB',
    header: 'B',
    cell: () => <TierCell tier="B" />,
    size: 100,
  }),
  columnHelper.accessor((row) => row, {
    id: 'tierC',
    header: 'C',
    cell: () => <TierCell tier="C" />,
    size: 100,
  }),
  columnHelper.accessor((row) => row, {
    id: 'tierD',
    header: 'D',
    cell: () => <TierCell tier="D" isLastColumn={true} />,
    size: 100,
  }),
  columnHelper.accessor('_id', {
    header: '',
    id: 'delete',
    size: 50,
    cell: (info) => <DeleteCell listId={info.getValue()} />,
  }),
];

const ModelTestTable = ({lists}: {lists: WithId<TestList>[]}) => {
  const table = useReactTable({
    columns,
    data: lists,
    getCoreRowModel: getCoreRowModel(),
  });

  return (
    <TableContainer
      border="1px"
      borderColor="kgray.200"
      bgColor="kgray.50"
      borderRadius="md"
      width="100%"
    >
      <Table variant="simple">
        <Thead bg="kgray.100">
          {table.getHeaderGroups().map((headerGroup) => (
            <Tr key={headerGroup.id}>
              {headerGroup.headers.map((header) => (
                <Th key={header.id}>
                  {flexRender(
                    header.column.columnDef.header,
                    header.getContext()
                  )}
                </Th>
              ))}
            </Tr>
          ))}
        </Thead>
        <Tbody>
          {table.getRowModel().rows.map((row) => (
            <ConditionalListQueryProvider key={row.id} list={row.original}>
              <Tr role="list" height="100px">
                {row.getVisibleCells().map((cell) => (
                  <Td
                    key={cell.id}
                    _groupHover={{bg: 'kgray.100'}}
                    width={
                      cell.column.getSize()
                        ? `${cell.column.getSize()}px`
                        : 'auto'
                    }
                  >
                    {flexRender(cell.column.columnDef.cell, cell.getContext())}
                  </Td>
                ))}
              </Tr>
            </ConditionalListQueryProvider>
          ))}
        </Tbody>
      </Table>
    </TableContainer>
  );
};

const downloadResultsCsv = async (results: ModelTestResults, name: string) => {
  const csv = new Parser().parse(
    results.map(({domain, score, tier, lists}) => {
      return {
        Domain: domain,
        'Overall Fit': score.toFixed(2),
        Tier: tier,
        'Test Lists': lists.join(', '),
      };
    })
  );
  const blob = new Blob([csv], {type: 'text/csv'});
  FileSaver.saveAs(blob, `testResults-${name}`);
};

const downloadSignalCountsCsv = async (
  {
    signalCounts,
    listCounts,
    name,
  }: {
    signalCounts: ModelTestSignalCounts;
    listCounts: Record<string, number>;
    name: string;
  },
  resolver: ScoringSignalResolver
) => {
  if (!signalCounts) {
    return;
  }

  const listCountRow = {
    'Signal Category': '',
    'Signal Name': 'List Count',
    ...listCounts,
  };

  const signalCountRows = signalCounts.map(({signal, listCounts}) => {
    const resolvedSignal = resolver(signal);

    return {
      'Signal Category': resolvedSignal?.category ?? '',
      'Signal Name': resolvedSignal?.label ?? signal,
      ...listCounts,
    };
  });

  const csv = new Parser().parse([listCountRow, ...signalCountRows]);
  const blob = new Blob([csv], {type: 'text/csv'});
  FileSaver.saveAs(blob, `signalCounts-${name}`);
};

export const ModelTest = () => {
  const resolver = useScoringSignalResolver();
  const customer = useCustomer();
  const {isOpen, onOpen, onClose} = useDisclosure();
  const createControlList = useCreateModelTestControlList();
  const {id: marketId, label: marketLabel} = useMarketFromContext();
  const {data, isInitialLoading} = useGetModelTestDefinition({
    marketId,
  });
  const [resultsType, setResultsType] = React.useState<
    'results' | 'counts' | null
  >(null);

  const lists = data?.lists ?? [];
  const tooManyLists = lists.length >= MaxModelTestLists;
  const hasControlList = lists.some(({type}) => type === 'samSample');
  const pendingLists = lists.some(({status}) => status.type === 'processing');

  const {refetch: fetchModelTestResults, isFetching: fetchingModelTestResults} =
    useGetModelTestResults(
      {
        marketId,
      },
      {
        enabled: false,
        onSuccess(data) {
          const name = `${customer.name}-${marketLabel}`;
          if (resultsType === 'results') {
            downloadResultsCsv(data.results, name);
          } else if (resultsType === 'counts') {
            const {signalCounts, listCounts} = data;
            downloadSignalCountsCsv({signalCounts, listCounts, name}, resolver);
          }
          setResultsType(null);
        },
      }
    );

  return (
    <>
      <WrapperFlex>
        <ScrollableFlex px={6}>
          <Box
            mt={6}
            width="50%"
            minWidth={1000}
            alignSelf="center"
            display="flex"
          >
            {isInitialLoading && (
              <HStack width="100%" justifyContent="center">
                <DelayedSpinner />
              </HStack>
            )}
            {!isInitialLoading &&
              (lists.length ? (
                <VStack width="100%" alignItems="flex-start">
                  <HStack mb={4}>
                    <Tooltip
                      isDisabled={!tooManyLists}
                      label={`Maximum of ${MaxModelTestLists} groups allowed`}
                    >
                      <Button
                        isDisabled={isInitialLoading || tooManyLists}
                        onClick={onOpen}
                        colorScheme="kbuttonblue"
                        fontSize="sm"
                        fontWeight="normal"
                      >
                        Add Test List
                      </Button>
                    </Tooltip>
                    <Button
                      isDisabled={
                        isInitialLoading || tooManyLists || hasControlList
                      }
                      isLoading={createControlList.isLoading}
                      onClick={() => createControlList.mutate({marketId})}
                      colorScheme="kbuttonblue"
                      variant="outline"
                      fontSize="sm"
                      fontWeight="normal"
                    >
                      Add Control List
                    </Button>
                    {lists.length && (
                      <Menu>
                        <Tooltip
                          isDisabled={!pendingLists}
                          label="Please wait for lists to finish processing"
                        >
                          <MenuButton
                            as={Button}
                            isDisabled={pendingLists}
                            colorScheme="kbuttonblue"
                            fontSize="sm"
                            fontWeight="normal"
                            isLoading={fetchingModelTestResults}
                            variant="outline"
                          >
                            Download
                          </MenuButton>
                        </Tooltip>

                        <MenuList zIndex={2} minWidth={32}>
                          <MenuItem
                            onClick={() => {
                              setResultsType('results');
                              fetchModelTestResults();
                            }}
                            pl={2}
                          >
                            Results
                          </MenuItem>
                          <MenuItem
                            onClick={() => {
                              setResultsType('counts');
                              fetchModelTestResults();
                            }}
                            pl={2}
                          >
                            Signal Counts
                          </MenuItem>
                        </MenuList>
                      </Menu>
                    )}
                  </HStack>
                  <ModelTestTable lists={lists} />
                </VStack>
              ) : (
                <Box
                  px={6}
                  pt={6}
                  bgColor="kgray.50"
                  borderColor="kgray.200"
                  borderWidth="1px"
                  borderRadius="lg"
                  width="100%"
                  display="flex"
                  justifyContent="center"
                >
                  <Text
                    pb={6}
                    lineHeight="1.5"
                    fontSize="lg"
                    textAlign="center"
                    mx={8}
                  >
                    To get started,{' '}
                    <Link onClick={onOpen} color="kblue.300">
                      upload a test list
                    </Link>
                    . For best results we recommend lists of size 500-2000, with
                    separate lists for different customer segments as well as a
                    list of disqualified accounts.
                  </Text>
                </Box>
              ))}
          </Box>
        </ScrollableFlex>
      </WrapperFlex>
      <ModelTestListImportModal isOpen={isOpen} onClose={onClose} />
    </>
  );
};
