import { useSuspenseQuery } from "@apollo/client";
import _ from "lodash";
import { useMemo } from "react";

import { Column, DataSource, DataSourceSchema, ProblemType, useFragment } from "@/apis/nannyml";
import { getModelSchema } from "@/apis/nannyml/queries/getModelSchema";
import { getModelSchemaDetails } from "@/apis/nannyml/queries/getModelSchemaDetails";

export type ModelSchema = {
  modelId: number;
  name: string;
  problemType: ProblemType;
  columns: Record<string, Column>;
};

export type ModelSchemaDetails = {
  modelId: number;
  name: string;
  problemType: ProblemType;
  dataSources: (Omit<DataSource, "head" | "events"> & { head: { [column: string]: any }[] })[];
  columns: Record<string, Column>;
};

export const useModelSchema = (modelId: number) => {
  const { data } = useSuspenseQuery(getModelSchema, {
    variables: {
      modelId,
      dataSourceFilter: {
        hasReferenceData: true,
      },
    },
  });

  return useMemo(() => {
    if (!data.monitoring_model) {
      throw new Error("Model not found");
    }

    const dataSources = useFragment(DataSourceSchema, data.monitoring_model.dataSources);
    return {
      modelId,
      name: data.monitoring_model.name,
      problemType: data.monitoring_model.problemType,
      columns: _.keyBy(
        dataSources.flatMap((dataSource) => dataSource.columns),
        (column) => column.name
      ),
    };
  }, [data]);
};

export const useModelSchemaDetails = (modelId: number, nrRows: number = 10) => {
  const { data } = useSuspenseQuery(getModelSchemaDetails, {
    variables: {
      modelId,
      nrRows,
    },
  });

  return useMemo(() => {
    if (!data.monitoring_model) {
      throw new Error("Model not found");
    }

    const dataSources = data.monitoring_model.dataSources.map((dataSource) => ({
      ...dataSource,
      head: JSON.parse(dataSource.head),
    }));

    return {
      modelId,
      name: data.monitoring_model.name,
      problemType: data.monitoring_model.problemType,
      dataSources,
      columns: _.uniqBy(
        dataSources.flatMap((dataSource) => dataSource.columns),
        (column) => column.name
      ),
    };
  }, [data]);
};
