import { yupResolver } from '@hookform/resolvers/yup';
import ExpandIcon from '@mui/icons-material/ExpandMore';
import {
  Accordion,
  AccordionDetails,
  AccordionSummary,
  CircularProgress,
  Grid,
  Input,
  Slider,
  Typography,
} from '@mui/material';
import { useMutation, useQuery } from '@tanstack/react-query';
import { filter, forEach, get, isEmpty } from 'lodash';
import { useSnackbar } from 'notistack';
import React, { ReactElement, useState } from 'react';
import { Controller, SubmitHandler, useForm } from 'react-hook-form';
import * as yup from 'yup';

import { getArtifactsInference, runMultiplexThreshold } from 'api/platform';
import { getSlidesDataForArtifactsResults } from 'api/study';
import { OrchestrationBySlideByType } from 'components/Pages/CalculateFeatures';
import InferenceModelsForSlides from 'components/Pages/CalculateFeatures/InferenceModelsForSlides';
import { getModelsTypeByModelInferences } from 'components/Pages/CalculateFeatures/InferenceModelsForSlides/utils';
import { BinaryClassifierConfig } from 'interfaces/inferenceArtifacts';
import { MultiplexThresholdConfig } from 'interfaces/jobs/multiplex/thresholdParams';
import { humanize } from 'utils/helpers';
import { useCasesParams } from 'utils/useCasesParams';
import { encodeQueryParamsUsingSchema } from 'utils/useEncodedFilters';
import { PlatformStepper } from '../PlatformStepper';
import { JobNameAndDescriptionStep } from '../utils';

const SNACK_BAR_KEY_RUN_MULTIPLEX_THRESHOLD = 'RUN_MULTIPLEX_NORMALIZATION';

export const defaultThresholdConfig: MultiplexThresholdConfig = {
  binaryClassifierInferenceArtifactUrls: [],
  rangeStart: 0.2,
  rangeEnd: 0.75,
  maxThreshold: 0.5,
  precisionDecimal: 3,
};

export interface ThresholdFormValues {
  jobName: string;
  jobDescription: string;
  configParams: MultiplexThresholdConfig;
}

const validationSchema = [
  yup.object({
    configParams: yup.object({
      binaryClassifierInferenceArtifactUrls: yup
        .array()
        .required('Binary classifier results are required')
        .min(1, 'Binary classifier results are required'),
    }),
  }),
  yup.object({}),
];

export interface RunMultiplexThresholdProps {
  onClose: () => void;
}

export const RunMultiplexThreshold = (props: RunMultiplexThresholdProps): ReactElement => {
  const { onClose } = props;
  const { enqueueSnackbar, closeSnackbar } = useSnackbar();
  const [activeStep, setActiveStep] = useState(0);
  const currentValidationSchema = validationSchema[activeStep];

  const [isStepFailed, setIsStepFailed] = useState<Record<number, boolean>>({});
  const { casesParams, schema, options } = useCasesParams();
  const getEncodedParams = () => encodeQueryParamsUsingSchema(casesParams, schema, options);

  const studyId = casesParams.filters?.studyId;
  const enableInferenceArtifactsQuery = Boolean(studyId);

  const {
    data: inferenceArtifacts,
    isLoading: isInferenceArtifactsLoading,
    isFetching: isInferenceArtifactsFetching,
  } = useQuery({
    queryKey: [
      'artifactsInference',
      { studyId: studyId, caseParams: getEncodedParams(), configParams: BinaryClassifierConfig },
    ],
    queryFn: ({ signal }) =>
      getArtifactsInference({ studyId, caseParams: getEncodedParams(), configParams: BinaryClassifierConfig }, signal),
    enabled: enableInferenceArtifactsQuery,
  });

  const { data: slides, isLoading: isSlidesLoading } = useQuery({
    queryKey: ['slidesDataForArtifactsResults', { studyId, caseParams: getEncodedParams() }],
    queryFn: ({ signal }) => getSlidesDataForArtifactsResults(studyId, getEncodedParams(), signal),
    enabled: Boolean(studyId),
  });

  const [selectedOrchestrations, setSelectedOrchestrations] = useState<OrchestrationBySlideByType>({});

  const {
    register,
    handleSubmit,
    control,
    watch,
    setValue,
    trigger,
    formState: { errors },
  } = useForm<ThresholdFormValues>({
    mode: 'onChange',
    defaultValues: {
      configParams: defaultThresholdConfig,
    },
    resolver: yupResolver(currentValidationSchema),
  });

  const handleSelectedOrchestrationsChange = (orchestrations: OrchestrationBySlideByType) => {
    setSelectedOrchestrations(orchestrations);

    const finalInferenceResultsArtifactUrls: string[] = [];

    const slidesWithResults = filter(slides, (slide) => {
      return !isEmpty(orchestrations[slide.slideId]);
    });

    forEach(slidesWithResults, (slide) => {
      forEach(orchestrations[slide.slideId], (orchestrationData, modelType) => {
        finalInferenceResultsArtifactUrls.push(
          orchestrationData?.orchestration.orchestrationResultArtifactUrlPattern.replace('{slide_id}', slide.slideId)
        );
      });
    });

    setValue('configParams.binaryClassifierInferenceArtifactUrls', finalInferenceResultsArtifactUrls);
    trigger('configParams.binaryClassifierInferenceArtifactUrls');
  };

  const checkValidationAndSetIsStepFailed = (stepIndex: number, objectToValidate: Record<string, any>) => {
    validationSchema[stepIndex]
      .validate(objectToValidate)
      .then(() => {
        setIsStepFailed((prev) => ({
          ...prev,
          [stepIndex]: false,
        }));
      })
      .catch(() => {
        setIsStepFailed((prev) => ({
          ...prev,
          [stepIndex]: true,
        }));
      });
  };

  const runMultiplexThresholdMutation = useMutation(runMultiplexThreshold, {
    onError: () => {
      enqueueSnackbar('Error occurred, Multiplex Threshold failed', {
        variant: 'error',
      });
    },
    onSuccess: () => {
      enqueueSnackbar('Multiplex Threshold Started', { variant: 'success' });
    },
    onSettled() {
      closeSnackbar(SNACK_BAR_KEY_RUN_MULTIPLEX_THRESHOLD);
    },
  });

  const onSubmit: SubmitHandler<ThresholdFormValues> = async (data) => {
    const validated = await validationSchema[0]
      .validate(data)
      .then(() => {
        setIsStepFailed((prev) => ({
          ...prev,
          [0]: true,
        }));
        return true;
      })
      .catch((e) => {
        return false;
      });

    if (!validated) {
      return;
    }

    runMultiplexThresholdMutation.mutate({
      ...casesParams,
      ...data,
    });

    enqueueSnackbar({
      variant: 'success',
      message: (
        <Grid container>
          <Grid item>
            <Typography>Waiting for Multiplex Threshold to start</Typography>
          </Grid>
          <Grid item>
            <CircularProgress sx={{ marginLeft: 10 }} color="inherit" size={20} />
          </Grid>
        </Grid>
      ),
      key: SNACK_BAR_KEY_RUN_MULTIPLEX_THRESHOLD,
      autoHideDuration: null,
    });

    onClose();
  };

  const validateThresholdParams = async () => {
    checkValidationAndSetIsStepFailed(0, {
      configParams: watch('configParams'),
    });
  };

  const steps = [
    {
      label: 'Threshold Configuration',
      content: (
        <Grid container direction="column" spacing={2}>
          <Grid item>
            {get(errors, 'configParams.binaryClassifierInferenceArtifactUrls') && (
              <Typography color="error">
                {humanize(get(errors, 'configParams.binaryClassifierInferenceArtifactUrls')?.message)}
              </Typography>
            )}
            <InferenceModelsForSlides
              studyId={studyId}
              slides={slides}
              modelsType={getModelsTypeByModelInferences(inferenceArtifacts)}
              inferenceModels={inferenceArtifacts}
              isLoading={isInferenceArtifactsLoading || isInferenceArtifactsFetching || isSlidesLoading}
              selectedOrchestrations={selectedOrchestrations}
              setSelectedOrchestrations={handleSelectedOrchestrationsChange}
            />
          </Grid>
          <SliderInputWithLabel
            label="Range Start"
            fieldName="configParams.rangeStart"
            control={control}
            min={0}
            max={1}
            step={0.01}
          />
          <SliderInputWithLabel
            label="Range End"
            fieldName="configParams.rangeEnd"
            control={control}
            min={0}
            max={1}
            step={0.01}
          />
          <SliderInputWithLabel
            label="Max Threshold"
            fieldName="configParams.maxThreshold"
            control={control}
            min={0}
            max={1}
            step={0.01}
          />
          <SliderInputWithLabel
            label="Precision Decimal"
            fieldName="configParams.precisionDecimal"
            control={control}
            min={0}
            max={10}
            step={1}
          />
        </Grid>
      ),
      onNextOrBackClick: validateThresholdParams,
    },
    {
      label: 'Job Name and Description',
      subLabel: activeStep > 0 && watch('jobName'),
      optional: true,
      content: <JobNameAndDescriptionStep control={control} errors={errors} register={register} />,
    },
  ];

  return (
    <>
      <PlatformStepper
        handleSubmit={handleSubmit(onSubmit)}
        steps={steps}
        setActiveStepForValidation={setActiveStep}
        isStepFailed={isStepFailed}
      />
      <Accordion>
        <AccordionSummary expandIcon={<ExpandIcon />}>Threshold Params Summary (JSON)</AccordionSummary>
        <AccordionDetails>
          <Typography component="pre">
            {JSON.stringify(
              {
                jobName: watch('jobName'),
                jobDescription: watch('jobDescription'),
                configParams: watch('configParams'),
              },
              null,
              2
            )}
          </Typography>
        </AccordionDetails>
      </Accordion>
    </>
  );
};

interface SliderInputWithLabelProps {
  label: string;
  fieldName: string;
  control: any;
  min: number;
  max: number;
  step: number;
}

const SliderInputWithLabel = ({ label, control, fieldName, min, max, step }: SliderInputWithLabelProps) => {
  return (
    <Grid item container direction="row" spacing={2} alignItems="center">
      <Grid item>
        <Typography variant="body2">{label}</Typography>
      </Grid>
      <Grid item>
        <Controller
          control={control}
          name={fieldName}
          render={({ field: { onChange, value } }) => (
            <Slider
              sx={{ width: 200 }}
              value={value}
              onChange={onChange}
              valueLabelDisplay="auto"
              min={min}
              max={max}
              step={step}
              aria-labelledby="range-slider"
            />
          )}
        />
      </Grid>
      <Grid item>
        <Controller
          control={control}
          name={fieldName}
          render={({ field: { onChange, value } }) => (
            <Input
              value={value}
              size="small"
              onChange={onChange}
              onBlur={() => {
                if (value < min) {
                  onChange(min);
                }
                if (value > max) {
                  onChange(max);
                }
              }}
              inputProps={{
                step: step,
                min: min,
                max: max,
                type: 'number',
                'aria-labelledby': 'input-slider',
              }}
            />
          )}
        />
      </Grid>
    </Grid>
  );
};
