import { mapValues } from 'lodash';
import { TFunction } from 'next-i18next';

import {
  ClassPerformance,
  ConfusionMatrixDatum,
  ConfusionMatrixResponse,
  OverallPerformance,
} from '../../../../../../services/DiagnosisAnalyticsService';
import { ClassListResponse } from '../../../../../../services/DiagnosisModelService';
import { HistogramDatum } from '../../../analytics/types';
import {
  FormattedPerformanceCurveDatum,
  FormattedPRCurveDatum,
  PerformanceCurveDatum,
  PRCurveDatum,
} from '../types';

const DIAGNOSIS_NULL_VALUE = null;
export const DIAGNOSIS_DISPLAY_NULL_VALUE = '(null)';

/** Confusion matrix */
export function getConfusionMatrixValues(matrix?: ConfusionMatrixDatum[]): number[] {
  if (!matrix) return [];
  return matrix.flatMap(row => Number(row.slice(2)));
}

/**
 * @param data
 * ideal sort order:
 *  ['none', 'car', 0],
    ['none', 'book', 130],
    ['none', 'none', 3000],
    ['book', 'car', 1],
    ['book', 'book', 2000],
    ['book', 'none', 102],
    ['car', 'car', 9000],
    ['car', 'book', 13],
    ['car', 'none', 30],
 *
 */

function sortConfusionXAxis(data: ConfusionMatrixDatum[]) {
  // sort
  const sortedX = data.sort((a, b) => {
    if (a[1] === DIAGNOSIS_DISPLAY_NULL_VALUE) return 1;
    if (b[1] === DIAGNOSIS_DISPLAY_NULL_VALUE) return -1;
    return a[1].localeCompare(b[1]);
  });
  return sortedX;
}

export function sortConfusionMatrixDatum(data: ConfusionMatrixDatum[]) {
  // sort true class alphabetically (y axis)
  const sortedY = data.sort((a, b) => {
    if (a[0] === DIAGNOSIS_DISPLAY_NULL_VALUE) return 1;
    if (b[0] === DIAGNOSIS_DISPLAY_NULL_VALUE) return -1;
    return a[0].localeCompare(b[0]);
  });
  const sortedX = sortConfusionXAxis(sortedY);
  return sortedX;
}

export function decodeConfusionMatrix(response?: ConfusionMatrixResponse): ConfusionMatrixDatum[] {
  if (!response || !response?.classList || !response?.matrix) return [];
  const result = [] as ConfusionMatrixDatum[];
  const classList = response.classList.map(c =>
    c !== DIAGNOSIS_NULL_VALUE ? c : DIAGNOSIS_DISPLAY_NULL_VALUE,
  );
  for (const [i, j, c] of response?.matrix) {
    result.push([classList[i], classList[j], c]);
  }
  return result;
}

// TODO: remove it not needed
export function calculateF1Score(precision: number, recall: number): number {
  return (2 * precision * recall) / (precision + recall);
}

export function getPrecisionRecallF1Points(
  data: PerformanceCurveDatum[],
): FormattedPerformanceCurveDatum[] {
  if (!data) return [];
  return data.map(d => {
    return {
      Precision: d.precision,
      Recall: d.recall,
      'F1 Score': d.f1Score,
      Confidence: d.confidence,
    };
  });
}

export function convertKeysToDisplayNames<T extends PRCurveDatum>(
  data: T[],
): FormattedPRCurveDatum[] {
  if (!data) return [];
  return data.map(d => {
    return {
      Precision: d.precision,
      Recall: d.recall,
    };
  });
}

export function transformPerformanceCurvePoints(
  data: PerformanceCurveDatum[],
): FormattedPerformanceCurveDatum[] {
  if (!data) return [];
  return data.map(d => {
    return {
      Precision: d.precision,
      Recall: d.recall,
      // TODO: remove after api integration
      'F1 Score': calculateF1Score(d.precision, d.recall),
      Confidence: d.confidence,
    };
  });
}

export function interpolateMissingClasses(
  metadataClasses: ClassListResponse['classList'],
  data: ConfusionMatrixDatum[],
) {
  const classToDatum = data.reduce((acc, datum) => {
    const [actual, predicted, count] = datum;
    if (!acc[actual]) {
      acc[actual] = { [predicted]: count };
    }
    acc[actual][predicted] = count;
    return acc;
  }, {} as Record<ConfusionMatrixDatum[0], Record<ConfusionMatrixDatum[1], ConfusionMatrixDatum[2]>>);

  const interpolated = [] as ConfusionMatrixDatum[];
  for (const actual of metadataClasses) {
    for (const predicted of metadataClasses) {
      if (actual === null && actual === predicted) {
        // interpolated.push([actualClass, predictedClass, null]);
        continue;
      }
      const actualClass = actual ?? DIAGNOSIS_DISPLAY_NULL_VALUE;
      const predictedClass = predicted ?? DIAGNOSIS_DISPLAY_NULL_VALUE;
      const count = classToDatum[actualClass]?.[predictedClass] ?? 0;
      interpolated.push([actualClass, predictedClass, count]);
    }
  }
  return interpolated;
}

export function transformConfusionMatrix(
  data: ConfusionMatrixDatum[],
  classList: ClassListResponse['classList'],
) {
  const interpolated = interpolateMissingClasses(classList, data);
  return sortConfusionMatrixDatum(interpolated);
}

/**PR Curve */
/**
 * example:
 * @param data
 **  [
 *    {Recall: 0.1, Precision: 0.9}, ..., {Recall: 0.95, Precision: 0.9}
 *  ]
 * @returns
 *  [
 *   {Recall: 0.0, Precision: 0.9}
 *   {Recall: 0.1, Precision: 0.9}
 *   ...
 *   {Recall: 0.95, Precision: 0.59},
 *   {Recall: 1.00, Precision: 0.59},
 * ]
 */
export function addBoundaryPoints(data: PRCurveDatum[]) {
  if (!data) return data;
  const result = data;
  const lowestRecallIndex = result.length - 1;
  const highestRecallIndex = 0;
  if (data[lowestRecallIndex]?.recall > 0.0)
    result.push({
      precision: result[lowestRecallIndex].precision,
      recall: 0,
    });
  if (data[highestRecallIndex]?.recall < 1.0) {
    result.unshift({
      precision: result[highestRecallIndex].precision,
      recall: 1.0,
    });
  }
  return result;
}

export function interpolateCurve(data: PRCurveDatum[]) {
  // imitate scikit-learn interpolate method to create a step curve
  const result = data.reduceRight((acc, item) => {
    if (item?.precision > acc[acc.length - 1]?.precision || acc.length === 0) {
      acc.push(item);
    }
    return acc;
  }, [] as PRCurveDatum[]);
  return result;
}

export function transformPRCurve(data: PRCurveDatum[], showInterpolated: boolean) {
  return convertKeysToDisplayNames(
    addBoundaryPoints(showInterpolated ? interpolateCurve(data) : data),
  );
}

/** Histogram */
export function formatHistogramData(data: HistogramDatum[]) {
  return data.map(d => ({
    ...d,
    keyNumeric: Number(d.key),
  }));
}

export const scaleAPToHundred = (ap: number) => ap * 100;

export function scaleOverallAPs(data?: OverallPerformance) {
  if (!data) return;
  return mapValues(data, value => scaleAPToHundred(value));
}

export function scaleClassAPs(data?: ClassPerformance[]) {
  if (!data) return [];
  return data.map(d => {
    return {
      ...d,
      ap: scaleAPToHundred(d.ap),
      ap50: scaleAPToHundred(d.ap50),
    };
  });
}

export type NumericColumns = keyof Omit<ClassPerformance, 'className'>;
export function formatDecimalPrecision(value: number, key: NumericColumns) {
  const decimalMap = {
    ap: 1,
    ap50: 1,
    precision: 2,
    recall: 2,
    scoreThres: 2,
  } as Record<NumericColumns, number>;
  return value.toFixed(decimalMap[key] ?? 2);
}

export function formatDecimalPrecisionClassAPs(data?: ClassPerformance[]) {
  if (!data) return [];
  return data.map(d => {
    return {
      ...d,
      ap: formatDecimalPrecision(d.ap, 'ap'),
      ap50: formatDecimalPrecision(d.ap50, 'ap50'),
      precision: formatDecimalPrecision(d.precision, 'precision'),
      recall: formatDecimalPrecision(d.recall, 'recall'),
      scoreThres: formatDecimalPrecision(d.scoreThres, 'scoreThres'),
    };
  });
}
export function getDisplayClassName(className: string, t: TFunction) {
  return className === 'all' ? t('curate.diagnosis.text.allPredictedClass') : className;
}
