import { ComponentProps, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { useHistory, useParams } from 'react-router-dom';

import { Box, Typography } from '@superb-ai/ui';

import analyticsTracker from '../../../analyticsTracker';
import { ONE_CREDIT } from '../../../queries/meteringLogic';
import { getUrl } from '../../../routes/util';
import RegexUtils from '../../../utils/RegexUtils';
import { useDatasetDataCountQuery } from '../../Curate/queries/dataQueries';
import { useSliceQueriesWithNames } from '../../Curate/queries/sliceQueries';
import { useCurateDatasetService } from '../../Curate/services/DatasetService';
import {
  MAXIMUM_MODEL_NAME_LENGTH,
  MINIMUM_MODEL_NAME_LENGTH,
  RECOMMENDED_GEN_AI_TRAIN_SET_IMAGE_NUMBER,
  RECOMMENDED_RECOGNITION_AI_TRAIN_SET_IMAGE_NUMBER_IN_AUTOMATICALLY_SAMPLE,
  RECOMMENDED_RECOGNITION_AI_TRAIN_SET_IMAGE_NUMBER_IN_MANUALLY_SELECT,
  RECOMMENDED_RECOGNITION_AI_VALIDATION_SET_IMAGE_NUMBER_IN_MANUALLY_SELECT,
  TRAINING_CREDIT,
} from '../constant';
import { GenerationAIDetailsMenuItem } from '../gen-ai/details/MenuItem';
import { GENERATION_MODEL, RECOGNITION_MODEL } from '../path';
import { useCreateModelMutation, useEstimatedTrainingTimeMutation } from '../queries/modelQueries';
import { RecognitionAIDetailMenuItem } from '../recognition-ai/detail/MenuItem';
import { RecognitionQueryKeyword } from '../recognition-ai/queries';
import { CreateModelParams } from '../services/types';
import { useBaselineContext } from './contexts/BaselineContext';
import { useDatasetClassContext } from './contexts/DatasetClassContext';
import { useModelSettingContext } from './contexts/ModelSettingContext';
import { Estimates } from './Estimates';
import { useTrainUrlParams } from './queries';
import { BaselineModelStep } from './steps/BaselineModelStep';
import { DatasetClassStep } from './steps/DatasetClassStep';
import { ModelSettingStep } from './steps/ModelSettingStep';
import { Step, TrainModelStepper } from './steps/TrainModelStepper';
import { TrainValidationSplitStep } from './steps/TrainValidationSplitStep';

const hourlyRate = TRAINING_CREDIT * ONE_CREDIT;

export function TrainModelDialog() {
  const params = useTrainUrlParams();
  const { t } = useTranslation();
  const history = useHistory();

  const { mutate: createModel } = useCreateModelMutation();
  const { accountName } = useParams<{
    accountName: string;
  }>();
  const { getDataset, createSlice } = useCurateDatasetService();

  const {
    publicModelId,
    summary: baselineSummary,
    selectedPublicModel,
    myModelId,
    selectedMyModel,
    uniqueNameFilter,
    isModelNameUnique,
    modelPurpose,
  } = useBaselineContext();

  const annotationType =
    selectedMyModel?.baselineModel.annotationType ?? selectedPublicModel?.annotationType ?? [];

  const {
    datasetId,
    setDatasetName,
    trainingSetNames,
    validationSetNames,
    selectedAnnotationClasses,
    datasetClassSummary,
    splitType,
    trainSetSliceName,
    validationSetSliceName,
    isValidSliceNames,
  } = useDatasetClassContext();
  const { memo, selectedTags } = useModelSettingContext();

  // @ts-ignore
  const sortedSelectedAnnotationClasses = selectedAnnotationClasses.toSorted((a, b) =>
    a.name.localeCompare(b.name),
  );

  const sliceQueries = useSliceQueriesWithNames({
    datasetId: datasetId,
    fromPublicDatasets: false,
    enabled: Boolean(datasetId),
    names: [...trainingSetNames, ...validationSetNames],
  });

  const selectedSourceIdNames = sliceQueries
    .flatMap(_ => _.data?.results)
    .filter(x => trainingSetNames.some(name => name === x?.name))
    .map(x => ({ id: x?.id ?? '', name: x?.name ?? '' }));

  const selectedValidationSetIdNames = sliceQueries
    .flatMap(_ => _.data?.results)
    .filter(x => validationSetNames?.some(name => name === x?.name))
    .map(x => ({ id: x?.id ?? '', name: x?.name ?? '' }));

  const hasAllData = Boolean(
    (publicModelId || myModelId) &&
      datasetId &&
      trainingSetNames?.length &&
      uniqueNameFilter &&
      (selectedAnnotationClasses.length > 0 || selectedAnnotationClasses.length > 0),
  );

  const datasetDataCountQuery = useDatasetDataCountQuery({
    datasetId: datasetId ?? '',
    fromPublicDatasets: false,
    queryString: queryBuilder({
      OR: [...trainingSetNames, ...validationSetNames],
      annotationType: annotationType,
    }),
  });
  const notFetchdatasetDataCountQuery = [...trainingSetNames, ...validationSetNames].length === 0;

  const trainingSetDataCountQuery = useDatasetDataCountQuery({
    datasetId: datasetId ?? '',
    fromPublicDatasets: false,
    queryString: queryBuilder({
      OR: trainingSetNames,
      AND: validationSetNames,
      annotationType: annotationType,
    }),
    disabled: Boolean(!datasetId) || Boolean(trainingSetNames.length === 0),
  });

  const validationSetDataCountQuery = useDatasetDataCountQuery({
    datasetId: datasetId ?? '',
    fromPublicDatasets: false,
    queryString: queryBuilder({ OR: validationSetNames, annotationType: annotationType }),
    disabled: Boolean(!datasetId),
  });

  const { data: estimatedTrainingTime, mutate: estimateTrainingTime } =
    useEstimatedTrainingTimeMutation();

  useEffect(() => {
    (async () => {
      const dataset = await getDataset({ datasetId: datasetId ?? '', fromPublicDatasets: false });
      setDatasetName(dataset.name ?? null);
    })();
  }, [datasetId]);

  async function onStartTrainingByManualSplit() {
    if (!datasetId) return;
    // Create data
    const createModelParams: CreateModelParams = {
      name: uniqueNameFilter,
      baselineModelId: selectedMyModel ? selectedMyModel.baselineModel.id : publicModelId ?? '',
      parentModelId: selectedMyModel ? selectedMyModel.id : null,
      referenceId: datasetId ?? '',
      type: 'curate',
      trainingSetList: selectedSourceIdNames ?? [],
      validationSetList: selectedValidationSetIdNames ?? [],
      autoTrainingSet: null,
      autoValidationSet: null,
      // reference: https://superb-ai.slack.com/archives/C054MDHV820/p1709772593795159
      annotationClassList: selectedMyModel
        ? selectedMyModel.modelSetting.annotationClassList
        : sortedSelectedAnnotationClasses,
      modelTraining: {
        numTrainingSet: trainingSetDataCountQuery.data?.count ?? 0,
        numValidationSet: validationSetDataCountQuery.data?.count ?? 0,
      },
      memo,
      modelTags: selectedTags,
    };

    const modelTrainingRequestedEventTracking = () => {
      analyticsTracker.modelTrainingRequested({
        accountId: accountName,
        modelPurpose,
        datasetId,
        source: selectedPublicModel?.source ?? selectedMyModel?.baselineModel.source,
      });
    };

    createModel(createModelParams, {
      onSuccess(data) {
        history.push(
          params.from === 'recognition'
            ? getUrl(
                [accountName, RECOGNITION_MODEL, RecognitionAIDetailMenuItem.path],
                {
                  id: data.id,
                },
                {
                  [RecognitionQueryKeyword.DetailTab]: 'progress',
                },
              )
            : getUrl(
                [accountName, GENERATION_MODEL, GenerationAIDetailsMenuItem.path],
                {
                  id: data.id,
                },
                {
                  [RecognitionQueryKeyword.DetailTab]: 'progress',
                },
              ),
        );
      },
    });
    modelTrainingRequestedEventTracking();
  }

  const trackDataSliceCreated = useCallback(
    (newSliceId: string, referrer: 'model-train-training' | 'train-model-validation') => {
      analyticsTracker.dataSliceCreated({
        accountId: accountName,
        sliceId: newSliceId,
        datasetId,
        dataCount: 'UNKNOWN',
        dataType: 'image',
        referrer: referrer,
      });
    },
    [accountName, datasetId, history],
  );

  async function onStartTrainingByRandomSplit() {
    if (!datasetId) return;

    // slice 생성은 random split에서만
    const trainSetSlice = await createSlice({
      datasetId,
      name: trainSetSliceName,
      description: '',
    });
    trackDataSliceCreated(trainSetSlice.id, 'model-train-training');
    const validationSetSlice = await createSlice({
      datasetId,
      name: validationSetSliceName,
      description: '',
    });
    trackDataSliceCreated(validationSetSlice.id, 'train-model-validation');

    // Create data
    const createModelParams: CreateModelParams = {
      name: uniqueNameFilter,
      baselineModelId: selectedMyModel ? selectedMyModel.baselineModel.id : publicModelId ?? '',
      parentModelId: selectedMyModel ? selectedMyModel.id : null,
      referenceId: datasetId ?? '',
      type: 'curate',
      trainingSetList: selectedSourceIdNames ?? [],
      validationSetList: selectedValidationSetIdNames ?? [],
      autoTrainingSet: { id: trainSetSlice.id, name: trainSetSlice.name },
      autoValidationSet: { id: validationSetSlice.id, name: validationSetSlice.name },
      // reference: https://superb-ai.slack.com/archives/C054MDHV820/p1709772593795159
      annotationClassList: selectedMyModel
        ? selectedMyModel.modelSetting.annotationClassList
        : sortedSelectedAnnotationClasses,
      modelTraining: {
        numTrainingSet: (trainingSetDataCountQuery.data?.count ?? 0) * 0.8,
        numValidationSet: (trainingSetDataCountQuery.data?.count ?? 0) * 0.2,
      },
      memo,
      modelTags: selectedTags,
    };

    const modelTrainingRequestedEventTracking = () => {
      analyticsTracker.modelTrainingRequested({
        accountId: accountName,
        modelPurpose,
        datasetId,
        trainSliceId: [trainSetSlice.id],
        validationSliceId: [validationSetSlice.id],
        source: selectedPublicModel?.source ?? selectedMyModel?.baselineModel.source,
      });
    };

    createModel(createModelParams, {
      onSuccess(data) {
        history.push(
          params.from === 'recognition'
            ? getUrl(
                [accountName, RECOGNITION_MODEL, RecognitionAIDetailMenuItem.path],
                {
                  id: data.id,
                },
                {
                  [RecognitionQueryKeyword.DetailTab]: 'progress',
                },
              )
            : getUrl(
                [accountName, GENERATION_MODEL, GenerationAIDetailsMenuItem.path],
                {
                  id: data.id,
                },
                {
                  [RecognitionQueryKeyword.DetailTab]: 'progress',
                },
              ),
        );
      },
    });
    modelTrainingRequestedEventTracking();
  }

  async function onStartTrainingByGenAI() {
    if (!datasetId) return;
    // Create data
    const createModelParams: CreateModelParams = {
      name: uniqueNameFilter,
      baselineModelId: selectedMyModel ? selectedMyModel.baselineModel.id : publicModelId ?? '',
      parentModelId: selectedMyModel ? selectedMyModel.id : null,
      referenceId: datasetId ?? '',
      type: 'curate',
      trainingSetList: selectedSourceIdNames ?? [],
      validationSetList: selectedValidationSetIdNames ?? [],
      autoTrainingSet: null,
      autoValidationSet: null,
      annotationClassList: sortedSelectedAnnotationClasses,
      modelTraining: {
        numTrainingSet: (trainingSetDataCountQuery.data?.count ?? 0) * 0.8,
        numValidationSet: (trainingSetDataCountQuery.data?.count ?? 0) * 0.2,
      },
      memo,
      modelTags: selectedTags,
    };

    const modelTrainingRequestedEventTracking = () => {
      analyticsTracker.modelTrainingRequested({
        accountId: accountName,
        modelPurpose,
        datasetId,
        source: selectedPublicModel?.source ?? selectedMyModel?.baselineModel.source,
      });
    };

    createModel(createModelParams, {
      onSuccess(data) {
        history.push(
          getUrl(
            [accountName, GENERATION_MODEL, GenerationAIDetailsMenuItem.path],
            {
              id: data.id,
            },
            {
              [RecognitionQueryKeyword.DetailTab]: 'progress',
            },
          ),
        );
      },
    });
    modelTrainingRequestedEventTracking();
  }

  const stepBaselineModel: Step = {
    title: t('model.myModels.baselineModel'),
    description: t('model.train.stepBaselineModelDescription'),
    content: <BaselineModelStep />,
    summary: baselineSummary,
    isButtonEnabled: Boolean(publicModelId) || Boolean(selectedMyModel),
    eventTrackerParams: {
      step: 'choose-baseline-model',
      modelPurpose: modelPurpose,
      source: selectedMyModel ? 'my-model' : 'new-model',
    },
  };

  const stepTrainValidationSplit: Step = {
    title: t('model.train.trainValidationSplit'),
    description: t('model.train.stepTrainValidationSplitDescription'),
    content: <TrainValidationSplitStep />,
    summary: splitType ? (
      <Typography variant="m-regular">
        {splitType === 'manual' ? t('model.train.manualSplit') : t('model.train.randomSplit')}
      </Typography>
    ) : undefined,
    isButtonEnabled: Boolean(splitType),
    eventTrackerParams: {
      step: 'split-train-val',
    },
  };

  const stepDatasetClass: Step = {
    title: t('model.train.datasetClass'),
    description: t('model.train.stepDatasetClassDescription'),
    content: <DatasetClassStep />,
    summary: datasetClassSummary,
    isButtonEnabled:
      modelPurpose === 'generation'
        ? Boolean(
            datasetId &&
              trainingSetNames.length &&
              selectedAnnotationClasses.length &&
              !trainingSetDataCountQuery.isLoading &&
              RECOMMENDED_GEN_AI_TRAIN_SET_IMAGE_NUMBER <=
                (trainingSetDataCountQuery.data?.count ?? 0),
          )
        : splitType === 'manual'
        ? Boolean(
            datasetId &&
              trainingSetNames.length &&
              validationSetNames.length &&
              selectedAnnotationClasses.length &&
              !trainingSetDataCountQuery.isLoading &&
              RECOMMENDED_RECOGNITION_AI_TRAIN_SET_IMAGE_NUMBER_IN_MANUALLY_SELECT <=
                (trainingSetDataCountQuery.data?.count ?? 0) &&
              !validationSetDataCountQuery.isLoading &&
              RECOMMENDED_RECOGNITION_AI_VALIDATION_SET_IMAGE_NUMBER_IN_MANUALLY_SELECT <=
                (validationSetDataCountQuery.data?.count ?? 0),
          )
        : Boolean(
            datasetId &&
              trainingSetNames.length &&
              selectedAnnotationClasses.length &&
              isValidSliceNames &&
              !trainingSetDataCountQuery.isLoading &&
              RECOMMENDED_RECOGNITION_AI_TRAIN_SET_IMAGE_NUMBER_IN_AUTOMATICALLY_SAMPLE <=
                (trainingSetDataCountQuery.data?.count ?? 0),
          ),
    eventTrackerParams: {
      step: 'select-dataset-and-class',
    },
  };

  const canSubmit =
    hasAllData &&
    isModelNameUnique &&
    uniqueNameFilter.length <= MAXIMUM_MODEL_NAME_LENGTH &&
    uniqueNameFilter.length >= MINIMUM_MODEL_NAME_LENGTH &&
    !RegexUtils.HAS_SPECIAL_SYMBOLS(uniqueNameFilter) &&
    (modelPurpose === 'generation' || Boolean(splitType));

  // Model Name
  const stepModelSetting: Step = {
    title: t('model.train.modelSetting'),
    description: t('model.train.stepModelSettingDescription'),
    content: <ModelSettingStep />,
    isButtonEnabled: canSubmit,
    eventTrackerParams: {
      step: 'complete-model-setting',
    },
  };

  function getConfiguredSteps(): ComponentProps<typeof TrainModelStepper>['steps'] {
    if (modelPurpose === 'generation') {
      return {
        stepBaselineModel,
        stepDatasetClass,
        stepModelSetting,
      };
    } else {
      if (selectedMyModel && selectedMyModel.trainingSet.validationSetList.length > 0) {
        return {
          stepBaselineModel,
          stepDatasetClass,
          stepModelSetting,
        };
      } else {
        return {
          stepBaselineModel,
          stepTrainValidationSplit,
          stepDatasetClass,
          stepModelSetting,
        };
      }
    }
  }

  const getStartingStepIndex = () => {
    if (params.publicModelId) {
      return 1;
    } else if (params.trainedModelId) {
      return 1;
    }
    return 0;
  };

  // 예상 트레이닝 시간 mutate
  useEffect(() => {
    if (!trainingSetDataCountQuery.data?.count || !validationSetDataCountQuery.data?.count) {
      return;
    }
    if (selectedPublicModel) {
      estimateTrainingTime({
        baselineModel: selectedPublicModel.name,
        // 슬라이스가 선택되지 않았을 때 trainingSetDataCountQuery는 모든 이미지 개수를 반환함
        numTrainingSet: trainingSetNames.length > 0 ? trainingSetDataCountQuery.data?.count : 0,
        numValidationSet:
          validationSetNames.length > 0 ? validationSetDataCountQuery.data?.count : 0,
      });
    }
    if (selectedMyModel) {
      estimateTrainingTime({
        baselineModel: selectedMyModel.baselineModel.name,
        // 슬라이스가 선택되지 않았을 때 trainingSetDataCountQuery는 모든 이미지 개수를 반환함
        numTrainingSet: trainingSetNames.length > 0 ? trainingSetDataCountQuery.data?.count : 0,
        numValidationSet:
          validationSetNames.length > 0 ? validationSetDataCountQuery.data?.count : 0,
      });
    }
  }, [
    estimateTrainingTime,
    selectedPublicModel,
    selectedMyModel,
    trainingSetDataCountQuery.data?.count,
    trainingSetNames.length,
    validationSetDataCountQuery.data?.count,
    validationSetNames.length,
  ]);

  return (
    <>
      <Box
        position="fixed"
        top="0"
        bottom="0"
        left={'0'}
        right={'0'}
        style={{ zIndex: 1299 }}
        backgroundColor="white"
      >
        <TrainModelStepper
          steps={getConfiguredSteps()}
          startingStepIndex={getStartingStepIndex()}
          lastStepButton={{
            text: t('model.train.startTraining'),
            onClick:
              modelPurpose === 'generation'
                ? onStartTrainingByGenAI
                : splitType === 'manual'
                ? onStartTrainingByManualSplit
                : onStartTrainingByRandomSplit,
          }}
          footer={
            <>
              <Box display="flex" alignItems="center">
                <Estimates
                  imageCount={notFetchdatasetDataCountQuery ? 0 : datasetDataCountQuery.data?.count}
                  trainingTimeSeconds={estimatedTrainingTime?.totalEstimatedTime ?? 0}
                  creditsPerHour={hourlyRate / ONE_CREDIT}
                  isLoading={
                    datasetDataCountQuery.isFetching ||
                    trainingSetDataCountQuery.isFetching ||
                    validationSetDataCountQuery.isFetching
                  }
                />
              </Box>
            </>
          }
        />
      </Box>
    </>
  );
}

export function queryBuilder({
  OR,
  AND,
  annotationType,
}: {
  OR: string[];
  AND?: string[];
  annotationType?: string[];
}) {
  let sourceQuery = OR.map(item => `(slice = "${item}")`).join(' OR ');
  if (AND?.length) {
    const validationSetQuery = AND.map(item => `(slice != "${item}")`).join(' AND ');
    sourceQuery = `(${sourceQuery}) AND ${validationSetQuery}`;
  }
  if (sourceQuery.length > 0 && annotationType?.length) {
    const annotationTypeQuery = annotationType
      .map(type => `(annotations.type.${type}.count > 0)`)
      .join(' AND ');
    sourceQuery = `(${sourceQuery}) AND ${annotationTypeQuery}`;
  }

  return sourceQuery;
}
