import { ColorBand, HeatmapBandsColorScale } from '@elastic/charts';
import { NamedColor, NamedColorWeight } from '@superb-ai/ui';
import { scaleQuantile, scaleQuantize } from 'd3-scale';

import { MetricType } from '../types';

// PR Curve
export const diagnosisComparedClassColors = [
  '#FF625A',
  'rgb(66, 133, 244)',
  '#FBBC04',
  'rgb(36, 193, 224)',
  // 'rgb(250, 123, 23)',
  'rgb(132, 48, 206)',
] as string[];

const GT_COLOR = 'mint-400';
const PREDICTION_COLOR = 'red-400';
const PREDICTION2_COLOR = 'blue-400';

export const modelColors = {
  groundTruth: [GT_COLOR, '#4AE2B9'],
  prediction: [PREDICTION_COLOR, '#FF625A'],
  prediction2: [PREDICTION2_COLOR, '#5C7CFA'],
} as Record<string, [NamedColor | NamedColorWeight, string]>;

export type AnnotationSetType = keyof typeof modelColors;

export const getAnnotationGroupColor = (key: string): NamedColor | NamedColorWeight => {
  const colorMap = {
    annotationCount: [GT_COLOR, '#4AE2B9'],
    predictionCount: [PREDICTION_COLOR, '#FF625A'],
    // comparedPredictionCount: [PREDICTION2_COLOR, '#5C7CFA'],
  } as Record<string, [NamedColor | NamedColorWeight, string]>;
  return colorMap[key as AnnotationSetType][0];
};

export const PERFORMANCE_TABLE_COLORS = {
  apThreshold: '#B3B3B3',
  apColumns: 'rgba(232, 232, 232, 1)',
  apColumnsWithOpactity: 'rgba(0, 0, 0, 0.04)',
};
export const AP_THRESHOLD_COLOR = '#B3B3B3';

export const heatmapGradientRed5 = [
  // ['white', '#FFFFFF'],
  '#FFF8F8',
  '#FFE7E7',
  '#FFD7D5',
  '#FFD7D5',
  '#FFB5B0',
  '#FFB5B0',
  '#FF8F87',
  '#FF8F87',
  '#FF8F87',
  '#FF625A',
  '#FF625A',
];

export const heatmapGradientBlue4 = [
  ['blue-100', '#E3F0FF'],
  ['blue-200', '#BDD4FF'],
  ['blue-300', '#8EA9FF'],
  ['blue-400', '#5C7CFA'],
];

export const chartColors = {
  text: '#333333',
  axisTitle: '#B3B3B3 ',
  brushArea: '#E3E3E3',
  brushStroke: 'rgba(0, 0, 0, 0.09)',
  singleModel: ['red-400', '#FF625A'],
  singleModelLight: ['red-200', '#FFB5B0'],
  comparedDiagnosis: ['blue-400', '#5C7CFA'],
  // metrics
  precision: ['purple-400', '#EECAFF'],
  recall: ['sky-400', '#CFF1FF'],
  f1Score: ['green-400', '#82DB24'],
  // class dist
  class: ['mint-400', '#2DCE89'],
  classCompared: ['yellow-400', '#FFC107'],
  // diagnosis
};

export function getPerformanceCurveSeriesColor(
  metric: Capitalize<MetricType> | 'F1 Score',
): string {
  switch (metric) {
    case 'Precision':
      return chartColors.precision[1];
    case 'Recall':
      return chartColors.recall[1];
    case 'F1 Score':
      return chartColors.f1Score[1];
    default:
      return chartColors.text;
  }
}

type Input = number[];

function getThresholds(data: Input, colors: string[]): number[] {
  const scale = scaleQuantize<string, number>()
    .domain([Math.min(...data), Math.max(...data)])
    .range(colors);
  return scale.thresholds();
}

function createColorBands(thresholds: number[], colors: string[], topPadding: number): ColorBand[] {
  const lastIndex = thresholds.length - 1;
  return thresholds.map((q, i) => ({
    start: i === 0 ? 0.01 : thresholds[i - 1],
    end: i === lastIndex ? q + topPadding : q,
    color: colors[i],
  }));
}

export function heatmapQuantizeColorScale(data: Input, colors: string[]): HeatmapBandsColorScale {
  const thresholds = getThresholds(data, colors);
  thresholds.push(Math.max(...data));
  const topPadding = 0.1;
  const colorBands: ColorBand[] = createColorBands(thresholds, colors, topPadding);

  // Add the initial band
  colorBands.unshift({ start: -Infinity, end: 0.01, color: '#FFFFFF' });

  return {
    type: 'bands',
    bands: colorBands,
  };
}

function computeQuantiles(data: Input, colors: string[], topPadding: number): number[] {
  const quantileScale = scaleQuantile<string>().domain(data).range(colors);
  const quantiles = quantileScale.quantiles();
  quantiles.push(Math.max(...data) + topPadding);
  return quantiles;
}

export function heatmapQuantileColorScale(data: Input, colors: string[]): HeatmapBandsColorScale {
  const topPadding = 0.1;
  const quantiles = computeQuantiles(data, colors, topPadding);

  // Add the initial band
  const colorBands: ColorBand[] = createColorBands(quantiles, colors, topPadding);
  colorBands.unshift({ start: -Infinity, end: 0.01, color: '#FFFFFF' });

  return {
    type: 'bands',
    bands: colorBands,
  };
}
