import React, { useMemo } from 'react';
import { useParams } from 'react-router';

import { useDatasetContext } from '../../../../../contexts/DatasetContext';
import { useObjectClusterContext } from '../../../../../contexts/ObjectClusterContext';
import { useObjectFilterContext } from '../../../../../contexts/ObjectFilterContext';
import { usePublicDatasetContext } from '../../../../../contexts/PublicDatasetContextProvider';
import { useQueryContext } from '../../../../../contexts/QueryContext';
import { useSliceContext } from '../../../../../contexts/SliceContext';
import {
  ComparedGroup,
  useComparedObjectScatterPointsQuery,
  useObjectScatterPointsQuery,
} from '../../views/embedding/queries/embeddingQueries';
import { ObjectEmbeddingDatum } from '../../views/embedding/types';
import { useCompareModeContext } from './CompareModeContext';

type ContextProps = {
  data?: ObjectEmbeddingDatum[];
  isLoadingProjections: boolean;
  dataCompare: ObjectEmbeddingDatum[];
  isLoadingProjectionsCompare: boolean;
  isFetching: boolean;
  sampledCount: number;
};

const Context = React.createContext({} as ContextProps);

export type CompareOption = { group: ComparedGroup; id: string; name: string };
function getSetInfo(options: CompareOption[], selectedId?: string) {
  return options.find(option => option.id === selectedId);
}

const useProvider = () => {
  const { sliceInfo } = useSliceContext();
  const { queryString } = useQueryContext();
  const { datasetInfo } = useDatasetContext();
  const { appliedFilters } = useObjectFilterContext();
  const { clusterLevel } = useObjectClusterContext();
  const { showPublicDatasets } = usePublicDatasetContext();
  const { datasetId } = useParams<{ datasetId: string }>();
  const { comparedId, showComparedPoints } = useCompareModeContext();

  const compareOptions = [
    { group: 'slice', id: sliceInfo?.id, name: sliceInfo?.name },
    { group: 'dataset', id: datasetInfo?.id, name: datasetInfo?.name },
  ] as CompareOption[];

  const scatterDependencies = useMemo(
    () => ({
      datasetId,
      fromPublicDatasets: showPublicDatasets,
      clusterLevel,
      ...(queryString && { query: queryString }),
      ...(sliceInfo?.name && { slice: sliceInfo.name }),
      ...(appliedFilters && { appliedFilters: appliedFilters }),
    }),
    [datasetId, showPublicDatasets, queryString, sliceInfo, clusterLevel, appliedFilters],
  );
  const {
    data: points,
    isError: isErrorProjections,
    isInitialLoading: isLoadingProjections,
    isFetching,
  } = useObjectScatterPointsQuery(scatterDependencies);

  const {
    data: dataCompare,
    isError: isErrorProjectionsCompare,
    isInitialLoading: isLoadingProjectionsCompare,
  } = useComparedObjectScatterPointsQuery({
    ...scatterDependencies,
    comparedOption: comparedId ? getSetInfo(compareOptions, comparedId) : undefined,
    showComparedPoints,
  });

  return {
    data: points,
    dataCompare: dataCompare ?? [],
    isLoadingProjections: isLoadingProjections,
    isLoadingProjectionsCompare: isLoadingProjectionsCompare,
    isFetching,
    sampledCount: points?.length ?? 0,
  };
};

export const useObjectScatterContext = (): ContextProps => {
  return React.useContext(Context);
};

export const ObjectScatterProvider: React.FC = ({ children }) => {
  const info = useProvider();
  return <Context.Provider value={info}>{children}</Context.Provider>;
};
