import { useMutation, useSuspenseQuery } from "@apollo/client";
import { ResultOf } from "@graphql-typed-document-node/core";
import _ from "lodash";
import { EyeIcon, Loader2Icon, PencilIcon, PlusIcon, SaveIcon, SearchIcon, Trash2Icon, XIcon } from "lucide-react";
import React from "react";
import { Link, useSearchParams } from "react-router-dom";

import { Input } from "@/DesignSystem/shadcn/Input";
import { Label } from "@/DesignSystem/shadcn/Label/Label";
import { CreateMetricInput, gql, ProblemType, RunState, useFragment } from "@/apis/nannyml";
import { confirm, Dialog, DialogContent, DialogTrigger } from "@/components/Dialog";
import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/Table";
import { Button } from "@/components/common/Button";
import { RequestStateLayout } from "@/components/dashboard/RequestStateLayout/RequestStateLayout";
import { MonitoringModelName } from "@/components/monitoring";
import { problemTypeLabels } from "@/formatters/models";
import { useIsFormValid } from "@/hooks/form";
import { useSearchList } from "@/hooks/useSearchList";
import { formatISODateTime } from "@/lib/dateUtils";

import { ClassificationMetric, convertClassificationMetricToInput } from "./ClassificationMetric";
import { convertRegressionMetricToInput, RegressionMetric } from "./RegressionMetric";

const monitoringMetricDetailsFragment = gql(/* GraphQL */ `
  fragment MonitoringMetricDetails on Metric {
    __typename
    id
    name
    description
    problemType
    ... on ClassificationMetric {
      ...ClassificationMetricDetails
    }
    ... on RegressionMetric {
      ...RegressionMetricDetails
    }
  }
`);

const getMonitoringMetricsQuery = gql(/* GraphQL */ `
  query GetMonitoringMetrics {
    monitoring_metrics {
      ...MonitoringMetricDetails
    }
  }
`);

const getMonitoringMetricModelsQuery = gql(/* GraphQL */ `
  query GetMonitoringMetricModels($metricId: Int!) {
    monitoring_metric(metricId: $metricId) {
      models {
        id
        name
        latestRun {
          state
          completedAt
        }
      }
    }
  }
`);

const createMonitoringMetricMutation = gql(/* GraphQL */ `
  mutation CreateMonitoringMetric($metric: CreateMetricInput!) {
    create_monitoring_metric(metric: $metric) {
      ...MonitoringMetricDetails
    }
  }
`);

const editMonitoringMetricMutation = gql(/* GraphQL */ `
  mutation EditMonitoringMetric($metric: EditMetricInput!) {
    edit_monitoring_metric(metric: $metric) {
      __typename
      ...MonitoringMetricDetails
    }
  }
`);

const deleteMonitoringMetricMutation = gql(/* GraphQL */ `
  mutation DeleteMonitoringMetric($metricId: Int!) {
    delete_monitoring_metric(metricId: $metricId) {
      ...MonitoringMetricDetails
    }
  }
`);

type MetricType = ResultOf<typeof monitoringMetricDetailsFragment>;
type MetricConfiguratorProps<T extends CreateMetricInput = CreateMetricInput> = {
  metric: T;
  onMetricChange: (metric: T) => void;
};

export const CustomMetrics = () => {
  const { data } = useSuspenseQuery(getMonitoringMetricsQuery);
  const metrics = useFragment(monitoringMetricDetailsFragment, data?.monitoring_metrics ?? []);
  return (
    <div className="p-8 flex flex-col">
      <h1 className="text-3xl font-semibold mb-2">Custom metrics</h1>
      <span className="text-gray-400 text-sm">
        Custom metrics are performance metrics you define using python code. They can be used to measure business
        impact, model performance, or any other metric that is important to your use case.
      </span>
      {Object.values(ProblemType).map((problemType) => (
        <MetricsTable
          key={problemType}
          metrics={metrics.filter((metric) => metric.problemType === problemType)}
          problemType={problemType}
        />
      ))}
    </div>
  );
};

const MetricsTable = ({ metrics, problemType }: { metrics: MetricType[]; problemType: ProblemType }) => {
  const { search, setSearch, results } = useSearchList(
    metrics.filter((m) => m.problemType === problemType),
    "name"
  );
  const [deleteMetric] = useMutation(deleteMonitoringMetricMutation, {
    update: (cache, { data }) => {
      if (data) {
        cache.evict({ id: cache.identify(data?.delete_monitoring_metric) });
      }
    },
  });
  const [params] = useSearchParams();
  const queryMetricId = params.get("id");
  const queryProblemType = params.get("problemType");

  const onDeleteMetric = (metric: MetricType) => {
    confirm({
      title: "Delete metric",
      message: (
        <div className="flex flex-col gap-4">
          <span>Are you sure you want to delete the '{metric.name}' metric?</span>
          <span>
            This will delete the metric and any associated results for all{" "}
            {problemTypeLabels[problemType].toLowerCase()} models. Once deleted, the data cannot be restored.
          </span>
        </div>
      ),
      confirmIntent: "reject",
    }).then((confirmed) => {
      if (confirmed) {
        deleteMetric({ variables: { metricId: metric.id } });
      }
    });
  };

  return (
    <div className="container flex flex-col gap-4">
      <h3 className="text-xl font-semibold mt-12">{problemTypeLabels[problemType]}</h3>
      <div className="flex justify-between">
        <Label className="relative w-[300px] text-slate-400 cursor-text">
          <SearchIcon className="absolute left-3 top-3" size={16} />
          <Input
            className="pl-9 border-gray-600"
            placeholder="Search metrics..."
            value={search}
            onChange={(e) => setSearch(e.target.value)}
          />
        </Label>
        <CreateMetricDialog problemType={problemType} defaultOpen={problemType === queryProblemType} />
      </div>
      <div className="border rounded-md border-gray-600">
        <Table className="odd:[&_tr]:bg-oddBg">
          <TableHeader>
            <TableRow>
              <TableHead className="w-64">Metric</TableHead>
              <TableHead className="px-8 w-full">Description</TableHead>
              <TableHead />
            </TableRow>
          </TableHeader>
          <TableBody>
            {results.length === 0 ? (
              <TableRow>
                <TableCell colSpan={3} className="text-center text-gray-400 italic">
                  No metrics found
                </TableCell>
              </TableRow>
            ) : (
              results.map((metric) => (
                <TableRow key={metric.id}>
                  <TableCell className="whitespace-nowrap">{metric.name}</TableCell>
                  <TableCell className="px-8">{metric.description}</TableCell>
                  <TableCell>
                    <div className="flex gap-2 w-max">
                      <EditMetricDialog metric={metric} defaultOpen={metric.id.toString() === queryMetricId} />
                      <UsedByDialog metric={metric} />
                      <Button
                        cva={{ intent: "reject", size: "small" }}
                        title="Delete metric"
                        onClick={() => onDeleteMetric(metric)}
                      >
                        <Trash2Icon size={20} strokeWidth={1} />
                        Delete
                      </Button>
                    </div>
                  </TableCell>
                </TableRow>
              ))
            )}
          </TableBody>
        </Table>
      </div>
    </div>
  );
};

const CreateMetricDialog = ({
  problemType,
  className,
  defaultOpen = false,
}: {
  problemType: ProblemType;
  className?: string;
  defaultOpen?: boolean;
}) => {
  const [isOpen, setIsOpen] = React.useState(defaultOpen);
  return (
    <Dialog open={isOpen} onOpenChange={setIsOpen}>
      <DialogTrigger className={className} asChild>
        <Button cva={{ intent: "primary", size: "mediumLong" }}>
          <PlusIcon size={16} />
          Add new metric
        </Button>
      </DialogTrigger>
      <DialogContent className="max-w-[95%] max-h-[95%] w-fit overflow-auto">
        <CreateMetric problemType={problemType} onCancel={() => setIsOpen(false)} onCreated={() => setIsOpen(false)} />
      </DialogContent>
    </Dialog>
  );
};

const CreateMetric = ({
  problemType,
  onCancel,
  onCreated,
}: {
  problemType: ProblemType;
  onCancel: () => void;
  onCreated: () => void;
}) => {
  const configurator = metricConfigurators[problemType];
  const [metricInput, setMetricInput] = React.useState(configurator.createDefault());
  const [ref, isFormValid] = useIsFormValid();
  const [createMetric, { loading, error }] = useMutation(createMonitoringMetricMutation, {
    onCompleted: onCreated,
    update: (cache, { data }) => {
      // Add the new metric to the cached metric list. The list is sorted by metric name, which should normally match
      // the backend operation
      cache.modify({
        fields: {
          monitoring_metrics(existingMetrics = [], { readField }) {
            const newMetricRef = cache.writeFragment({
              data: data?.create_monitoring_metric as any,
              fragment: monitoringMetricDetailsFragment,
              fragmentName: "MonitoringMetricDetails",
            });
            const refs = existingMetrics.concat(newMetricRef);
            return _.sortBy(refs, (ref) => readField("name", ref));
          },
        },
      });
    },
  });

  const onSubmit = (e: React.FormEvent) => {
    e.preventDefault();
    createMetric({ variables: { metric: metricInput } });
  };

  return (
    <form ref={ref} onSubmit={onSubmit}>
      <h3 className="text-xl font-semibold mb-2">Create new {problemTypeLabels[problemType].toLowerCase()} metric</h3>
      <p className="text-gray-400 text-sm max-w-[800px] mb-4">
        Custom metrics are performance metrics you define using python code. They can be used to measure business
        impact, model performance, or any other metric that is important to your use case.
      </p>
      <configurator.Component metric={metricInput} onMetricChange={setMetricInput} />
      <div className="flex justify-center items-center gap-2 mt-6">
        <Button
          cva={{ intent: "primary", size: "mediumLong" }}
          type="submit"
          className="gap-2"
          disabled={!isFormValid || loading}
        >
          {loading ? <Loader2Icon size={20} className="animate-spin" /> : <SaveIcon size={20} />}
          Save metric
        </Button>
        <Button cva={{ intent: "reject", size: "mediumLong" }} className="gap-2" onClick={onCancel} disabled={loading}>
          <XIcon size={20} />
          Cancel
        </Button>
      </div>
      {error && <p className="mt-4 text-center text-red-500">{error.message}</p>}
    </form>
  );
};

const EditMetricDialog = ({
  metric,
  className,
  defaultOpen = false,
}: {
  metric: MetricType;
  className?: string;
  defaultOpen?: boolean;
}) => {
  const [isOpen, setIsOpen] = React.useState(defaultOpen);
  return (
    <Dialog open={isOpen} onOpenChange={setIsOpen}>
      <DialogTrigger className={className} asChild>
        <Button cva={{ intent: "primary", size: "small" }}>
          <PencilIcon size={20} strokeWidth={1} />
          Edit
        </Button>
      </DialogTrigger>
      <DialogContent className="max-w-[95%] max-h-[95%] w-fit overflow-auto">
        <EditMetric metric={metric} onCancel={() => setIsOpen(false)} onCreated={() => setIsOpen(false)} />
      </DialogContent>
    </Dialog>
  );
};

const EditMetric = ({
  metric,
  onCancel,
  onCreated,
}: {
  metric: MetricType;
  onCancel: () => void;
  onCreated: () => void;
}) => {
  const configurator = metricConfigurators[metric.problemType];
  const [metricInput, setMetricInput] = React.useState({
    id: metric.id,
    ...configurator.convertToInput(metric as any),
  });
  const [ref, isFormValid] = useIsFormValid();
  const [editMetric, { loading, error }] = useMutation(editMonitoringMetricMutation, {
    onCompleted: (data) => {
      if (data.edit_monitoring_metric.__typename === "EditMetricRequiresInvalidation") {
        confirm({
          title: "Edit requires invalidation",
          message: (
            <div className="flex flex-col gap-4">
              <span>
                The changes you made require invalidation of all models using this metric. This means the next NannyML
                run for the associated models will re-analyze all data. This may take a long time and uses additional
                compute.
              </span>
              <span>Are you sure you want to make this change?</span>
            </div>
          ),
          confirmIntent: "reject",
        }).then((confirmed) => {
          if (confirmed) {
            editMetric({ variables: { metric: { ...metricInput, allowInvalidation: true } } });
          }
        });
      } else {
        onCreated();
      }
    },
  });

  const onSubmit = (e: React.FormEvent) => {
    e.preventDefault();
    editMetric({ variables: { metric: metricInput } });
  };

  return (
    <form ref={ref} onSubmit={onSubmit}>
      <h3 className="text-xl font-semibold mb-2">Edit {metric.name} metric</h3>
      <p className="text-gray-400 text-sm max-w-[800px] mb-4">
        Custom metrics are performance metrics you define using python code. They can be used to measure business
        impact, model performance, or any other metric that is important to your use case.
      </p>
      <configurator.Component metric={metricInput} onMetricChange={(v) => setMetricInput({ id: metric.id, ...v })} />
      <div className="flex justify-center items-center gap-2 mt-6">
        <Button
          cva={{ intent: "primary", size: "mediumLong" }}
          type="submit"
          className="gap-2"
          disabled={!isFormValid || loading}
        >
          {loading ? <Loader2Icon size={20} className="animate-spin" /> : <SaveIcon size={20} />}
          Save metric
        </Button>
        <Button cva={{ intent: "reject", size: "mediumLong" }} className="gap-2" onClick={onCancel} disabled={loading}>
          <XIcon size={20} />
          Cancel
        </Button>
      </div>
      {error && <p className="mt-4 text-center text-red-500">{error.message}</p>}
    </form>
  );
};

const UsedByDialog = ({ metric, className }: { metric: MetricType; className?: string }) => {
  return (
    <Dialog>
      <DialogTrigger className={className} asChild>
        <Button cva={{ intent: "secondary", size: "small" }} title="See which models use this metric">
          <EyeIcon size={20} strokeWidth={1} />
          Used by
        </Button>
      </DialogTrigger>
      <DialogContent className="max-w-[95%] max-h-[95%] w-fit overflow-auto">
        <React.Suspense fallback={<RequestStateLayout isLoading loaderText="Looking up models..." />}>
          <UsedBy metric={metric} />
        </React.Suspense>
      </DialogContent>
    </Dialog>
  );
};

const UsedBy = ({ metric }: { metric: MetricType }) => {
  const { data } = useSuspenseQuery(getMonitoringMetricModelsQuery, { variables: { metricId: metric.id } });
  const models = data.monitoring_metric.models;

  return (
    <div className="grid grid-cols-[auto_auto] divide-y divide-gray-600">
      <div className="col-span-full">
        <h3 className="text-xl font-semibold">Models using {metric.name}</h3>
        {models.length > 0 ? (
          <span className="text-gray-400 inline-block pb-4">
            The '{metric.name}' metric is used by the following models that you monitor
          </span>
        ) : (
          <span className="text-gray-400">The '{metric.name}' metric is not used by any models yet</span>
        )}
      </div>
      {models.map((model) => (
        <React.Fragment key={model.id}>
          <div className="flex flex-col py-3">
            <MonitoringModelName modelId={model.id} />
            <span className="text-gray-400">
              {model.latestRun
                ? model.latestRun.state === RunState.Running
                  ? "Currently running"
                  : `Last run at ${formatISODateTime(new Date(model.latestRun.completedAt))}`
                : "Model not run yet"}
            </span>
          </div>
          <div className="flex items-center">
            <Link to={`/monitoring/model/${model.id}`}>
              <Button cva={{ intent: "primary", size: "small" }} title="View model">
                <EyeIcon size={20} strokeWidth={1} />
                View
              </Button>
            </Link>
          </div>
        </React.Fragment>
      ))}
    </div>
  );
};

const problemTypeMapping = {
  [ProblemType.BinaryClassification]: "ClassificationMetric",
  [ProblemType.MulticlassClassification]: "ClassificationMetric",
  [ProblemType.Regression]: "RegressionMetric",
} as const;

const metricConfigurators: {
  [P in ProblemType]: {
    Component: React.FC<MetricConfiguratorProps>;
    convertToInput: (metric: Extract<MetricType, { __typename: (typeof problemTypeMapping)[P] }>) => CreateMetricInput;
    createDefault: () => CreateMetricInput;
  };
} = {
  [ProblemType.BinaryClassification]: {
    Component: ClassificationMetric,
    convertToInput: convertClassificationMetricToInput,
    createDefault: () => ({
      name: "",
      problemType: ProblemType.BinaryClassification,
      classification: { calculateFn: "", estimateFn: null },
    }),
  },
  [ProblemType.MulticlassClassification]: {
    Component: ClassificationMetric,
    convertToInput: convertClassificationMetricToInput,
    createDefault: () => ({
      name: "",
      problemType: ProblemType.MulticlassClassification,
      classification: { calculateFn: "", estimateFn: null },
    }),
  },
  [ProblemType.Regression]: {
    Component: RegressionMetric,
    convertToInput: convertRegressionMetricToInput,
    createDefault: () => ({
      name: "",
      problemType: ProblemType.Regression,
      regression: { aggregateFn: "", lossFn: "" },
    }),
  },
};
