import { T } from "vitest/dist/types-0373403c";

import {
  AnalysisType,
  CalculatorType,
  ConceptShiftMetric,
  DataPeriod,
  DataQualityMetric,
  MultivariateDriftMethod,
  PerformanceMetric,
  PerformanceType,
  SummaryStatsMetric,
  TimeSeriesResult,
  UnivariateDriftMethod,
} from "@/apis/nannyml";
import { ResultView } from "@/domains/monitoring";

// Utility method to generate a function that returns a label for a given metric name. It is intended to be used with
// metric names obtained from results. These are typically lowercase and need to be converted before they can be used
// to look up labels.
const generateMetricLabelGetter =
  <T extends { [value: string]: string }>(labels: T) =>
  <V extends string>(metricName: V): V extends keyof T ? string : string | undefined =>
    labels[metricName.toUpperCase() as keyof T];

export const resultViewLabels: Record<ResultView, string> = {
  [ResultView.ConceptDrift]: "Concept drift",
  [ResultView.DataQuality]: "Data quality",
  [ResultView.Performance]: "Performance",
  [ResultView.CovariateShift]: "Covariate shift",
};

export const dataPeriodLabels: Record<DataPeriod, string> = {
  [DataPeriod.Analysis]: "Analysis",
  [DataPeriod.Reference]: "Reference",
};

export const performanceMetricLabels: Record<PerformanceMetric, string> = {
  [PerformanceMetric.RocAuc]: "ROC AUC",
  [PerformanceMetric.F1]: "F1",
  [PerformanceMetric.Precision]: "Precision",
  [PerformanceMetric.Recall]: "Recall",
  [PerformanceMetric.Specificity]: "Specificity",
  [PerformanceMetric.Accuracy]: "Accuracy",
  [PerformanceMetric.BusinessValue]: "Business value",
  [PerformanceMetric.ConfusionMatrix]: "Confusion matrix",
  [PerformanceMetric.AveragePrecision]: "Average precision",
  [PerformanceMetric.Mae]: "MAE",
  [PerformanceMetric.Mape]: "MAPE",
  [PerformanceMetric.Mse]: "MSE",
  [PerformanceMetric.Rmse]: "RMSE",
  [PerformanceMetric.Msle]: "MSLE",
  [PerformanceMetric.Rmsle]: "RMSLE",
};

export const getStandardPerformanceMetricLabel = generateMetricLabelGetter(performanceMetricLabels);

// Method capable of handling both standard & custom metrics. Custom performance metrics don't have a label defined, so
// we just use the metric name itself
export const getPerformanceMetricLabel = (metricName: string): string =>
  getStandardPerformanceMetricLabel(metricName) ?? metricName;

type ComponentLabel = {
  abbr: string;
  name: string;
};

export const performanceComponentLabels: Record<string, ComponentLabel> = {
  true_positive: { abbr: "TP", name: "True positive" },
  true_negative: { abbr: "TN", name: "True negative" },
  false_positive: { abbr: "FP", name: "False positive" },
  false_negative: { abbr: "FN", name: "False negative" },
};

export const conceptShiftMetricLabels: Record<ConceptShiftMetric, string> = {
  [ConceptShiftMetric.RocAuc]: "ROC AUC delta",
  [ConceptShiftMetric.Precision]: "Precision delta",
  [ConceptShiftMetric.Recall]: "Recall delta",
  [ConceptShiftMetric.F1]: "F1 delta",
  [ConceptShiftMetric.Accuracy]: "Accuracy delta",
  [ConceptShiftMetric.Specificity]: "Specificity delta",
  [ConceptShiftMetric.AveragePrecision]: "Average precision delta",
  [ConceptShiftMetric.Magnitude]: "Magnitude",
};

export const getConceptShiftMetricLabel = generateMetricLabelGetter(conceptShiftMetricLabels);

export const distanceMethodLabels = {
  [UnivariateDriftMethod.JensenShannon]: "Jensen Shannon",
  [UnivariateDriftMethod.Wasserstein]: "Wasserstein",
  [UnivariateDriftMethod.Hellinger]: "Hellinger",
  [UnivariateDriftMethod.LInfinity]: "L-infinity",
};

export const getUnivariateDistanceMethodLabel = generateMetricLabelGetter(distanceMethodLabels);

export const statisticalMethodLabels = {
  [UnivariateDriftMethod.KolmogorovSmirnov]: "Kolmogorov Smirnov",
  [UnivariateDriftMethod.Chi2]: "Chi2",
};

export const getUnivariateStatisticalMethodLabel = generateMetricLabelGetter(statisticalMethodLabels);

export const univariateDriftMethodLabels = {
  ...distanceMethodLabels,
  ...statisticalMethodLabels,
} as Record<UnivariateDriftMethod, string>;

export const getUnivariateDriftMethodLabel = generateMetricLabelGetter(univariateDriftMethodLabels);

export const multivariateDriftMethodLabels = {
  [MultivariateDriftMethod.PcaReconstructionError]: "PCA reconstruction error",
  // Metric name does not align with setting name, so handle it separately
  RECONSTRUCTION_ERROR: "PCA reconstruction error",
  [MultivariateDriftMethod.DomainClassifierAuroc]: "Domain classifier AUROC",
  DOMAIN_CLASSIFIER: "Domain classifier AUROC",
};

export const getMultivariateDriftMethodLabel = generateMetricLabelGetter(multivariateDriftMethodLabels);

export const dataQualityMetricLabels: Record<DataQualityMetric, string> = {
  [DataQualityMetric.MissingValues]: "Missing values",
  [DataQualityMetric.UnseenValues]: "Unseen values",
};

export const getDataQualityMetricLabel = generateMetricLabelGetter(dataQualityMetricLabels);

export const summaryStatsMetricLabels: Record<SummaryStatsMetric, string> = {
  [SummaryStatsMetric.RowsCount]: "Number of rows",
  [SummaryStatsMetric.SummaryStatsAvg]: "Average",
  [SummaryStatsMetric.SummaryStatsMedian]: "Median",
  [SummaryStatsMetric.SummaryStatsStd]: "Standard deviation",
  [SummaryStatsMetric.SummaryStatsSum]: "Sum",
};

export const getSummaryStatsMetricLabel = generateMetricLabelGetter(summaryStatsMetricLabels);

export const metricLabels: Record<AnalysisType, (metricName: string) => string | undefined> = {
  [AnalysisType.ConceptShift]: getConceptShiftMetricLabel,
  [AnalysisType.DataQuality]: getDataQualityMetricLabel,
  [AnalysisType.FeatureDrift]: (m) => getUnivariateDriftMethodLabel(m) ?? getMultivariateDriftMethodLabel(m),
  [AnalysisType.SummaryStats]: getSummaryStatsMetricLabel,
  [AnalysisType.EstimatedPerformance]: getPerformanceMetricLabel,
  [AnalysisType.RealizedPerformance]: getPerformanceMetricLabel,
  [AnalysisType.Distribution]: () => "",
};

export const analysisLabels: Record<AnalysisType, string> = {
  [AnalysisType.ConceptShift]: "Concept drift",
  [AnalysisType.DataQuality]: "",
  [AnalysisType.EstimatedPerformance]: "Estimated performance",
  [AnalysisType.FeatureDrift]: "",
  [AnalysisType.RealizedPerformance]: "Realized performance",
  [AnalysisType.Distribution]: "",
  [AnalysisType.SummaryStats]: "Summary statistics",
};

export const calculatorLabels: Record<CalculatorType, string> = {
  [CalculatorType.Cbpe]: "CBPE",
  [CalculatorType.Dle]: "DLE",
  [CalculatorType.Pape]: "PAPE",
  [CalculatorType.MissingValues]: "Missing values",
  [CalculatorType.ReconstructionError]: "Reconstruction error",
  [CalculatorType.DomainClassifier]: "Domain classifier",
  [CalculatorType.PerformanceCalculation]: "Realized performance",
  [CalculatorType.Rcs]: "RCS",
  [CalculatorType.UnivariateDrift]: "Univariate drift",
  [CalculatorType.UnseenValues]: "Unseen values",
  [CalculatorType.ContinuousDistribution]: "Continuous distribution",
  [CalculatorType.CategoricalDistribution]: "Categorical distribution",
  [CalculatorType.SummaryStatsAvg]: "Summary statistics: average",
  [CalculatorType.SummaryStatsMedian]: "Summary statistics: median",
  [CalculatorType.SummaryStatsRowCount]: "Summary statistics: row count",
  [CalculatorType.SummaryStatsStd]: "Summary statistics: standard deviation",
  [CalculatorType.SummaryStatsSum]: "Summary statistics: sum",
  [CalculatorType.CustomCbpe]: "Custom CBPE",
  [CalculatorType.CustomDle]: "Custom DLE",
  [CalculatorType.CustomPape]: "Custom PAPE",
  [CalculatorType.CustomPerformanceCalculation]: "Custom realized performance",
};

export const algorithmLabels: Partial<Record<CalculatorType, string>> = {
  [CalculatorType.Cbpe]: "CBPE",
  [CalculatorType.Dle]: "DLE",
  [CalculatorType.Pape]: "PAPE",
  [CalculatorType.ReconstructionError]: "Reconstruction error",
  [CalculatorType.DomainClassifier]: "Domain classifier",
  [CalculatorType.Rcs]: "RCS",
  [CalculatorType.CustomCbpe]: "CBPE",
  [CalculatorType.CustomDle]: "DLE",
  [CalculatorType.CustomPape]: "PAPE",
};

export const performanceTypeLabels: Record<PerformanceType, string> = {
  [PerformanceType.Cbpe]: "CBPE",
  [PerformanceType.Dle]: "DLE",
  [PerformanceType.Pape]: "PAPE",
  [PerformanceType.Realized]: "Realized performance",
};

export const alertDetectedLabels: Record<CalculatorType, string> = {
  [CalculatorType.Cbpe]: "Estimated performance alert",
  [CalculatorType.Dle]: "Estimated performance alert",
  [CalculatorType.Pape]: "Estimated performance alert",
  [CalculatorType.MissingValues]: "Missing values detected",
  [CalculatorType.ReconstructionError]: "Multivariate drift detected",
  [CalculatorType.DomainClassifier]: "Multivariate drift detected",
  [CalculatorType.PerformanceCalculation]: "Realized performance alert",
  [CalculatorType.Rcs]: "Concept drift detected",
  [CalculatorType.UnivariateDrift]: "Drift detected",
  [CalculatorType.UnseenValues]: "Unseen values detected",
  [CalculatorType.ContinuousDistribution]: "",
  [CalculatorType.CategoricalDistribution]: "",
  [CalculatorType.SummaryStatsAvg]: "Drift detected",
  [CalculatorType.SummaryStatsMedian]: "Drift detected",
  [CalculatorType.SummaryStatsRowCount]: "Number of rows alert",
  [CalculatorType.SummaryStatsStd]: "Drift detected",
  [CalculatorType.SummaryStatsSum]: "Drift detected",
  [CalculatorType.CustomCbpe]: "Estimated performance alert",
  [CalculatorType.CustomDle]: "Estimated performance alert",
  [CalculatorType.CustomPape]: "Estimated performance alert",
  [CalculatorType.CustomPerformanceCalculation]: "Realized performance alert",
};

export const getShortResultLabel = (
  result: Pick<TimeSeriesResult, "analysisType" | "metricName" | "componentName">
) => {
  let label = metricLabels[result.analysisType](result.metricName) || result.metricName;

  if (result.componentName && result.componentName !== result.metricName) {
    label += ` (${performanceComponentLabels[result.componentName]?.abbr || result.componentName})`;
  }

  return label;
};

export const getResultLabels = (
  result: Pick<TimeSeriesResult, "analysisType" | "calculatorType" | "metricName" | "componentName" | "columnName">
): {
  metric: string;
  analysis: string;
  calculator: string;
  columnName?: string;
  algorithm?: string;
} => {
  let metric = metricLabels[result.analysisType](result.metricName) || result.metricName;
  if (result.componentName && result.componentName !== result.metricName) {
    metric += ` - ${performanceComponentLabels[result.componentName]?.name || result.componentName}`;
  }

  const labels = {
    metric,
    analysis: analysisLabels[result.analysisType],
    calculator: calculatorLabels[result.calculatorType],
    algorithm: algorithmLabels[result.calculatorType],
    columnName: result.columnName ?? undefined,
  };

  // RCS has two calculators mixed together, so we need to distinguish between them
  if (result.calculatorType === CalculatorType.Rcs) {
    if (result.metricName.toUpperCase() === ConceptShiftMetric.Magnitude) {
      labels.analysis += " - Magnitude Estimation";
      labels.calculator += "-ME";
      labels.algorithm += "-ME";
    } else {
      labels.analysis += " - Performance Impact Estimation";
      labels.calculator += "-PIE";
      labels.algorithm += "-PIE";
    }
  }

  return labels;
};

export const getResultTitles = (
  result: Pick<TimeSeriesResult, "analysisType" | "calculatorType" | "metricName" | "componentName" | "columnName">
): {
  title: string;
  subtitle: string;
} => {
  const labels = getResultLabels(result);
  return {
    title: labels.columnName ?? labels.metric,
    subtitle: labels.columnName
      ? labels.metric
      : labels.analysis && labels.algorithm
      ? `${labels.analysis} (${labels.algorithm})`
      : labels.analysis
      ? labels.analysis
      : labels.algorithm ?? "",
  };
};
