import { TFunction, Trans, useTranslation } from 'react-i18next';
import { useHistory, useParams } from 'react-router';

import {
  HeatmapBandsColorScale,
  HeatmapElementEvent,
  RecursivePartial,
  Theme,
} from '@elastic/charts';
import { Download, InfoCircleOutline } from '@superb-ai/icons';
import { Box, Button, Chip, Icon, Tooltip, Typography } from '@superb-ai/ui';
import { lowerCase, startCase } from 'lodash';

import analyticsTracker from '../../../../../../../../../analyticsTracker';
import FileUtils from '../../../../../../../../../utils/FileUtils';
import { useDatasetContext } from '../../../../../../../contexts/DatasetContext';
import { useDiagnosisAnalyticsFilterContext } from '../../../../../../../contexts/DiagnosisAnalyticsFilterContext';
import { useDiagnosisCommonFilterContext } from '../../../../../../../contexts/DiagnosisCommonFilterContext';
import { useDiagnosisModelContext } from '../../../../../../../contexts/DiagnosisModelContext';
import { usePublicDatasetContext } from '../../../../../../../contexts/PublicDatasetContextProvider';
import { useConfusionMatrixQuery } from '../../../../../../../queries/diagnosisAnalyticsQueries';
import { useDiagnosisSchema } from '../../../../../../../queries/diagnosisModelQueries';
import { ConfusionMatrixDatum } from '../../../../../../../services/DiagnosisAnalyticsService';
import { DiagnosisDetail } from '../../../../../../../services/DiagnosisModelService';
import { GRID_VIEW } from '../../../../../../../types/viewTypes';
import { getSearchQueryRoute } from '../../../../../../../utils/routeUtils';
import { DEFAULT_EVALUATION_THRESHOLDS } from '../../../const';
import { ChartCard } from '../ChartCard';
import { ConfusionMatrix } from '../charts/ConfusionMatrix';
import { getHeatmapStyle } from '../charts/customTheme';
import { chartColors, heatmapGradientRed5, heatmapQuantizeColorScale } from '../colorScale';
import { formatConfusionMatrixDownload } from '../download';
import {
  DIAGNOSIS_DISPLAY_NULL_VALUE,
  getConfusionMatrixValues,
  getDisplayClassName,
  transformConfusionMatrix,
} from '../transformer';

export const getMatrixHeight = (n: number): string => {
  const minHeight = 300;
  const maxHeight = 20000; // max height of 501 * 501 matrix (not sure..)
  return `${Math.min(Math.max(n * 55, minHeight), maxHeight)}px`;
};

export const getMatrixWidth = (n: number): string => {
  // // numCols range from 1 to 200
  // const numCols = Math.ceil(Math.sqrt(n + 1));
  // if (numCols < 10) return 900;
  // if (numCols > 50 && numCols < 100) return 1400;
  // const minWidth = 1200;
  // const maxWidth = 1600;
  // return Math.min(Math.max((n / 2) * 15, minWidth), maxWidth);
  if (n <= 200) return '95%';
  else return `${n * 10}px`; // max width of 501 * 501 matrix (not sure..)
};

export const getConfusionMatrixAxisTitles = (t: TFunction) => {
  return {
    predicted: startCase(t('curate.diagnosis.chart.confusionMatrix.predictedClass')),
    true: startCase(t('curate.diagnosis.chart.confusionMatrix.groundTruthClass')),
  };
};

const getEvaluationResult = (prediction: string, groundTruth: string) => {
  if (prediction == DIAGNOSIS_DISPLAY_NULL_VALUE && groundTruth == DIAGNOSIS_DISPLAY_NULL_VALUE)
    return;
  if (prediction == DIAGNOSIS_DISPLAY_NULL_VALUE && groundTruth != DIAGNOSIS_DISPLAY_NULL_VALUE)
    return 'FN';
  if (prediction === groundTruth) {
    return 'TP';
  } else {
    if (prediction != DIAGNOSIS_DISPLAY_NULL_VALUE && groundTruth == DIAGNOSIS_DISPLAY_NULL_VALUE)
      return 'FP';
    return 'MC';
  }
};

export type MatrixProps = {
  matrixData: ConfusionMatrixDatum[];
  heatmapColors: HeatmapBandsColorScale;
  handleDownload: () => void;
};

export type CardProps = {
  diagnosis: DiagnosisDetail;
};

const ConfusionMatrixChartArea = ({ matrixData, heatmapColors, handleDownload }: MatrixProps) => {
  const { t } = useTranslation();
  const { accountName } = useParams<{ accountName: string }>();
  const { selectedClass } = useDiagnosisAnalyticsFilterContext();
  const history = useHistory();

  // classCount equals class count + 1 (for 'none')
  const classCount = Math.ceil(Math.sqrt(matrixData.length)); // equals row & column counts
  const matrixHeight = getMatrixHeight(classCount);
  const matrixWidth = getMatrixWidth(classCount);

  if (classCount && classCount > 200) {
    return (
      <Box
        style={{ height: '300px' }}
        display="flex"
        alignItems="center"
        justifyContent="center"
        flexDirection="column"
        gap={3}
      >
        <Typography color="gray-300" variant="m-regular" textAlign="center">
          <Trans t={t} i18nKey="curate.diagnosis.chart.confusionMatrix.moreThan200" />
        </Typography>
        <Button onClick={handleDownload} size="s" variant="soft-fill">
          <Icon icon={Download} />
          {t('curate.diagnosis.chart.confusionMatrix.csvDownload')}
        </Button>
      </Box>
    );
  }

  const optionalProps = {
    tooltipFooterText: t('curate.diagnosis.action.clickToInspectObjects'),
  };

  const trackChartClick = () => {
    analyticsTracker.chartClicked({
      accountId: accountName,
      chartName: 'confusion-matrix',
      feature: 'model-diagnosis',
    });
  };

  const axisTitles = getConfusionMatrixAxisTitles(t);
  return (
    <Box justifyContent="center" width="100%">
      <ConfusionMatrix
        chartStyle={{
          zIndex: -1,
          marginTop: '13px',
          marginLeft: '8px',
          alignContent: 'center',
          justifyContent: 'center',
          background: 'white',
        }}
        data={matrixData}
        chartHeight={matrixHeight}
        chartWidth={matrixWidth}
        xAxis={{ title: axisTitles.predicted }}
        yAxis={{ title: axisTitles.true }}
        colorScale={heatmapColors}
        optional={optionalProps}
        heatmapProps={{
          ...(selectedClass !== 'all' && {
            highlightedData: matrixData
              .filter(d => d[1] === selectedClass)
              .reduce(
                (acc, cur) => {
                  acc.y.push(cur[0]);
                  acc.x.push(cur[1]);
                  return acc;
                },
                { x: [], y: [] } as { x: string[]; y: string[] },
              ),
          }),
          xSortPredicate: 'dataIndex',
          ySortPredicate: 'dataIndex',
        }}
        // TODO: move theme config to heatmapTheme
        onElementListeners={{
          onElementClick: (elements: any) => {
            const heatmapEvent = elements[0] as HeatmapElementEvent;
            const cell = heatmapEvent[0];
            const predictedClass = cell.datum.x as string;
            const trueClass = cell.datum.y as string;
            const evaluationResult = getEvaluationResult(predictedClass, trueClass);
            const searchQueryRoute = getSearchQueryRoute(history, {
              ...(evaluationResult && { evaluation_result: evaluationResult }),
              prediction: predictedClass,
              annotation: trueClass,
              view: GRID_VIEW,
            });
            history.push(searchQueryRoute);
            trackChartClick();
          },
        }}
        theme={
          {
            background: { color: 'white' },
            heatmap: getHeatmapStyle(matrixData.length),
            axes: {
              axisTitle: {
                fontFamily: 'Inter',
                fontSize: 12,
                fill: chartColors.axisTitle,
                padding: 8,
              },
            },
          } as RecursivePartial<Theme>
        }
      />
    </Box>
  );
};

const ConfusionMatrixCard = ({ diagnosis }: CardProps) => {
  const { t } = useTranslation();
  const { accountName } = useParams<{ accountName: string }>();
  const { selectedDiagnosis, targetIou } = useDiagnosisModelContext();
  const { datasetId } = useDatasetContext();
  const { selectedClass } = useDiagnosisAnalyticsFilterContext();
  const { showPublicDatasets } = usePublicDatasetContext();
  const { splitIn, sliceId } = useDiagnosisCommonFilterContext();

  const metadataQuery = useDiagnosisSchema({
    datasetId,
    fromPublicDatasets: showPublicDatasets,
    diagnosisId: diagnosis.id,
  });

  const query = useConfusionMatrixQuery({
    datasetId,
    diagnosisId: diagnosis.id,
    modelSource: selectedDiagnosis?.modelSource,
    splitIn,
    sliceId,
  });

  const matrixData: ConfusionMatrixDatum[] = query?.data
    ? transformConfusionMatrix(query.data, metadataQuery?.data?.classList ?? [])
    : [];

  const handleDownload = () => {
    if (!matrixData) return;
    const columnNames = {
      predictedClass: lowerCase(t('curate.diagnosis.chart.confusionMatrix.predictedClass')),
      trueClass: lowerCase(t('curate.diagnosis.chart.confusionMatrix.groundTruthClass')),
      count: t('curate.diagnosis.chart.confusionMatrix.count'),
    };
    const downloadData = formatConfusionMatrixDownload(matrixData, columnNames);
    const timestamp = new Date().toISOString();
    const fileName = `${diagnosis?.modelName}_${timestamp}_confusion_matrix`;
    FileUtils.exportToCsv(downloadData, Object.values(columnNames), fileName);
    analyticsTracker.chartDownloaded({
      accountId: accountName,
      chartName: 'confusion-matrix',
      feature: 'model-diagnosis',
    });
  };

  const heatmapColors = heatmapQuantizeColorScale(
    getConfusionMatrixValues(query.data),
    heatmapGradientRed5,
  );
  return (
    <ChartCard
      handleDownload={handleDownload}
      chartContainerProps={{ overflow: 'auto' }} //, maxHeight: matrixHeight + 100 }}
      isLoading={query?.isLoading ?? true}
      chartTitle={t('curate.diagnosis.chart.confusionMatrix.title')}
      chartComponent={
        <ConfusionMatrixChartArea
          matrixData={matrixData}
          heatmapColors={heatmapColors}
          handleDownload={handleDownload}
        />
      }
      headerLeftArea={
        <Box display="flex" gap={1} alignItems="center">
          <Tooltip
            placement="right"
            content={
              <Trans
                i18nKey="curate.diagnosis.chart.confusionMatrix.info"
                components={{ bold: <Typography color="gray-100" /> }}
                values={{
                  targetIou,
                  targetConfidence: DEFAULT_EVALUATION_THRESHOLDS.confidenceThreshold,
                }}
              />
            }
          >
            <Icon icon={InfoCircleOutline} />
          </Tooltip>
          <Chip color="primary">{getDisplayClassName(selectedClass, t)}</Chip>
        </Box>
      }
    />
  );
};

export default ConfusionMatrixCard;
