import React, { useMemo } from 'react';
import { useParams } from 'react-router';

import { useDatasetContext } from '../../../../../contexts/DatasetContext';
import { useImageFilterContext } from '../../../../../contexts/ImageFilterContext';
import { usePublicDatasetContext } from '../../../../../contexts/PublicDatasetContextProvider';
import { useQueryContext } from '../../../../../contexts/QueryContext';
import { useSliceContext } from '../../../../../contexts/SliceContext';
import {
  ComparedGroup,
  useComparedImageScatterPointsQuery,
  useImageScatterPointsQuery,
} from '../../views/embedding/queries/embeddingQueries';
import { EmbeddingDatum } from '../../views/embedding/types';
import { useCompareModeContext } from './CompareModeContext';

type ContextProps = {
  data?: EmbeddingDatum[];
  dataCompare?: EmbeddingDatum[];
  isLoadingProjections: boolean;
  isLoadingProjectionsCompare: boolean;
  sampledCount?: number;
};

const Context = React.createContext({} as ContextProps);

export type CompareOption = { group: ComparedGroup; id: string; name: string };
function getSetInfo(options: CompareOption[], selectedId?: string) {
  return options.find(option => option.id === selectedId);
}

const useProvider = () => {
  const { sliceInfo } = useSliceContext();
  const { queryString } = useQueryContext();
  const { datasetInfo } = useDatasetContext();
  const { showComparedPoints, comparedId } = useCompareModeContext();
  const { showPublicDatasets } = usePublicDatasetContext();
  const { datasetId } = useParams<{ datasetId: string }>();
  const compareOptions = [
    { group: 'slice', id: sliceInfo?.id, name: sliceInfo?.name },
    { group: 'dataset', id: datasetInfo?.id, name: datasetInfo?.name },
  ] as CompareOption[];
  const { clusterLevel } = useImageFilterContext();
  const scatterDependencies = useMemo(
    () => ({
      datasetId,
      fromPublicDatasets: showPublicDatasets,
      ...(queryString && { query: queryString }),
      ...(sliceInfo?.name && { slice: sliceInfo.name }),
      clusterLevel,
    }),
    [datasetId, showPublicDatasets, queryString, clusterLevel, sliceInfo],
  );

  const {
    data: points,
    isError: isErrorProjections,
    isFetching: isLoadingProjections,
  } = useImageScatterPointsQuery(scatterDependencies);

  const {
    data: dataCompare,
    isError: isErrorProjectionsCompare,
    isFetching: isLoadingProjectionsCompare,
  } = useComparedImageScatterPointsQuery({
    ...scatterDependencies,
    comparedOption: comparedId ? getSetInfo(compareOptions, comparedId) : undefined,
    showComparedPoints,
  });

  return {
    data: points,
    dataCompare: dataCompare ?? [],
    isLoadingProjections: isLoadingProjections,
    isLoadingProjectionsCompare: isLoadingProjectionsCompare,
    sampledCount: points?.length ?? 0,
  };
};

export const useImageScatterContext = (): ContextProps => {
  return React.useContext(Context);
};

export const ImageScatterProvider: React.FC = ({ children }) => {
  const info = useProvider();
  return <Context.Provider value={info}>{children}</Context.Provider>;
};
