import { ScatterChart } from "@mui/x-charts/ScatterChart";
import { MentionType } from "../../state";
import { MakeOptional } from "@mui/x-charts/internals";
import { ScatterItemIdentifier, ScatterSeriesType, ScatterValueType } from "@mui/x-charts";
import { useState } from "react";
import { Box, Typography } from "@mui/material";
import { GlobalSizes } from "../../size";
import MentionFocus from "../MentionFocus";

type PointType = ScatterValueType & { mention: MentionType };

function EmbeddingsScatterChart({ mentions, stringToColor }: { mentions: MentionType[]; stringToColor: (str: string) => string }) {
  const [seriesData, setSeriesData] = useState<ScatterItemIdentifier | null>(null);
  const [focusedMention, setFocusedMention] = useState<MentionType>();

  const mentionsByInquiry = mentions.reduce<{ [key: string]: MentionType[] }>((groupedMentions, mention) => {
    if (!groupedMentions["other"]) groupedMentions["other"] = [];
    if (!mention.ai_filter?.results) {
      groupedMentions["other"].push(mention);
    } else {
      mention.ai_filter.results.forEach(({ key }) => {
        if (!groupedMentions[key]) groupedMentions[key] = [];
        groupedMentions[key].push(mention);
      });
    }
    return groupedMentions;
  }, {});

  const series: MakeOptional<ScatterSeriesType, "type">[] = Object.entries(mentionsByInquiry).map(([key, value]) => {
    return {
      label: key,
      id: key,
      color: key === "other" ? "gray" : stringToColor(key),
      data: value.map((mention) => {
        return {
          x: mention.reduced_embedding![0],
          y: mention.reduced_embedding![1],
          id: mention.url,
          mention,
        };
      }),
      highlightScope: {
        highlight: "item",
      },
      markerSize: GlobalSizes.smallGap,
      valueFormatter: (value) => {
        const point = series.find((s) => s.data?.some((d) => d.id === value.id))?.data?.find((d) => d.id === value.id) as PointType | undefined;
        return point?.mention
          ? `${point.mention.user ? point.mention.user + " - " : ""}${point.mention.source} - ${point.mention.description_short || point.mention.full_content}`
          : "";
      },
    };
  });

  return (
    <Box>
      <Typography variant="h6" gutterBottom ml={GlobalSizes.gap}>
        AI Narrative Map
      </Typography>
      <ScatterChart
        height={600}
        series={series}
        slotProps={{
          legend: {
            hidden: true,
          },
        }}
        onItemClick={(event: React.MouseEvent<SVGElement, MouseEvent>, d: ScatterItemIdentifier) => {
          setSeriesData(d);
          if (!seriesData) return;
          setFocusedMention((series.find((s) => s.id === seriesData.seriesId)?.data?.[seriesData.dataIndex] as PointType)?.mention);
        }}
      />
      {seriesData && (
        <MentionFocus mention={focusedMention} setMention={setFocusedMention} />
      )}
    </Box>
  );
}

export default EmbeddingsScatterChart;
