import { useQuery } from '@tanstack/react-query';
import { sortBy } from 'lodash';

import {
  DatasetObjects,
  PostGetObjectRequest,
  useCurateDatasetObjectService,
} from '../../../../../../services/DatasetObjectService';
import {
  DatasetBriefData,
  useCurateDatasetService,
} from '../../../../../../services/DatasetService';
import {
  Cluster,
  GetClustersRequestParams,
  useDataOpsEmbeddingService,
} from '../../../../../../services/EmbeddingService';
import {
  convertSelectedSuperClustersToQueryString,
  stringifyFilters,
} from '../../../../../../utils/filterUtils';
import { calculatePercent } from '../../../../../../utils/numberUtils';
import {
  AnnotationFilterSchema,
  CLUSTER_LEVEL_2_SIZE_8,
  CLUSTER_LEVEL_3_SIZE_16,
  ClusterLevel,
  ImageFilterSchema,
  SuperClusters,
} from '../../../filter/types';
import { SUPER_CLUSTER_COLORS_8, SUPER_CLUSTER_COLORS_16 } from '../scatterView/const';
import { generateFilterDotColors } from '../scatterView/utils/color';
import {
  EmbeddingDatum,
  GeoTileAggConfig,
  ImageTileRequestParams,
  ImageTileResponse,
  ObjectEmbeddingDatum,
  ObjectTileRequestParams,
  ObjectTileResponse,
  ScatterParams,
  TileBucket,
} from '../types';

export type CompareOption = { group: string; id: string; name: string };

type ProjectionInParams = {
  top: number;
  bottom: number;
  left: number;
  right: number;
};

export type Dependencies = {
  datasetId: string;
  fromPublicDatasets: boolean;
  query?: string;
  slice?: string;
  projectionIn?: ProjectionInParams;
};

export type JobDependencies = {
  jobId: string;
};

async function addImageUrlsClusterInfosToPoints(
  datasetId: string,
  fromPublicDatasets: boolean,
  points: EmbeddingDatum[],
  getDatasetBriefDataList: ReturnType<typeof useCurateDatasetService>['getDatasetBriefDataList'],
  sliceName?: string,
  query?: string,
) {
  let imageIdToUrls = {};
  const chunkSize = 100;
  for (let i = 0; i < points.length; i += chunkSize) {
    const chunk = points.slice(i, i + chunkSize);
    // No need to include filter params here, since image ids are given
    // we could still potentially send filter params though
    const datasetImages = await getDatasetBriefDataList({
      datasetId,
      fromPublicDatasets,
      searchAfter: undefined,
      expand: ['small_thumbnail_url', 'thumbnail_url', 'leaf_cluster', 'super_clusters'],
      size: chunkSize,
      query,
      idIn: chunk.map(item => item.id),
      ...(sliceName && { sliceName }),
    });
    imageIdToUrls = datasetImages?.results.reduce(
      (
        acc: Record<
          string,
          {
            small: string;
            medium: string;
            leafClusterId: string;
            leafClusterSize: number;
            superClusters: SuperClusters;
          }
        >,
        image: DatasetBriefData,
      ) => {
        acc[image.id] = {
          small: image.smallThumbnailUrl || '',
          medium: image.thumbnailUrl,
          leafClusterId: image.leafCluster.id,
          leafClusterSize: image.leafCluster.size,
          superClusters: image.superClusters,
        };
        return acc;
      },
      imageIdToUrls,
    );
  }
  return points.map((point: EmbeddingDatum) => {
    const { small, medium, leafClusterId, superClusters, leafClusterSize } =
      imageIdToUrls[point.id as keyof typeof imageIdToUrls];
    return {
      ...point,
      imageThumbnailUrl: medium,
      smallImageThumbnailUrl: small,
      leafClusterId,
      leafClusterSize,
      superClusters,
    };
  });
}

const ES_CONFIG: GeoTileAggConfig = {
  size: 50, // bucket pagination
  tilePrecision: 12, // increase precision to increase number of tiles
  maxNumDocs: 30, // max number of docs to return per tile
};

type ComparedScatterQuery = {
  comparedOption?: CompareOption;
  showComparedPoints?: boolean;
};

type TileGetter<T extends AnnotationFilterSchema | ImageFilterSchema | undefined> = (
  apiParams: ScatterParams & GeoTileAggConfig & { appliedFilters?: T },
) => Promise<ImageTileResponse>;

async function getPointsInAllTiles<
  T extends AnnotationFilterSchema | ImageFilterSchema | undefined,
>(apiParams: ScatterParams & GeoTileAggConfig & { appliedFilters?: T }, getTiles: TileGetter<T>) {
  const output = [];
  const applyFilters = (params: ScatterParams & GeoTileAggConfig & { appliedFilters?: T }) => {
    const { appliedFilters, ...restParams } = params;
    return getTiles({ ...restParams, appliedFilters });
  };

  let response = await applyFilters(apiParams);

  // terminate when response has no buckets
  while (response?.results.length > 0) {
    const flattenedPoints = extractPointsFromBuckets(response);
    output.push(...flattenedPoints);
    response = await getTiles({
      ...apiParams,
      searchAfter: response?.searchAfter,
    });
    // const buckets = response?.results.length;
    // processedBuckets += buckets;
  }
  return output;
}

export function useImageScatterPointsQuery({
  datasetId,
  fromPublicDatasets,
  clusterLevel,
  slice,
  query,
  appliedFilters,
}: ImageTileRequestParams & {
  appliedFilters?: ImageFilterSchema | undefined;
}) {
  const { getImageScatterTiles } = useDataOpsEmbeddingService();
  const { getDatasetBriefDataList } = useCurateDatasetService();

  const queryFnWrapper = (query?: string, slice?: string, appliedFilters?: any) => async () => {
    const apiParams = {
      datasetId,
      fromPublicDatasets,
      clusterLevel,
      ...(query && { query }),
      ...(slice && { slice }),
      ...ES_CONFIG,
    };
    const points = await getPointsInAllTiles(apiParams, getImageScatterTiles);
    if (!points) return points;
    return addImageUrlsClusterInfosToPoints(
      datasetId,
      fromPublicDatasets,
      points,
      getDatasetBriefDataList,
      slice,
      query,
    );

    // return points;
  };
  // };
  return useQuery({
    queryKey: ['image-scatter-points', datasetId, slice, query, appliedFilters],
    queryFn: queryFnWrapper(query, slice, appliedFilters),
    retry: 2,
    refetchOnWindowFocus: false,
    refetchOnMount: false,
  });
}

export function useObjectScatterPointsQuery({
  datasetId,
  slice,
  query,
  clusterLevel,
  appliedFilters,
  fromPublicDatasets,
}: ObjectTileRequestParams & {
  appliedFilters?: AnnotationFilterSchema;
}) {
  const { getObjectScatterTiles } = useDataOpsEmbeddingService();
  const { postGetObjectList } = useCurateDatasetObjectService();

  const stringifiedFilters = stringifyFilters(appliedFilters);
  const queryFnWrapper = (query?: string, slice?: string, appliedFilters?: any) => async () => {
    const apiParams = {
      datasetId,
      fromPublicDatasets,
      clusterLevel,
      ...(query && { query }),
      ...(slice && { slice }),
      ...(appliedFilters && { appliedFilters }),
      ...ES_CONFIG,
    };

    const points = await getPointsInAllTiles(apiParams, getObjectScatterTiles);
    if (!points) return points;
    return addObjectInfoToPoints({ points, ...apiParams, postGetObjectList });
  };
  return useQuery({
    queryKey: ['object-scatter-points', datasetId, slice, query, stringifiedFilters, clusterLevel],
    queryFn: queryFnWrapper(query, slice, appliedFilters),
    retry: 2,
    refetchOnWindowFocus: false,
    refetchOnMount: false,
  });
}

export type ComparedGroup = 'slice' | 'dataset';
export function useComparedImageScatterPointsQuery({
  datasetId,
  fromPublicDatasets,
  slice,
  query,
  appliedFilters,
  clusterLevel,
  comparedOption,
  showComparedPoints,
}: ImageTileRequestParams & ComparedScatterQuery) {
  const { getImageScatterTiles } = useDataOpsEmbeddingService();

  const comparedQueryFnWrapper = (slice?: string) => async () => {
    const apiParams = {
      datasetId,
      fromPublicDatasets,
      appliedFilters: undefined,
      clusterLevel,
      ...ES_CONFIG,
      ...(slice && { slice }),
    };
    const points = await getPointsInAllTiles<undefined>(apiParams, getImageScatterTiles);
    return points;
  };

  return useQuery({
    queryKey: [slice, query, appliedFilters, showComparedPoints, comparedOption?.id, clusterLevel],
    queryFn:
      comparedOption?.group === 'dataset'
        ? comparedQueryFnWrapper()
        : comparedQueryFnWrapper(comparedOption?.name),
    enabled: showComparedPoints && Boolean(comparedOption?.id),
    refetchOnWindowFocus: false,
  });
}

export function useComparedObjectScatterPointsQuery({
  datasetId,
  slice,
  query,
  appliedFilters,
  clusterLevel,
  comparedOption,
  showComparedPoints,
  fromPublicDatasets,
}: ObjectTileRequestParams & ComparedScatterQuery) {
  const { getObjectScatterTiles } = useDataOpsEmbeddingService();

  const comparedQueryFnWrapper = (slice?: string) => async () => {
    const apiParams = {
      datasetId,
      fromPublicDatasets,
      appliedFilters: undefined,
      clusterLevel,
      ...ES_CONFIG,
      ...(slice && { slice }),
    };

    const points = await getPointsInAllTiles<undefined>(apiParams, getObjectScatterTiles);
    return points as ObjectEmbeddingDatum[];
  };

  return useQuery({
    queryKey: [slice, query, appliedFilters, showComparedPoints, comparedOption?.id, clusterLevel],
    queryFn:
      comparedOption?.group === 'dataset'
        ? comparedQueryFnWrapper()
        : comparedQueryFnWrapper(comparedOption?.name),
    enabled: showComparedPoints && Boolean(comparedOption?.id),
    retry: 2,
    refetchOnWindowFocus: false,
  });
}

function extractPointsFromBuckets(
  response: ImageTileResponse | ObjectTileResponse,
): EmbeddingDatum[] {
  return response.results.flatMap((bucket: TileBucket) => bucket.data.flat());
}

async function addObjectInfoToPoints(
  params: ObjectTileRequestParams & {
    points: EmbeddingDatum[];
    postGetObjectList: ReturnType<typeof useCurateDatasetObjectService>['postGetObjectList'];
  },
): Promise<ObjectEmbeddingDatum[]> {
  const { datasetId, fromPublicDatasets, points, slice, query, appliedFilters, postGetObjectList } =
    params;
  const chunkSize = 100;
  const postBody: PostGetObjectRequest = {
    size: chunkSize,
    expand: ['image_thumbnail_url', 'original_image_size', 'small_image_thumbnail_url'],
    ...(slice && { slice }),
    ...(query && { query }),
    ...(appliedFilters && { annotation_filters: appliedFilters }),
  };
  const pointWithObjectInfo: ObjectEmbeddingDatum[] = [];

  // loop through points by size of 100 then call postGetObjectList
  for (let i = 0; i < points.length; i += chunkSize) {
    const chunk = points.slice(i, i + chunkSize);
    const sortedChunk = sortBy(chunk, 'id'); // asc is default
    const ids = sortedChunk.map(point => point.id);
    postBody.id_in = ids;

    const objectIdToPoint = sortedChunk.reduce((acc, point) => {
      acc[point.id] = point as ObjectEmbeddingDatum;
      return acc;
    }, {} as Record<string, ObjectEmbeddingDatum>);

    const objects = await postGetObjectList({
      dataset_id: datasetId,
      fromPublicDatasets,
      ...postBody,
    });
    objects?.results.forEach((object: DatasetObjects) => {
      const point = objectIdToPoint[object.id];
      pointWithObjectInfo.push({
        ...point,
        annotationClass: object.annotationClass,
        imageThumbnailUrl: object.imageThumbnailUrl,
        smallImageThumbnailUrl: object.smallImageThumbnailUrl,
        originalImageSize: object.originalImageSize,
        annotationValue: object.annotationValue,
        roi: object.roi,
        annotationType: object.annotationType,
      });
    });
  }
  return pointWithObjectInfo;
}

export function useImageSuperClustersQuery({
  datasetId,
  clusterLevel,
  sliceName,
  fromPublicDatasets,
}: GetClustersRequestParams) {
  const { getImageClusters } = useDataOpsEmbeddingService();
  const queryFn = async () => {
    const clusters = await getImageClusters({
      datasetId,
      clusterLevel: clusterLevel,
      sliceName,
      fromPublicDatasets,
    });
    let clusterIds = clusters?.results.map(d => d.id);
    if (clusterLevel === CLUSTER_LEVEL_3_SIZE_16) {
      const level2clusters = await getImageClusters({
        datasetId,
        clusterLevel: CLUSTER_LEVEL_2_SIZE_8,
        sliceName,
        fromPublicDatasets,
      });
      const level2ClusterIds = level2clusters?.results.map(d => d.id).filter(d => d !== null);

      // Define the order we want to fill
      const indexOrder = SUPER_CLUSTER_COLORS_8.map(color =>
        SUPER_CLUSTER_COLORS_16.indexOf(color),
      ).filter(index => index < clusterIds.length && index >= 0);
      const sortedClusterIds = new Array(clusterIds.length);

      // Fill in the specified indices with level2ClusterIds
      indexOrder.forEach((index, i) => {
        if (i < level2ClusterIds.length) {
          sortedClusterIds[index] = level2ClusterIds[i];
        }
      });

      // Fill in the remaining indices with the rest of the clusterIds
      let j = 0;
      for (let i = 0; i < sortedClusterIds.length; i++) {
        if (sortedClusterIds[i] === undefined) {
          while (j < clusterIds.length && level2ClusterIds.includes(clusterIds[j])) {
            j++;
          }
          if (j < clusterIds.length) {
            sortedClusterIds[i] = clusterIds[j];
            j++;
          }
        }
      }
      clusterIds = sortedClusterIds.filter(id => id !== undefined); // Remove any undefined slots
    }
    const colorMap = generateFilterDotColors(clusterIds, clusterLevel);
    const totalCount = clusters?.results.reduce((acc, cluster) => {
      acc += cluster.size;
      return acc;
    }, 0);
    return clusters?.results.map((d, index) => {
      return {
        ...d,
        color: colorMap[d.id],
        share: calculatePercent({ numerator: d.size, denominator: totalCount, nearest: 'int' }),
      };
    });
  };

  return useQuery({
    queryKey: ['image-clusters', datasetId, clusterLevel, sliceName],
    queryFn,
    enabled: typeof datasetId !== 'undefined',
    retry: 2,
    refetchOnWindowFocus: false,
    refetchOnMount: false,
  });
}

export function useImageLeafClustersQuery({
  datasetId,
  sliceName,
  fromPublicDatasets,
  queryString,
}: GetClustersRequestParams) {
  const { getImageLeafClusters } = useDataOpsEmbeddingService();
  const queryFn = async () => {
    const clusters = await getImageLeafClusters({
      datasetId,
      sliceName,
      fromPublicDatasets,
      queryString,
    });
    return clusters?.results;
  };

  return useQuery({
    queryKey: ['image-leaf-clusters', datasetId, sliceName, queryString],
    queryFn,
    enabled: !!datasetId,
    retry: 2,
    refetchOnWindowFocus: false,
    refetchOnMount: false,
  });
}

export function useImageLeavesInSuperClustersQuery({
  datasetId,
  sliceName,
  fromPublicDatasets,
  superClusterIds,
  clusterLevel,
  disabled,
}: {
  datasetId: string;
  clusterLevel: ClusterLevel;
  sliceName?: string;
  fromPublicDatasets: boolean;
  superClusterIds: string[];
  disabled?: boolean;
}) {
  const { getImageLeafClusters } = useDataOpsEmbeddingService();

  const queryFn = async () => {
    const allClusters = await Promise.all(
      superClusterIds.map(async superClusterId => {
        const queryString = convertSelectedSuperClustersToQueryString(
          [superClusterId],
          clusterLevel,
        );
        const clusters = await getImageLeafClusters({
          datasetId,
          sliceName,
          fromPublicDatasets,
          queryString,
        });
        return clusters?.results?.map(d => ({ id: d[0], size: d[1], superClusterId }));
      }),
    );
    // Flatten the array of results
    return allClusters.flat();
  };

  return useQuery({
    queryKey: [
      'image-leaf-clusters-of-selected-super-cluster',
      datasetId,
      sliceName,
      superClusterIds,
      clusterLevel,
    ],
    queryFn,
    enabled: !!datasetId && !!superClusterIds && !disabled,
    retry: 2,
    refetchOnWindowFocus: false,
    refetchOnMount: false,
  });
}

export type ClusterDisplayInfo = {
  color: string;
  colorWithOpacity?: string;
  name?: string; // Added later to provide translation
  share?: number;
};

export type ClusterWithDisplayInfo = Cluster & ClusterDisplayInfo;
