import { useQuery } from '@tanstack/react-query';
import { isBoolean } from 'lodash';

import { DEFAULT_EVALUATION_THRESHOLDS } from '../components/datasets/dataset/modelDiagnosis/const';
import { decodeConfusionMatrix } from '../components/datasets/dataset/modelDiagnosis/diagnosis/analytics/transformer';
import {
  AnalyticsFilterSchema,
  AnalyticsFilterSchemaSnakeCase,
} from '../components/datasets/dataset/modelDiagnosis/diagnosis/filterSchema';
import {
  HistogramResponse,
  ModelSource,
  PRCurveDatum,
} from '../components/datasets/dataset/modelDiagnosis/diagnosis/types';
import { QUERY_KEY } from '../const/QueryKey';
import { usePublicDatasetContext } from '../contexts/PublicDatasetContextProvider';
import { useDiagnosisAnalyticsService } from '../services/DiagnosisAnalyticsService';
import { Split } from '../types/evaluationTypes';

export type DiagnosisDependencies = {
  datasetId: string;
  fromPublicDatasets?: boolean;
  diagnosisId: string;
  query?: string;
  modelSource?: ModelSource;
  sliceId?: string;
} & Partial<AnalyticsFilterSchema>;

function getFilterBody(
  modelSource?: ModelSource,
  predictionClass?: AnalyticsFilterSchema['predictionClass'],
  fromPublicDatasets?: boolean,
): AnalyticsFilterSchemaSnakeCase {
  const filterBody = {} as AnalyticsFilterSchemaSnakeCase;
  if (modelSource === 'external') {
    filterBody['filter_by_optimal_confidence'] = false;
  }
  if (predictionClass !== undefined) {
    filterBody['prediction_class'] = predictionClass;
  }
  return filterBody;
}

// precision, recall, f1 curve as a function of model confidence
export function usePerformanceCurveQuery(dependencies: DiagnosisDependencies) {
  const { datasetId, fromPublicDatasets, diagnosisId, query, predictionClass, sliceId } =
    dependencies;
  const { getPerformanceCurve } = useDiagnosisAnalyticsService();
  const { showPublicDatasets } = usePublicDatasetContext();
  const filter = getFilterBody(undefined, predictionClass);
  return useQuery({
    queryKey: ['performance-curve', datasetId, diagnosisId, query, predictionClass, sliceId],
    queryFn: () => {
      return getPerformanceCurve({
        datasetId,
        fromPublicDatasets: isBoolean(fromPublicDatasets) ? fromPublicDatasets : showPublicDatasets,
        diagnosisId,
        query,
        filter,
        sliceId,
      });
    },
    enabled: Boolean(datasetId && diagnosisId),
  });
}

// precision-recall curve points
export function usePRCurveQuery(dependencies: DiagnosisDependencies) {
  const { datasetId, fromPublicDatasets, diagnosisId, predictionClass, query, sliceId } =
    dependencies;
  const { getPRCurve } = useDiagnosisAnalyticsService();
  const { showPublicDatasets } = usePublicDatasetContext();
  const filter = getFilterBody(undefined, predictionClass);
  return useQuery({
    queryKey: ['pr-curve', datasetId, diagnosisId, query, sliceId, predictionClass],
    queryFn: () => {
      return getPRCurve({
        datasetId,
        fromPublicDatasets: isBoolean(fromPublicDatasets) ? fromPublicDatasets : showPublicDatasets,
        diagnosisId,
        query,
        filter,
        sliceId,
      });
    },
    enabled: Boolean(datasetId && diagnosisId),
  });
}

export function useComparedPRCurvesQuery(
  dependencies: DiagnosisDependencies & { predictionClasses: string[] },
) {
  const { datasetId, fromPublicDatasets, diagnosisId, predictionClasses, query, sliceId } =
    dependencies;
  const { getPRCurve } = useDiagnosisAnalyticsService();
  const { showPublicDatasets } = usePublicDatasetContext();
  return useQuery({
    queryKey: ['compared-pr-curves', datasetId, diagnosisId, query, predictionClasses],
    queryFn: async () => {
      if (!predictionClasses) return;
      const prCurves = {} as Record<(typeof predictionClasses)[number], PRCurveDatum[]>;
      for (const predictionClass of predictionClasses) {
        const prCurve = await getPRCurve({
          datasetId,
          fromPublicDatasets: isBoolean(fromPublicDatasets)
            ? fromPublicDatasets
            : showPublicDatasets,
          diagnosisId,
          query,
          filter: { prediction_class: predictionClass },
          sliceId,
        });
        prCurves[predictionClass] = prCurve?.data ?? [];
      }
      return prCurves;
    },
    enabled: Boolean(datasetId && diagnosisId && predictionClasses),
  });
}

function excludeIouLessThan(data: HistogramResponse, threshold: number) {
  // 0 ious are false positive and false negatives
  return data?.histogram
    .map(d => {
      return {
        key: d.key,
        count: d.key === 0 ? 0 : d.count,
      };
    })
    .filter(d => d.key >= threshold);
}

export function useIouHistogramQuery(dependencies: DiagnosisDependencies) {
  const {
    datasetId,
    fromPublicDatasets,
    diagnosisId,
    query,
    predictionClass,
    modelSource,
    sliceId,
  } = dependencies;
  const { getMetricHistogram } = useDiagnosisAnalyticsService();
  const { showPublicDatasets } = usePublicDatasetContext();
  const filter = getFilterBody(modelSource, predictionClass);
  return useQuery({
    queryKey: [
      'iou-histogram',
      diagnosisId,
      predictionClass,
      query,
      datasetId,
      modelSource,
      sliceId,
    ],
    queryFn: () => {
      return getMetricHistogram({
        datasetId,
        fromPublicDatasets: isBoolean(fromPublicDatasets) ? fromPublicDatasets : showPublicDatasets,
        diagnosisId,
        // query,
        metric: 'iou',
        filter,
        sliceId,
      });
    },
    select: data => excludeIouLessThan(data, DEFAULT_EVALUATION_THRESHOLDS.iouThreshold),
    enabled: Boolean(datasetId && diagnosisId),
  });
}

export function useConfidenceHistogramQuery(dependencies: DiagnosisDependencies) {
  const {
    datasetId,
    fromPublicDatasets,
    diagnosisId,
    query,
    predictionClass,
    modelSource,
    sliceId,
  } = dependencies;
  const { getMetricHistogram } = useDiagnosisAnalyticsService();
  const { showPublicDatasets } = usePublicDatasetContext();
  const filter = getFilterBody(modelSource, predictionClass);
  return useQuery({
    queryKey: [
      'confidence-histogram',
      diagnosisId,
      predictionClass,
      // query,
      datasetId,
      sliceId,
    ],
    queryFn: () => {
      return getMetricHistogram({
        datasetId,
        fromPublicDatasets: isBoolean(fromPublicDatasets) ? fromPublicDatasets : showPublicDatasets,
        diagnosisId,
        // query,
        metric: 'confidence',
        filter,
        sliceId,
      });
    },
    enabled: Boolean(datasetId && diagnosisId),
  });
}

export function useConfusionMatrixQuery(
  dependencies: DiagnosisDependencies & { splitIn?: Split[] },
) {
  const { datasetId, fromPublicDatasets, diagnosisId, query, modelSource, splitIn, sliceId } =
    dependencies;
  const { getConfusionMatrix } = useDiagnosisAnalyticsService();
  const { showPublicDatasets } = usePublicDatasetContext();
  const filter = getFilterBody(modelSource);
  return useQuery({
    queryKey: [
      QUERY_KEY.diagnosisConfusionMatrix,
      diagnosisId,
      query,
      datasetId,
      modelSource,
      splitIn,
      sliceId,
    ],
    queryFn: () => {
      return getConfusionMatrix({
        datasetId,
        fromPublicDatasets: isBoolean(fromPublicDatasets) ? fromPublicDatasets : showPublicDatasets,
        diagnosisId,
        query,
        ...(modelSource === 'external' && !showPublicDatasets && { filter }),
        splitIn,
        sliceId,
      });
    },
    enabled: Boolean(datasetId && diagnosisId),
    select: data => {
      return decodeConfusionMatrix(data);
    },
  });
}

export function usePerformanceTable(dependencies: DiagnosisDependencies) {
  const { datasetId, fromPublicDatasets, diagnosisId } = dependencies;
  const { getPerformanceTable } = useDiagnosisAnalyticsService();
  const { showPublicDatasets } = usePublicDatasetContext();
  return useQuery({
    queryKey: [
      QUERY_KEY.diagnosisAnalyticsPerformanceTable,
      datasetId,
      diagnosisId,
      // query,
    ],
    queryFn: () => {
      return getPerformanceTable({
        datasetId,
        fromPublicDatasets: isBoolean(fromPublicDatasets) ? fromPublicDatasets : showPublicDatasets,
        diagnosisId,
        // query,
      });
    },
    enabled: Boolean(datasetId && diagnosisId),
    retry: 2,
  });
}
