import { ApolloClient, useApolloClient, useMutation } from "@apollo/client";
import { Dispatch, SetStateAction, useMemo, useState } from "react";
import { useNavigate } from "react-router-dom";

import {
  Chunking,
  ColumnInput,
  ColumnType,
  CreateDataSourceInput,
  CreateModelInput,
  DataSourceInspectData,
  ProblemType,
  ProductType,
  StorageInput,
  gql,
} from "@/apis/nannyml";
import { DatasetStorageDetails, inspectDataSource } from "@/components/DatasetStorage";
import { inspectCategoricalColumn } from "@/components/DatasetStorage/Inspect";
import { SchemaConfiguration, SchemaTable } from "@/components/Schema";
import {
  Wizard,
  WizardContent,
  WizardControlButtons,
  WizardOverview,
  WizardStep,
  getWizardControls,
} from "@/components/Wizard";
import { ModelContextProvider } from "@/components/monitoring/ModelContext";
import { problemTypeColumns } from "@/domains/monitoring";
import { cn } from "@/lib/utils";

import { ClassMapping } from "./ClassMapping";
import { DataRequirements } from "./DataRequirements";
import { MetricConfig } from "./MetricConfig";
import { ModelDetails } from "./ModelDetails";
import { Review } from "./Review";
import { TargetData } from "./TargetData";

const REFERENCE_DATASET_NAME = "reference";
const MONITORED_DATASET_NAME = "analysis";
const TARGET_DATASET_NAME = "target";

const createMonitoringModelMutation = gql(/* GraphQL */ `
  mutation CreateMonitoringModel($input: CreateModelInput!) {
    create_monitoring_model(model: $input) {
      id
    }
  }
`);

/**
 * Updates a data source in the create model input. It will either add a new data source or update an existing one with
 * matching name.
 * @param setModelInput Callback to update the model input
 * @param input Values to update the data source with
 */
const updateDataSource = (
  setModelInput: Dispatch<SetStateAction<Partial<CreateModelInput>>>,
  input: Pick<CreateDataSourceInput, "name"> & Partial<CreateDataSourceInput>
) => {
  setModelInput((prev) => {
    const index = prev.dataSources?.findIndex((ds) => ds.name === input.name) ?? -1;
    if (!prev.dataSources || index === -1) {
      return { dataSources: (prev.dataSources ?? []).concat(input as CreateDataSourceInput) };
    } else {
      return {
        dataSources: prev.dataSources
          .slice(0, index)
          .concat(prev.dataSources.slice(index + 1))
          .concat({ ...prev.dataSources[index], ...input }),
      };
    }
  });
};

/**
 * Generates a function that can be called to update storage info of a data source in the create model input. The
 * columns will be reset to an empty array as the storage should be inspected to obtain that information.
 * @param setModelInput Callback to update the model input
 * @param input Data source name & storage info to update the data source with
 * @returns Function that can be called to update the storage info of the data source
 */
const updateDataSourceStorageInfo =
  (
    setModelInput: Dispatch<SetStateAction<Partial<CreateModelInput>>>,
    input: Omit<CreateDataSourceInput, "columns" | "storageInfo">
  ) =>
  (storageInfo: StorageInput | null | undefined) => {
    updateDataSource(setModelInput, { ...input, storageInfo, columns: [] });
  };

const inspectMonitoringDataset = (
  client: ApolloClient<any>,
  name: string,
  modelInput: Partial<CreateModelInput>,
  setModelInput: Dispatch<SetStateAction<Partial<CreateModelInput>>>,
  setDataSourceHeads: Dispatch<SetStateAction<Record<string, DataSourceInspectData["head"]>>>
) => {
  const ds = modelInput.dataSources?.find((ds) => ds.name === name);
  if (!ds) {
    throw new Error(`Data source '${name}' not found`);
  }

  let schema;
  if (name !== REFERENCE_DATASET_NAME) {
    schema = modelInput.dataSources?.find((ds) => ds.name === REFERENCE_DATASET_NAME)?.columns;
  }

  return inspectDataSource(
    client,
    ds,
    ProductType.Monitoring,
    modelInput.problemType!,
    (columns) => updateDataSource(setModelInput, { name, columns }),
    (head) => setDataSourceHeads((prev) => ({ ...prev, [name]: head })),
    schema
  );
};

const updateSchema = (
  setModelInput: Dispatch<SetStateAction<Partial<CreateModelInput>>>,
  schema: ColumnInput[],
  impactsRuntimeConfig: boolean
) =>
  setModelInput(({ dataSources, runtimeConfig }) => ({
    // Update the schema and remove columns from all data sources other than reference to force
    // re-inspection. Runtime config is also reset as it may depend on the schema.
    dataSources: dataSources!.map((ds) => ({
      ...ds,
      columns: ds.name === REFERENCE_DATASET_NAME ? schema : [],
    })),
    runtimeConfig: impactsRuntimeConfig ? undefined : runtimeConfig,
  }));

export const AddMonitoringModelWizard = () => {
  const navigate = useNavigate();
  const client = useApolloClient();
  const [createModel] = useMutation(createMonitoringModelMutation);
  const [modelInput, setModelInput] = useState<Partial<CreateModelInput>>({});
  const [dataSourceHeads, setDataSourceHeads] = useState<Record<string, DataSourceInspectData["head"]>>({});
  const [chunking, setChunking] = useState<Chunking>();
  const [nrOfRows, setNrOfRows] = useState<number | null>(null);
  const [classes, setClasses] = useState<string[]>([]);

  const steps: WizardStep<CreateModelInput>[] = useMemo(
    () =>
      (
        [
          {
            title: "Data requirements",
            subtitle: "Find out what data you need to provide",
            isCompleted: () => true,
            render: () => <DataRequirements />,
          },
          {
            title: "Model details",
            subtitle: "Define details of the model",
            render: (modelInput, setModelInput) => (
              <ModelDetails
                modelInput={modelInput}
                onModelInputChange={setModelInput}
                chunking={chunking}
                onChunkingChange={setChunking}
                nrOfRows={nrOfRows}
                onNrOfRowsChange={setNrOfRows}
              />
            ),
          },
          {
            title: "Reference data",
            subtitle: "Provide reference data for the model",
            isCompleted: (modelInput) => Boolean(modelInput.dataSources?.some((ds) => ds.hasReferenceData)),
            onComplete: (modelInput, setModelInput) =>
              inspectMonitoringDataset(client, REFERENCE_DATASET_NAME, modelInput, setModelInput, setDataSourceHeads),
            render: (modelInput, setModelInput) => (
              <DatasetStorageDetails
                name="reference"
                value={modelInput.dataSources?.find((ds) => ds.hasReferenceData)?.storageInfo ?? null}
                onChange={updateDataSourceStorageInfo(setModelInput, {
                  name: REFERENCE_DATASET_NAME,
                  hasReferenceData: true,
                  hasAnalysisData: false,
                })}
              />
            ),
          },
          {
            title: "Configure schema",
            subtitle: "Define columns for the model",
            isCompleted: (modelInput) =>
              Boolean(
                modelInput.dataSources?.some((ds) => ds.name === REFERENCE_DATASET_NAME && ds.columns.length > 0)
              ),
            onComplete:
              modelInput.problemType !== ProblemType.MulticlassClassification
                ? undefined
                : (modelInput) => {
                    const refDs = modelInput.dataSources!.find((ds) => ds.name === REFERENCE_DATASET_NAME)!;
                    return inspectCategoricalColumn(
                      client,
                      refDs.storageInfo!,
                      refDs.columns.find((col) => col.columnType === ColumnType.Target)!.name,
                      refDs.columns.filter((col) => col.columnType === ColumnType.PredictionScore).length
                    ).then(setClasses);
                  },
            render: (modelInput, setModelInput) => {
              const schema = modelInput.dataSources?.find((ds) => ds.name === REFERENCE_DATASET_NAME)!.columns!;
              const columnConfig = problemTypeColumns[modelInput.problemType!];
              const onSchemaChange = (schema: ColumnInput[]) => updateSchema(setModelInput, schema, true);

              return (
                <div className="self-stretch flex flex-col justify-around max-w-full gap-8">
                  <SchemaConfiguration columnConfig={columnConfig} schema={schema} onSchemaChange={onSchemaChange} />
                  <SchemaTable
                    columnConfig={columnConfig}
                    schema={schema}
                    head={dataSourceHeads[REFERENCE_DATASET_NAME] ?? []}
                    onSchemaChange={onSchemaChange}
                  />
                </div>
              );
            },
          },
          modelInput.problemType === ProblemType.MulticlassClassification && {
            title: "Class mapping",
            subtitle: "Map class labels to columns",
            onComplete: (modelInput) => {
              const mappedClasses = modelInput
                .dataSources!.find((ds) => ds.name === REFERENCE_DATASET_NAME)!
                .columns.filter((col) => col.columnType === ColumnType.PredictionScore)
                .map((col) => col.className);

              const uniqueClasses = new Set(mappedClasses);
              if (uniqueClasses.size !== mappedClasses.length) {
                return Promise.reject(
                  "Duplicate class names found. Each predicted probability column must map to a different class"
                );
              }
            },
            render: (modelInput, setModelInput) => (
              <ClassMapping
                schema={modelInput.dataSources?.find((ds) => ds.name === REFERENCE_DATASET_NAME)?.columns ?? []}
                classes={classes}
                onSchemaChange={(schema) => updateSchema(setModelInput, schema, false)}
              />
            ),
          },
          {
            title: "Monitored data",
            subtitle: "Provide monitored data for the model",
            isCompleted: (modelInput) =>
              Boolean(modelInput.dataSources?.some((ds) => ds.name === MONITORED_DATASET_NAME && ds.storageInfo)),
            onComplete: (modelInput, setModelInput) =>
              inspectMonitoringDataset(client, MONITORED_DATASET_NAME, modelInput, setModelInput, setDataSourceHeads),
            render: (modelInput, setModelInput) => (
              <DatasetStorageDetails
                name="monitored"
                value={modelInput.dataSources?.find((ds) => ds.name === MONITORED_DATASET_NAME)?.storageInfo ?? null}
                onChange={updateDataSourceStorageInfo(setModelInput, {
                  name: MONITORED_DATASET_NAME,
                  hasReferenceData: false,
                  hasAnalysisData: true,
                })}
                reference={modelInput.dataSources?.find((ds) => ds.name === REFERENCE_DATASET_NAME)?.storageInfo}
              />
            ),
          },
          modelInput.dataSources?.some(
            (ds) =>
              ds.name === MONITORED_DATASET_NAME &&
              ds.columns.length > 0 &&
              !ds.columns.some((column) => column.columnType === ColumnType.Target)
          ) && {
            title: "Target data",
            subtitle: "Optionally provide target data for the model",
            isCompleted: (modelInput) =>
              Boolean(
                modelInput.dataSources?.some((ds) => ds.name === TARGET_DATASET_NAME && ds.storageInfo !== undefined)
              ),
            onComplete: (modelInput, setModelInput) => {
              const targetDs = modelInput.dataSources?.find((ds) => ds.name === TARGET_DATASET_NAME)!;
              const schema = modelInput.dataSources?.find((ds) => ds.name === REFERENCE_DATASET_NAME)?.columns;
              if (targetDs.storageInfo) {
                return inspectMonitoringDataset(
                  client,
                  TARGET_DATASET_NAME,
                  modelInput,
                  setModelInput,
                  setDataSourceHeads
                );
              } else {
                // No targets provided. Create default data source with just target & identifier columns
                updateDataSource(setModelInput, {
                  name: TARGET_DATASET_NAME,
                  columns: schema?.filter(
                    (col) => col.columnType === ColumnType.Target || col.columnType === ColumnType.Identifier
                  ),
                });
              }
            },
            render: (modelInput, setModelInput) => (
              <TargetData
                name="target"
                value={modelInput.dataSources?.find((ds) => ds.name === TARGET_DATASET_NAME)?.storageInfo}
                onChange={updateDataSourceStorageInfo(setModelInput, {
                  name: TARGET_DATASET_NAME,
                  hasReferenceData: false,
                  hasAnalysisData: true,
                })}
                reference={modelInput.dataSources?.find((ds) => ds.name === REFERENCE_DATASET_NAME)?.storageInfo}
              />
            ),
          },
          {
            title: "Metrics",
            subtitle: "Configure metrics to evaluate",
            isCompleted: (modelInput) => Boolean(modelInput.runtimeConfig && modelInput.kpm),
            render: (modelInput, setModelInput) => {
              const schema = {
                columns: modelInput
                  .dataSources!.find((ds) => ds.name === REFERENCE_DATASET_NAME)!
                  .columns.map(({ name, columnType }) => ({ name, columnType })),
                hasAnalysisTargets: modelInput.dataSources!.some(
                  (ds) => ds.hasAnalysisData && ds.columns.some((col) => col.columnType === ColumnType.Target)
                ),
              };
              return (
                <ModelContextProvider name={modelInput.name!} problemType={modelInput.problemType!} schema={schema}>
                  <MetricConfig
                    modelInput={modelInput as CreateModelInput}
                    onModelInputChange={setModelInput}
                    chunking={chunking!}
                    nrOfRows={nrOfRows}
                    classes={classes}
                  />
                </ModelContextProvider>
              );
            },
          },
          {
            title: "Review",
            subtitle: "Review your model settings",
            isCompleted: (_, ctx) => getWizardControls(ctx).arePreviousStepsCompleted(),
            render: (modelInput) => <Review settings={modelInput as CreateModelInput} />,
          },
        ] as WizardStep<CreateModelInput>[]
      ).filter(Boolean),
    [client, dataSourceHeads, chunking, nrOfRows, classes, modelInput.dataSources, modelInput.problemType]
  );

  const onComplete = (input: CreateModelInput) => {
    return createModel({ variables: { input } })
      .then((data) => {
        if (data.errors || !data.data) {
          throw new Error("Failed to add monitoring model: " + data.errors?.[0].message);
        }

        navigate(`/monitoring/model/${data.data.create_monitoring_model.id}`);
      })
      .catch((error) => Promise.reject(`Failed to add monitoring model: ${error.message}`));
  };

  return (
    <div className="flex h-full bg-dark overflow-y-auto py-4">
      <Wizard
        className={cn(
          "rounded-xl border border-gray-600 bg-primaryBg px-8",
          "m-auto min-h-[max(75%,750px)] h-fit w-4/5 2xl:w-9/12"
        )}
        value={modelInput}
        onChange={setModelInput}
        steps={steps}
        onComplete={onComplete}
        preventForwardJump
      >
        <h3 className="text-2xl text-center">Add monitoring model</h3>
        <WizardOverview />
        <WizardContent />
        <WizardControlButtons finishLabel="Add model" />
      </Wizard>
    </div>
  );
};
