import { ReactElement, useState } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { useRouteMatch } from 'react-router-dom';

import { Axis, Chart, LineSeries, Position, Settings, Tooltip, TooltipType } from '@elastic/charts';
import { Box, Select, Tab, TabList, TabPanel, Tabs, Typography } from '@superb-ai/ui';

import analyticsTracker from '../../../../analyticsTracker';
import { Row } from '../../../../components/elements/Row';
import { lineChartTheme } from '../../../Curate/components/datasets/dataset/modelDiagnosis/diagnosis/analytics/charts/customTheme';
import DropdownSearchInput from '../../components/components';
import { isRecognitionAIModelTraining } from '../../services/modelTrainingTypeGuards';
import { MyModelDetail, RecognitionAIModelDetailTraining } from '../../services/types';

export const PerformanceChart = ({
  data,
  hasTrainingProcess,
}: {
  data: MyModelDetail;
  hasTrainingProcess: boolean;
}) => {
  const { t } = useTranslation();
  const { params } = useRouteMatch<{ accountName: string }>();

  return (
    <Box>
      <Box
        borderRadius="2px"
        border={hasTrainingProcess ? undefined : '1px solid'}
        borderColor={hasTrainingProcess ? undefined : 'gray-200'}
        backgroundColor={hasTrainingProcess ? undefined : 'gray-100'}
        style={{ height: 308 }}
      >
        {hasTrainingProcess ? (
          <Tabs>
            <TabList variant="fill" color="gray">
              <Tab
                onClick={() => {
                  analyticsTracker.modelDetailClicked({
                    accountId: params.accountName,
                    clickedButton: 'click-performance-chart-tab',
                    modelStatus: data.status,
                  });
                }}
              >
                {t('model.myModelDetail.performanceChart.ap')}
              </Tab>
              <Tab
                onClick={() => {
                  analyticsTracker.modelDetailClicked({
                    accountId: params.accountName,
                    clickedButton: 'click-performance-chart-tab',
                    modelStatus: data.status,
                  });
                }}
              >
                {t('model.myModelDetail.performanceChart.loss')}
              </Tab>
            </TabList>
            <TabPanel style={{ padding: 16 }}>
              {isRecognitionAIModelTraining(data.modelTraining) && (
                <APChart modelTrainingData={data.modelTraining} />
              )}
            </TabPanel>
            <TabPanel style={{ padding: 16 }}>
              {isRecognitionAIModelTraining(data.modelTraining) && (
                <LossChart modelTrainingData={data.modelTraining} />
              )}
            </TabPanel>
          </Tabs>
        ) : (
          <Box
            width="100%"
            height="100%"
            display="flex"
            alignItems="center"
            justifyContent="center"
            textAlign="center"
          >
            <Typography variant="m-regular" color="gray-300">
              <Trans t={t} i18nKey={'model.myModelDetail.performanceChart.beforeOneEpoch'} />
            </Typography>
          </Box>
        )}
      </Box>
    </Box>
  );
};

const APChart = ({
  modelTrainingData,
}: {
  modelTrainingData: RecognitionAIModelDetailTraining;
}) => {
  const { t } = useTranslation();
  const defaultAPColor = '#ff625a';
  const classAPColor = '#5A7BFF';
  const [searchClassName, setSearchClassName] = useState('');
  const [selectedClassNames, setSelectedClassNames] = useState<string | null>(null);

  const modelDetailTrainingResult = modelTrainingData.modelTrainingEpochs.map(
    epoch => epoch.detailTrainingResult,
  );
  const classSelectOption = modelDetailTrainingResult[0].classPerformance.map(x => ({
    label: x.className,
    value: x.className,
  }));
  const filteredOptions = classSelectOption.filter(option =>
    option.label.includes(searchClassName),
  );

  const mAPChartData = modelDetailTrainingResult.map((result, index) => ({
    epochs: index,
    ap: result.overallPerformance.ap,
  }));
  const selectedClassChartData = modelDetailTrainingResult.map((result, index) => ({
    epochs: index,
    ap: result.classPerformance.find(x => x.className === selectedClassNames)?.ap,
  }));

  return (
    <>
      <Row justifyContent="space-between" mb={2}>
        <Box display="grid" style={{ width: 140 }}>
          <Select
            variant="soft-fill"
            size="s"
            placeholder={t('model.myModelDetail.selectClass')}
            data={filteredOptions}
            onChangeValue={v => {
              setSelectedClassNames(v);
            }}
            value={selectedClassNames}
            prefix={
              <DropdownSearchInput
                placeholder={t('model.myModelDetail.performanceChart.placeholder')}
                setSearchValue={setSearchClassName}
              />
            }
          />
        </Box>
        <Row gap={1.5}>
          <Row style={{ gap: 6 }}>
            <Box backgroundColor={'primary'} style={{ width: 6, height: 6, borderRadius: 3 }} />
            {t('model.myModelDetail.performanceChart.mAP')}
          </Row>
          {selectedClassNames && (
            <Row style={{ gap: 6 }}>
              <Box backgroundColor={'secondary'} style={{ width: 6, height: 6, borderRadius: 3 }} />
              {selectedClassNames}
            </Row>
          )}
        </Row>
      </Row>
      <Chart size={{ width: '100%', height: 204 }}>
        <Settings theme={lineChartTheme()} />
        <Tooltip
          type={TooltipType.Crosshairs}
          {...(TooltipHeaderComponent && {
            headerFormatter: ({ datum }) => {
              return <TooltipHeaderComponent datum={datum} />;
            },
          })}
        />
        <Axis
          id="bottom"
          position={Position.Bottom}
          title={t('model.myModelDetail.epochs')}
          style={{ tickLabel: { fill: 'black' }, axisTitle: { visible: false } }}
          tickFormat={d => String(Number(d) + 1)}
        />
        <Axis
          id="left"
          title={'mAP'}
          position={Position.Left}
          style={{ tickLabel: { fill: 'black' }, axisTitle: { visible: false } }}
        />
        <LineSeries
          id="mAP"
          data={mAPChartData}
          xAccessor={'epochs'}
          yAccessors={['ap']}
          color={defaultAPColor}
          tickFormat={d => Number(d * 100).toFixed(1)}
          lineSeriesStyle={{
            point: {
              visible: false,
            },
          }}
        />
        <LineSeries
          id={selectedClassNames ?? ''}
          data={selectedClassChartData}
          xAccessor={'epochs'}
          yAccessors={['ap']}
          color={classAPColor}
          tickFormat={d => Number(d * 100).toFixed(1)}
          lineSeriesStyle={{
            point: {
              visible: false,
            },
          }}
        />
      </Chart>
    </>
  );
};

const LossChart = ({
  modelTrainingData,
}: {
  modelTrainingData: RecognitionAIModelDetailTraining;
}) => {
  const { t } = useTranslation();
  const trainSetColor = '#3479FF';
  const validSetColor = '#FFB600';

  const modelDetailTrainingResult = modelTrainingData.modelTrainingEpochs.map(
    epoch => epoch.detailTrainingResult,
  );

  const trainSetLossChartData = modelDetailTrainingResult.map((result, index) => ({
    epochs: index,
    loss: result.overallPerformance.trainLoss,
  }));
  const validationSetLossChartData = modelDetailTrainingResult.map((result, index) => ({
    epochs: index,
    loss: result.overallPerformance.evalLoss,
  }));

  return (
    <>
      <Row justifyContent="flex-end" gap={1.5} mb={2} style={{ height: 24 }}>
        <Row style={{ gap: 6 }}>
          <Box backgroundColor={'yellow-400'} style={{ width: 6, height: 6, borderRadius: 3 }} />
          {t('model.train.validationSet')}
        </Row>
        <Row style={{ gap: 6 }}>
          <Box backgroundColor={'blue-400'} style={{ width: 6, height: 6, borderRadius: 3 }} />
          {t('model.train.trainSet')}
        </Row>
      </Row>
      <Chart size={{ width: '100%', height: 204 }}>
        <Settings theme={lineChartTheme()} />
        <Tooltip
          type={TooltipType.Crosshairs}
          {...(TooltipHeaderComponent && {
            headerFormatter: ({ datum }) => {
              return <TooltipHeaderComponent datum={datum} />;
            },
          })}
        />
        <Axis
          id="bottom"
          position={Position.Bottom}
          title={'Epochs'}
          style={{ tickLabel: { fill: 'black' }, axisTitle: { visible: false } }}
          tickFormat={d => String(Number(d) + 1)}
        />
        <Axis
          id="left"
          title={'Loss'}
          position={Position.Left}
          style={{ tickLabel: { fill: 'black' }, axisTitle: { visible: false } }}
        />
        <LineSeries
          id={t('model.train.validationSet')}
          data={validationSetLossChartData}
          xAccessor={'epochs'}
          yAccessors={['loss']}
          color={validSetColor}
          tickFormat={d => Number(d).toFixed(1)}
          lineSeriesStyle={{
            point: {
              visible: false,
            },
          }}
        />
        <LineSeries
          id={t('model.train.trainSet')}
          data={trainSetLossChartData}
          xAccessor={'epochs'}
          yAccessors={['loss']}
          color={trainSetColor}
          tickFormat={d => Number(d).toFixed(1)}
          lineSeriesStyle={{
            point: {
              visible: false,
            },
          }}
        />
      </Chart>
    </>
  );
};

function TooltipHeaderComponent(props: { datum: Record<string, any> }): ReactElement {
  const { t } = useTranslation();
  const epochs = props.datum?.epochs + 1;
  return (
    <Box display="flex" style={{ width: '182px', marginRight: '-10px' }}>
      <Box>{t('model.myModelDetail.epochs')}</Box>
      <Box display="flex" ml="auto" pr={1}>
        {epochs ? epochs : '-'}
      </Box>
    </Box>
  );
}
