import { useEffect, useState } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { useParams } from 'react-router';

import { InfoCircleOutline } from '@superb-ai/icons';
import { Box, Chip, Icon, Tooltip, Typography } from '@superb-ai/ui';
import { lowerCase, mapValues, snakeCase } from 'lodash';

import FileUtils from '../../../../../../../../../utils/FileUtils';
import { useDiagnosisAnalyticsFilterContext } from '../../../../../../../contexts/DiagnosisAnalyticsFilterContext';
import { useDiagnosisCommonFilterContext } from '../../../../../../../contexts/DiagnosisCommonFilterContext';
import { useDiagnosisModelContext } from '../../../../../../../contexts/DiagnosisModelContext';
import { usePublicDatasetContext } from '../../../../../../../contexts/PublicDatasetContextProvider';
import {
  useComparedPRCurvesQuery,
  usePerformanceTable,
  usePRCurveQuery,
} from '../../../../../../../queries/diagnosisAnalyticsQueries';
import { useDiagnosisSchema } from '../../../../../../../queries/diagnosisModelQueries';
import { PredictionTypeEnum } from '../../types';
import { ChartCard } from '../ChartCard';
import PrecisionRecallCurve from '../charts/PrecisionRecallCurve';
import { formatPRCurveDataDownload } from '../download';
import MultiSelectClassDropdown from '../elements/MultiSelectClassDropdown';
import { getDisplayClassName, scaleAPToHundred, transformPRCurve } from '../transformer';
import { CardProps } from './ConfusionMatrixCard';

const PRCurveCard = ({ diagnosis }: CardProps) => {
  const { t } = useTranslation();

  const { targetIou } = useDiagnosisModelContext();
  const { selectedClass } = useDiagnosisAnalyticsFilterContext();
  const { showPublicDatasets } = usePublicDatasetContext();

  const { datasetId } = useParams<{ datasetId: string }>();
  const { sliceId } = useDiagnosisCommonFilterContext();

  const [comparedClasses, setComparedClasses] = useState<string[]>([]);
  const [showInterpolated, setShowInterpolated] = useState<boolean>(true);
  const [APs, setAPs] = useState<Record<string, number>>({});
  const dependencies = {
    datasetId,
    fromPublicDatasets: showPublicDatasets,
    diagnosisId: diagnosis.id,
  };
  useEffect(() => {}, [comparedClasses]);

  const metadataQuery = useDiagnosisSchema(dependencies);
  const prCurveQuery = usePRCurveQuery({
    ...dependencies,
    ...(selectedClass !== 'all' && { predictionClass: selectedClass }),
    sliceId,
  });

  const comparedPrCurveQuery = useComparedPRCurvesQuery({
    ...dependencies,
    ...(comparedClasses && { predictionClasses: comparedClasses }),
    sliceId,
  });
  const performanceTableQuery = usePerformanceTable(dependencies);
  const data = prCurveQuery.data?.data ?? [];
  const interpolatedData = transformPRCurve(data, showInterpolated);

  /** AP scores */
  const meanAP = () => {
    const data = performanceTableQuery?.data;
    if (!data) return 0;
    const map =
      selectedClass === 'all'
        ? data.metrics.overallPerformance.ap50
        : data.metrics.classPerformance.find(d => d.className === selectedClass)?.ap50;
    return map ? scaleAPToHundred(map) : undefined;
  };

  const getComparedAPs = (classes: string[]) => {
    const data = performanceTableQuery?.data;
    if (!data) return {};
    const result = {} as Record<string, number>;
    data.metrics.classPerformance
      .filter(d => classes.includes(d.className))
      .forEach(d => {
        result[d.className] = d.ap50;
      });
    return result;
  };

  const handleDownload = () => {
    // TODO: add compraed data to download
    if (!data) return;
    const columnNames = {
      recall: lowerCase(t('curate.diagnosis.metric.recall')),
      precision: lowerCase(t('curate.diagnosis.metric.precision')),
      confidenceThreshold: snakeCase(t('curate.diagnosis.metric.confidenceThreshold')),
    };
    const timestamp = new Date().toISOString();
    const downloadData = formatPRCurveDataDownload({
      data: data,
      sortBy: columnNames.recall,
    });
    const fileName = `${diagnosis?.modelName}_${
      selectedClass === 'all' ? 'all_classes' : selectedClass
    }_${timestamp}_precision_recall_chart`;

    FileUtils.exportToCsv(
      downloadData,
      [columnNames.recall, columnNames.precision, columnNames.confidenceThreshold],
      fileName,
    );
  };

  const handleSelectComparedClasses = (values: string[]) => {
    const COMPARE_CLASS_LIMIT = 4;
    if (values.length > COMPARE_CLASS_LIMIT) return;
    setComparedClasses(values);
  };

  const comparedData = comparedPrCurveQuery?.data
    ? mapValues(comparedPrCurveQuery.data, points => transformPRCurve(points, showInterpolated))
    : {};

  return (
    <ChartCard
      handleDownload={handleDownload}
      style={{ height: '362px', paddingRight: '5px', paddingLeft: '2px' }}
      isLoading={prCurveQuery.isLoading ?? true}
      chartTitle={t('curate.diagnosis.chart.precisionRecall.title')}
      headerLeftArea={
        <Box display="flex" gap={1} alignItems="center">
          <Tooltip
            content={
              <Trans
                i18nKey="curate.diagnosis.chart.precisionRecall.info"
                components={{ bold: <Typography color="gray-300" /> }}
                values={{ targetIou }}
              />
            }
          >
            <Icon icon={InfoCircleOutline} />
          </Tooltip>
          <Chip color="primary">{getDisplayClassName(selectedClass, t)}</Chip>
          vs
          <MultiSelectClassDropdown
            classList={metadataQuery.data?.classList ?? []}
            selectedClasses={comparedClasses}
            handleSelectClasses={handleSelectComparedClasses}
          />
        </Box>
      }
      chartComponent={
        <Box display="flex" flexDirection="row" width="100%" height="100%">
          {/* <MeanAvpMiniCard
            metric={
              {
                value: meanAP(),
                name:
                  selectedClass === 'all'
                    ? t('curate.diagnosis.metric.meanAveragePrecisionAbbr')
                    : t('curate.diagnosis.classMetrics.apAbbr'),
                threshold: '@0.5',
                tooltipText:
                  selectedClass === 'all' ? (
                    <Trans i18nKey="curate.diagnosis.chart.precisionRecall.meanAPTooltip">
                      Mean average precision measured at <strong> 0.5 IoU</strong>
                    </Trans>
                  ) : (
                    <Trans i18nKey="curate.diagnosis.chart.precisionRecall.classAPTooltip">
                      Average precision of <strong> {{ className: selectedClass }} </strong>
                    </Trans>
                  ),
              } as MetricInfo
            }
            color={modelColors['prediction'][0]}
            style={{ width: '165px', marginLeft: '5px' }}
          /> */}
          <PrecisionRecallCurve
            data={{ [selectedClass ?? 'all']: interpolatedData }}
            comparedData={comparedData}
            predictionSetType={PredictionTypeEnum.BASE}
          />
        </Box>
      }
    />
  );
};

export default PRCurveCard;
