import { Ref, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useHistory } from 'react-router';

import { Box } from '@superb-ai/ui';

import { useImageFilterContext } from '../../../../../../contexts/ImageFilterContext';
import { useImageScopeContext } from '../../../../../../contexts/ImageScopeContext';
import { useObjectFilterContext } from '../../../../../../contexts/ObjectFilterContext';
import { useObjectScopeContext } from '../../../../../../contexts/ObjectScopeContext';
import { LoadingIndicatorDiv } from '../../../analytics/components/LoadingIndicatorDiv';
import { useImageScatterContext } from '../../../analytics/contexts/ImageScatterContext';
import { useFilteredPoints, usePoints } from '../scatterView/providers/DataProvider';
import { EmbeddingDatum, ObjectEmbeddingDatum } from '../types';
import AllImagesGrid from './AllImagesGrid';
import ImageCell from './ImageCell';
import { ImagePreviewHeader } from './ImagePreviewHeader';
import ObjectImageCell from './ObjectImageCell';
import SampleImagesGrid from './SampleImagesGrid';

export const ImagePreviewArea = ({
  width,
  height,
  cardStyle,
  style,
  isGridAreaExpanded,
  setIsGridAreaExpanded,
}: {
  width: number | string;
  height: number | string;
  cardStyle?: React.CSSProperties;
  style?: React.CSSProperties;
  isGridAreaExpanded: boolean;
  setIsGridAreaExpanded: (isGridAreaExpanded: boolean) => void;
}) => {
  const { t } = useTranslation();
  const history = useHistory();
  const searchParams = new URLSearchParams(history.location.search);
  const scope = searchParams.get('scope') || 'image';
  const { isLoadingProjections } = useImageScatterContext();
  const { selectedSuperClusters: selectedImageClusters } = useImageFilterContext();
  const {
    selectedSuperClusters: selectedObjectClusters,
    selectedClasses,
    classClusterFilterTab,
  } = useObjectFilterContext();

  const defaultColumns = 3;
  const [columnsCount, setColumnsCount] = useState(defaultColumns);

  const containerRef = useRef<HTMLDivElement>();
  const [calculatedContentAreaHeight, setCalculatedContentAreaHeight] = useState<string>(); // [containerHeight, contentAreaHeight]
  const { totalCount: imageTotalCount } = useImageScopeContext();
  const { totalCount: objectTotalCount } = useObjectScopeContext();

  const groupIds = useMemo(() => {
    if (scope === 'image') return selectedImageClusters;
    else if (scope === 'object' && classClusterFilterTab === 'class') {
      return selectedClasses;
    } else if (scope === 'object' && classClusterFilterTab === 'cluster') {
      return selectedObjectClusters?.flatMap(v => v.cluster_id_in);
    }
    return [];
  }, [
    scope,
    selectedImageClusters,
    classClusterFilterTab,
    selectedClasses,
    selectedObjectClusters,
  ]);

  const points = usePoints();
  const filteredPoints = useFilteredPoints();

  function isImagePoint(point: EmbeddingDatum | ObjectEmbeddingDatum): point is EmbeddingDatum {
    return 'clusterId' in point;
  }

  function isObjectPoint(
    point: EmbeddingDatum | ObjectEmbeddingDatum,
  ): point is ObjectEmbeddingDatum {
    return 'annotationClass' in point;
  }

  const selectedPoints = useMemo(() => {
    if (scope === 'image')
      return points.filter(point => isImagePoint(point) && groupIds?.includes(point.clusterId));
    else if (scope === 'object') {
      const key = classClusterFilterTab === 'class' ? 'annotationClass' : 'clusterId';
      return points.filter(point => isObjectPoint(point) && groupIds?.includes(point[key]));
    }
    return points;
  }, [classClusterFilterTab, points, scope, groupIds]);

  const getDisplayedPoints = () => {
    if (filteredPoints.length > 0) return filteredPoints;
    if (
      selectedPoints.length > 0 ||
      selectedClasses?.length ||
      selectedImageClusters?.length ||
      selectedObjectClusters?.length
    )
      return selectedPoints;
    return points;
  };

  const displayedPoints = getDisplayedPoints();

  const [tab, setTab] = useState<'sample' | 'all'>('sample');

  return (
    <Box
      display="flex"
      style={{ ...cardStyle, ...style, width, height }}
      flexDirection="column"
      alignItems="flex-start"
    >
      <ImagePreviewHeader
        height={height}
        defaultColumns={defaultColumns}
        columnsCount={columnsCount}
        setColumnsCount={setColumnsCount}
        samplingPointCount={displayedPoints.length}
        isGridAreaExpanded={isGridAreaExpanded}
        setIsGridAreaExpanded={setIsGridAreaExpanded}
        totalCount={(scope === 'image' ? imageTotalCount : objectTotalCount) ?? 0}
        tab={tab}
        setTab={setTab}
        // dataCount={displayedPoints?.length ?? 0}
        setCalculatedContentAreaHeight={setCalculatedContentAreaHeight}
      />
      <Box
        ref={containerRef as Ref<HTMLElement>}
        flexDirection="column"
        display="flex"
        width="100%"
        style={{ height: calculatedContentAreaHeight }}
      >
        <Box p={1} pr={0} pb={0}>
          {isLoadingProjections ? (
            <LoadingIndicatorDiv
              style={{
                height: '400px',
              }}
            />
          ) : (
            containerRef.current &&
            (tab === 'sample' ? (
              <SampleImagesGrid
                displayedPoints={displayedPoints}
                areaWidth={containerRef.current.clientWidth - 8}
                areaHeight={containerRef.current.clientHeight - 16}
                columnsCount={columnsCount}
                CellComponent={scope === 'image' ? ImageCell : ObjectImageCell}
              />
            ) : (
              <AllImagesGrid
                areaWidth={containerRef.current.clientWidth - 8}
                areaHeight={containerRef.current.clientHeight - 16}
                columnsCount={columnsCount}
              />
            ))
          )}
        </Box>
      </Box>
    </Box>
  );
};
