import qs from 'qs';

import { AfterLoginCallback, apiCallAfterLogin, useFetcher } from '../../../services';
import {
  AnalyticsFilterSchema,
  AnalyticsFilterSchemaSnakeCase,
} from '../components/datasets/dataset/modelDiagnosis/diagnosis/filterSchema';
import {
  HistogramResponse,
  PerformanceCurveResponse,
  PRCurveResponse,
} from '../components/datasets/dataset/modelDiagnosis/diagnosis/types';
import { Split } from '../types/evaluationTypes';
import { ModelSource } from './DiagnosisModelService';

function appendQueryParams(endpoint: string, data: Record<string, any> | undefined): string {
  if (!data) return endpoint;
  return `${endpoint}${qs.stringify(data, { arrayFormat: 'brackets', addQueryPrefix: true })}`;
}

export type AnalyticsParams = Omit<
  {
    datasetId: string;
    fromPublicDatasets: boolean;
    diagnosisId: string;
    query?: string;
    modelSource?: ModelSource;
  } & Partial<AnalyticsFilterSchema>,
  'predictionClass'
> & {
  filter?: AnalyticsFilterSchemaSnakeCase;
  sliceId?: string;
};

type QueryParams = Omit<AnalyticsParams, 'datasetId' | 'diagnosisId' | 'fromPublicDatasets'> &
  AnalyticsFilterSchema;

export function buildStatsEndpoint(
  params: { datasetId: string; diagnosisId: string },
  path: string,
  fromPublicDatasets: boolean,
): string {
  const { datasetId, diagnosisId } = params;
  return `/curate/model-diagnosis/${
    fromPublicDatasets ? 'public-' : ''
  }datasets/${datasetId}/diagnoses/${diagnosisId}/stats/${path}`;
}

export function buildQueryParams(params: QueryParams): QueryParams {
  const queryParams: Record<string, any> = {};
  if (params?.query) queryParams.query = params.query;
  if (params?.histogramInterval) queryParams.histogram_interval = params.histogramInterval;
  if (params?.binningInterval) queryParams.histogram_interval = params.binningInterval;
  return queryParams;
}

export const CONFIDENCE_HISTOGRAM_INTERVAL = 0.1;
export const IOU_HISTOGRAM_INTERVAL = 0.1;

const getMetricHistogram: AfterLoginCallback<
  HistogramResponse,
  AnalyticsParams & { metric: 'confidence' | 'iou' }
> = async params => {
  if (!params.data) return;
  const { metric, datasetId, diagnosisId, filter, fromPublicDatasets, sliceId } = params.data;
  const endpoint = buildStatsEndpoint(
    { datasetId, diagnosisId },
    `${metric}-histogram/`,
    fromPublicDatasets,
  );
  const { data } = await apiCallAfterLogin({
    method: 'post',
    url: endpoint,
    hasPublicApi: false,
    isCurateUrl: true,
    ...params,
    data: {
      histogram_interval:
        metric === 'confidence' ? CONFIDENCE_HISTOGRAM_INTERVAL : IOU_HISTOGRAM_INTERVAL,
      slice_id: sliceId,
      ...filter,
    },
  });
  return data;
};

const CONFIDENCE_BINNING_INTERVAL = 0.01;
const getPerformanceCurve: AfterLoginCallback<
  PerformanceCurveResponse,
  AnalyticsParams
> = async params => {
  if (!params.data) return;
  const { query, datasetId, diagnosisId, filter, fromPublicDatasets, sliceId } = params.data;
  const endpoint = buildStatsEndpoint(
    { datasetId, diagnosisId },
    'performance-curve',
    fromPublicDatasets,
  );
  const queryParams = buildQueryParams({ query });
  const { data } = await apiCallAfterLogin({
    method: 'post',
    url: appendQueryParams(endpoint, queryParams),
    hasPublicApi: false,
    isCurateUrl: true,
    ...params,
    data: {
      binning_interval: CONFIDENCE_BINNING_INTERVAL,
      ...filter,
      slice_id: sliceId,
    },
  });
  return data;
};

const getPRCurve: AfterLoginCallback<PRCurveResponse, AnalyticsParams> = async params => {
  if (!params.data) return;
  const { datasetId, diagnosisId, filter, fromPublicDatasets, sliceId } = params.data;
  const endpoint = buildStatsEndpoint({ datasetId, diagnosisId }, 'pr-curve/', fromPublicDatasets);
  const { data } = await apiCallAfterLogin({
    method: 'post',
    url: endpoint,
    hasPublicApi: false,
    isCurateUrl: true,
    ...params,
    data: {
      binning_interval: CONFIDENCE_BINNING_INTERVAL,
      ...filter,
      slice_id: sliceId,
    },
  });
  return data;
};

export type ConfusionMatrixDatum = [string | null, string | null, number | null]; // actual, predicted, count

export type ConfusionMatrixResponse = {
  classList: (string | null)[];
  matrix: [number, number, ConfusionMatrixDatum[2]][];
};
const getConfusionMatrix: AfterLoginCallback<
  ConfusionMatrixResponse,
  AnalyticsParams & { splitIn?: Split[]; sliceId?: string }
> = async params => {
  if (!params.data) return;
  const { query, datasetId, diagnosisId, filter, fromPublicDatasets, splitIn, sliceId } =
    params.data;
  const endpoint = buildStatsEndpoint(
    { datasetId, diagnosisId },
    'confusion-matrix/',
    fromPublicDatasets,
  );
  const queryParams = buildQueryParams({ query });

  const { data } = await apiCallAfterLogin({
    method: 'post',
    url: appendQueryParams(endpoint, queryParams),
    hasPublicApi: false,
    isCurateUrl: true,
    ...params,
    data: {
      ...filter,
      splitIn,
      sliceId,
    },
  });
  return data;
};

export type OverallPerformance = {
  ap: number;
  ap50: number;
};

export type ClassPerformance = {
  className: string;
  ap: number;
  ap50: number;
  precision: number;
  recall: number;
  scoreThres: number;
};
export type PerformanceTableResponse = {
  metrics: {
    overallPerformance: OverallPerformance;
    classPerformance: ClassPerformance[];
  };
};

const getPerformanceTable: AfterLoginCallback<
  PerformanceTableResponse,
  AnalyticsParams
> = async params => {
  if (!params.data) return;
  const { query, diagnosisId, datasetId, fromPublicDatasets } = params.data;
  const endpoint = buildStatsEndpoint(
    { datasetId, diagnosisId },
    'performance-table/',
    fromPublicDatasets,
  );
  const queryParams = buildQueryParams({ query });
  const { data } = await apiCallAfterLogin({
    method: 'post',
    url: appendQueryParams(endpoint, queryParams),
    hasPublicApi: false,
    isCurateUrl: true,
    ...params,
    data: {},
  });
  return data;
};

export function useDiagnosisAnalyticsService() {
  const { afterLoginFetcher } = useFetcher();
  return {
    getConfusionMatrix: afterLoginFetcher(getConfusionMatrix),
    getPerformanceCurve: afterLoginFetcher(getPerformanceCurve),
    getPRCurve: afterLoginFetcher(getPRCurve),
    getMetricHistogram: afterLoginFetcher(getMetricHistogram),
    getPerformanceTable: afterLoginFetcher(getPerformanceTable),
  };
}
