import { ResultOf } from "@graphql-typed-document-node/core";
import _ from "lodash";
import * as Plotly from "plotly.js";
import { useMemo } from "react";
import Plot from "react-plotly.js";

import { getDateStep, getEndDate, getStartDate } from "@/adapters/monitoring";
import { FragmentType, gql, useFragment } from "@/apis/nannyml";
import { PlotConfig, PlotDataset, usePlotConfig } from "@/components/monitoring/PlotConfig";
import * as colors from "@/constants/colors";
import { PlotElements, PlotType } from "@/constants/enums";
import { DateLike, formatISODate } from "@/lib/dateUtils";

import { getAxisName, getPlotLayout, selectDataset } from "./ResultPlot.utils";

const kdeDistributionResultDetails = gql(/* GraphQL */ `
  fragment KdeDistributionResultDetails on KdeDistributionResult {
    column: columnName
    data: chunks {
      isAnalysis
      startTimestamp
      endTimestamp
      nrDataPoints
      data {
        value
        density
      }
      indices {
        value
        density
        cumulativeDensity
      }
    }
  }
`);

type KdeDistributionResultDetails = ResultOf<typeof kdeDistributionResultDetails>;

export const KdeDistributionPlot = ({
  dateRange,
  className,
  result: resultFragment,
  width,
  alerts,
  layout,
  onUpdate,
}: {
  dateRange?: [DateLike, DateLike];
  className?: string;
  result: FragmentType<typeof kdeDistributionResultDetails>;
  alerts?: (Boolean | null)[];
  width?: number;
  layout?: any;
  onUpdate?: (figure: any, graphDiv: any) => void;
}) => {
  const result = useFragment(kdeDistributionResultDetails, resultFragment);
  const config = usePlotConfig();

  // Cache data to prevent re-computation when date range changes
  const data = useMemo(() => getPlotData(result, config, alerts), [result, config, alerts]);

  return (
    <Plot
      className={className}
      data={data}
      layout={{ ...getKdeDistributionPlotLayout(result, config, dateRange, width), ...layout }}
      config={{ displayModeBar: false }}
      onUpdate={onUpdate}
    />
  );
};

/**
 * Get plot traces for the given result
 * @param result The results to be plotted
 * @param config Plot configuration to be used
 * @returns Plotly traces
 */
const getPlotData = (
  result: KdeDistributionResultDetails,
  config: PlotConfig,
  alerts?: (Boolean | null)[]
): Partial<Plotly.PlotData>[] => {
  if (config.type !== PlotType.Distribution) {
    throw new Error(`Unsupported plot type: ${config.type}`);
  }

  // Merge alert info into data
  result = {
    ...result,
    data: result.data.map((chunk, idx) => ({
      ...chunk,
      hasAlert: alerts?.[idx],
    })),
  };

  // Get traces for all datasets
  const traces = config.datasets.flatMap((dataset, idx) => {
    const data = selectDataset(result, dataset);
    if (!data.length) {
      return [];
    }

    const traceOption = traceOptions[config.subplotPerDataset ? "any" : dataset];
    const traces = getPlotTraces(result, data, config.elements, traceOption);
    if (config.subplotPerDataset) {
      // Move traces to a subplot
      traces.forEach((trace) => {
        trace.xaxis = getAxisName("x", idx);
        trace.yaxis = getAxisName("y", idx);
      });
    }

    return traces;
  });

  // Disable duplicate legend entries
  const legendEntries = new Set();
  traces.forEach((trace) => {
    trace.showlegend &&= !legendEntries.has(trace.name);

    if (trace.showlegend) {
      legendEntries.add(trace.name);
    }
  });

  return traces;
};

type TraceOptions = {
  traceColor: string;
  fillColor: string;
  alertColor: string;
  alertFillColor: string;
  nameFn: (name: string) => string;
};

const traceOptions: Record<PlotDataset | "any", TraceOptions> = {
  any: {
    traceColor: colors.referenceLineColor,
    fillColor: colors.referenceConfidenceBandColor,
    alertColor: colors.alertColor,
    alertFillColor: colors.alertFill,
    nameFn: (name) => name,
  },
  [PlotDataset.Analysis]: {
    traceColor: colors.analysisLineColor,
    fillColor: colors.analysisConfidenceBandColor,
    alertColor: colors.alertColor,
    alertFillColor: colors.alertFill,
    nameFn: (name) => `${name} (analysis)`,
  },
  [PlotDataset.Reference]: {
    traceColor: colors.referenceLineColor,
    fillColor: colors.referenceConfidenceBandColor,
    alertColor: colors.alertColor,
    alertFillColor: colors.alertFill,
    nameFn: (name) => `${name} (reference)`,
  },
};

/**
 * Generates traces for a kernel density estimation (KDE) distribution result
 * @param result Result to be plotted
 * @param data Data points to be plotted
 * @param plotElements Elements to be displayed in the plot
 * @param options Options for configuring the plot
 * @returns A list of Plotly traces
 */
const getPlotTraces = (
  result: Omit<KdeDistributionResultDetails, "data">,
  data: (KdeDistributionResultDetails["data"][0] & { hasAlert?: Boolean })[],
  plotElements: PlotElements[],
  options: TraceOptions
): Partial<Plotly.PlotData>[] => {
  const [hoverTemplate, getHoverTemplateData] = generateIndexHoverTemplate(result, options.traceColor);

  const getTraces = (
    name: string,
    chunks: typeof data,
    traceColor: string,
    fillColor: string
  ): Partial<Plotly.PlotData>[] => [
    {
      name: name,
      legendgroup: name,
      mode: "lines",
      line: {
        color: traceColor,
      },
      hoverinfo: "skip",
      showlegend: true,
      fill: "toself",
      fillcolor: fillColor,
      /* Distributions are plotted as a single trace for performance reasons. Plotly is much slower when using multiple
       * traces. To achieve this, we use a scatter plot with null values to separate the chunks. An extra data point is
       * added at the start and end of every distribution to ensure that the fill is closed at the appropriate
       * timestamp.
       */
      x: chunks.flatMap((chunk) => {
        const chunkPeriod = getDateStep(chunk);
        return [
          chunk.startTimestamp,
          ...chunk.data.map((dp) => new Date(chunk.startTimestamp).getTime() + chunkPeriod * dp.density),
          chunk.startTimestamp,
          null,
        ];
      }),
      y: chunks.flatMap((chunk) =>
        chunk.data.length === 0
          ? []
          : [chunk.data[0].value, ...chunk.data.map((dp) => dp.value), chunk.data.at(-1)!.value, null]
      ),
      type: "scatter",
    },
    {
      name: "Indices",
      legendgroup: "indices",
      mode: "lines",
      line: {
        color: traceColor,
        dash: "dot",
      },
      hovertemplate: hoverTemplate,
      customdata: chunks.flatMap((chunk) =>
        chunk.indices.flatMap((index) => [getHoverTemplateData(chunk, index), getHoverTemplateData(chunk, index), []])
      ),
      showlegend: false,
      x: chunks.flatMap((chunk) => {
        const chunkPeriod = getDateStep(chunk);
        return chunk.indices.flatMap((index) => [
          chunk.startTimestamp,
          new Date(chunk.startTimestamp).getTime() + chunkPeriod * index.density,
          null,
        ]);
      }),
      y: chunks.flatMap((chunk) => chunk.indices.flatMap((index) => [index.value, index.value, null])),
    },
  ];

  const traceName = options.nameFn("Kernel Density Estimation");

  if (plotElements.includes(PlotElements.Alerts)) {
    return [
      ...getTraces(
        traceName,
        data.filter((chunk) => !chunk.hasAlert),
        options.traceColor,
        options.fillColor
      ),
      ...getTraces(
        "Alerts",
        data.filter((chunk) => chunk.hasAlert),
        options.alertColor,
        options.alertFillColor
      ),
    ];
  } else {
    return getTraces(traceName, data, options.traceColor, options.fillColor);
  }
};

/**
 * Generates Plotly hover template for distribution indices
 * @param result Result to generate hover information for
 * @returns A template string and a function to generate hover `customdata` input
 */
const generateIndexHoverTemplate = (
  result: Pick<KdeDistributionResultDetails, "column">,
  metricColor: string
): [
  string,
  (
    chunk: KdeDistributionResultDetails["data"][0] & { hasAlert?: Boolean },
    index: KdeDistributionResultDetails["data"][0]["indices"][0]
  ) => (string | number)[]
] => [
  `<b style="color:${metricColor}">%{customdata[1]}</b> &nbsp; &nbsp; %{customdata[2]}<br />
    Chunk: <b>%{customdata[3]} - %{customdata[4]}</b> (%{customdata[5]} rows)<br />
    %{customdata[0]}: %{customdata[7]:,.0%} of values are smaller than <b>%{customdata[6]:.4f}</b><br />
    <extra></extra>`,
  (chunk, index) => [
    result.column,
    chunk.isAnalysis ? "Analysis" : "Reference",
    chunk.hasAlert ? `<b style="color:${colors.alertColor}">⚠️ Drift detected</b>` : "",
    formatISODate(getStartDate(chunk)),
    formatISODate(getEndDate(chunk)),
    chunk.nrDataPoints,
    index.value,
    index.cumulativeDensity,
  ],
];

/**
 * Get plot layout for the given result
 * @param result The results to be plotted
 * @param config Plot configuration to be used
 * @param dateRange Optional date range to be used for the plot
 * @param width Optional width to be used for the plot
 * @returns Plotly layout
 */
const getKdeDistributionPlotLayout = (
  result: KdeDistributionResultDetails,
  config: PlotConfig,
  dateRange?: [DateLike, DateLike],
  width?: number
): Partial<Plotly.Layout> => {
  // Use default layout as base
  const layout = getPlotLayout([result], config, dateRange, width);

  // Apply title and range for all y-axes
  const yAxes = Object.keys(layout).filter((key) => key.startsWith("yaxis"));
  return _.merge(
    layout,
    yAxes.reduce(
      (acc, axisKey) => ({
        ...acc,
        [axisKey]: {
          title: {
            text: result.column,
          },
        },
      }),
      {}
    )
  );
};
