import { yupResolver } from '@hookform/resolvers/yup';
import ExpandIcon from '@mui/icons-material/ExpandMore';
import {
  Accordion,
  AccordionDetails,
  AccordionSummary,
  Checkbox,
  CircularProgress,
  FormControl,
  FormControlLabel,
  FormHelperText,
  Grid,
  InputLabel,
  MenuItem,
  Select,
  TextField,
  Typography,
} from '@mui/material';
import { useMutation, useQuery } from '@tanstack/react-query';
import { first, flatMap, get, isEmpty, map, uniq, values } from 'lodash';
import { useSnackbar } from 'notistack';
import React, { ReactElement, useCallback, useEffect, useMemo, useState } from 'react';
import { Controller, SubmitHandler, useForm } from 'react-hook-form';
import * as yup from 'yup';

import { getArtifactsInference, runMultiplexThreshold } from 'api/platform';
import { getSlidesDataForArtifactsResults } from 'api/study';
import { AnnotationAssignmentAutocomplete } from 'components/atoms/AnnotationAssignmentAutocomplete';
import { CellRulesPanelSelect } from 'components/atoms/CellRulesPanelSelect';
import { OrchestrationBySlideByType } from 'components/Pages/CalculateFeatures';
import InferenceModelsForSlides from 'components/Pages/CalculateFeatures/InferenceModelsForSlides';
import { getModelsTypeByModelInferences } from 'components/Pages/CalculateFeatures/InferenceModelsForSlides/utils';
import { BinaryClassifierConfig } from 'interfaces/inferenceArtifacts';
import { MultiplexThresholdConfig, ThresholdMethod } from 'interfaces/jobs/multiplex/thresholdParams';
import { MULTIPLEX_STAIN_ID } from 'interfaces/stainType';
import { humanize } from 'utils/helpers';
import useCellRules from 'utils/queryHooks/cellRule/useCellRules';
import { useCasesParams } from 'utils/useCasesParams';
import { encodeQueryParamsUsingSchema } from 'utils/useEncodedFilters';
import { PlatformStepper } from '../PlatformStepper';
import { convertToCellRules } from './helpers';

const SNACK_BAR_KEY_RUN_MULTIPLEX_THRESHOLD = 'RUN_MULTIPLEX_NORMALIZATION';

export const defaultThresholdConfig: MultiplexThresholdConfig = {
  inferenceResultsArtifactUrl: '',
  thresholdMethod: ThresholdMethod.AUTO,
  annotationToInferenceMatchingDistUm: 4,
  thresholdPerMarker: true,
  saveThresholdedResultsToCsv: false,
};

export interface RunMultiplexThresholdProps {
  onClose: () => void;
}

export const RunMultiplexThreshold = (props: RunMultiplexThresholdProps): ReactElement => {
  const { onClose } = props;
  const { enqueueSnackbar, closeSnackbar } = useSnackbar();
  const [activeStep, setActiveStep] = useState(0);
  const currentValidationSchema = validationSchema[activeStep];

  const [isStepFailed, setIsStepFailed] = useState<Record<number, boolean>>({});
  const { casesParams, schema, options } = useCasesParams();
  const getEncodedParams = () => encodeQueryParamsUsingSchema(casesParams, schema, options);

  const studyId = casesParams.filters?.studyId;
  const enableInferenceArtifactsQuery = Boolean(studyId);

  const {
    data: inferenceArtifacts,
    isLoading: isInferenceArtifactsLoading,
    isFetching: isInferenceArtifactsFetching,
  } = useQuery({
    queryKey: [
      'artifactsInference',
      { studyId: studyId, caseParams: getEncodedParams(), configParams: BinaryClassifierConfig },
    ],
    queryFn: ({ signal }) =>
      getArtifactsInference({ studyId, caseParams: getEncodedParams(), configParams: BinaryClassifierConfig }, signal),
    enabled: enableInferenceArtifactsQuery,
  });

  const { data: slides, isFetched: isSlidesFetched } = useQuery({
    queryKey: ['slidesDataForArtifactsResults', { studyId: studyId, caseParams: getEncodedParams() }],
    queryFn: ({ signal }) => getSlidesDataForArtifactsResults(studyId, getEncodedParams(), signal),
  });

  const [selectedOrchestrations, setSelectedOrchestrations] = useState<OrchestrationBySlideByType>({});
  const selectedOrchestration = useMemo(
    () => first(map(flatMap(flatMap(values(selectedOrchestrations), values)), 'orchestration')),
    [selectedOrchestrations]
  );

  const [selectedPanelId, setSelectedPanelId] = useState<string | null>(null);
  const { data: panel } = useCellRules(selectedPanelId, studyId);

  const {
    register,
    handleSubmit,
    control,
    watch,
    setValue,
    formState: { errors },
  } = useForm<ThresholdFormValues>({
    mode: 'onChange',
    defaultValues: {
      configParams: defaultThresholdConfig,
    },
    resolver: yupResolver(currentValidationSchema),
  });

  useEffect(() => {
    setValue('configParams.inferenceResultsArtifactUrl', selectedOrchestration?.orchestrationResultArtifactUrlPattern, {
      shouldValidate: !isEmpty(errors),
    });
  }, [selectedOrchestration]);

  const cellRules = useMemo(() => {
    const panelCellRules = convertToCellRules(panel);
    setValue('configParams.cellRules', panelCellRules, { shouldValidate: !isEmpty(errors) });
    return panelCellRules;
  }, [panel]);

  const checkValidationAndSetIsStepFailed = (stepIndex: number, objectToValidate: Record<string, any>) => {
    validationSchema[stepIndex]
      .validate(objectToValidate)
      .then(() => {
        setIsStepFailed((prev) => ({
          ...prev,
          [stepIndex]: false,
        }));
      })
      .catch(() => {
        setIsStepFailed((prev) => ({
          ...prev,
          [stepIndex]: true,
        }));
      });
  };

  const runMultiplexThresholdMutation = useMutation(runMultiplexThreshold, {
    onError: () => {
      enqueueSnackbar('Error occurred, Multiplex Threshold failed', {
        variant: 'error',
      });
    },
    onSuccess: () => {
      enqueueSnackbar('Multiplex Threshold Started', { variant: 'success' });
    },
    onSettled() {
      closeSnackbar(SNACK_BAR_KEY_RUN_MULTIPLEX_THRESHOLD);
    },
  });

  const onSubmit: SubmitHandler<ThresholdFormValues> = async (data) => {
    const objectToValidate = {
      jobName: data.jobName,
      jobDescription: data.jobDescription,
      configParams: {
        ...data.configParams,
        inferenceResultsArtifactUrl: selectedOrchestration?.orchestrationResultArtifactUrlPattern,
        cellRules,
      },
    };
    const validated = await validationSchema[0]
      .validate(objectToValidate)
      .then(() => {
        return true;
      })
      .catch((e) => {
        return false;
      });

    if (!validated) {
      return;
    }

    runMultiplexThresholdMutation.mutate({
      ...casesParams,
      ...objectToValidate,
    });

    enqueueSnackbar({
      variant: 'success',
      message: (
        <Grid container>
          <Grid item>
            <Typography>Waiting for Multiplex Threshold to start</Typography>
          </Grid>
          <Grid item>
            <CircularProgress sx={{ marginLeft: 10 }} color="inherit" size={20} />
          </Grid>
        </Grid>
      ),
      key: SNACK_BAR_KEY_RUN_MULTIPLEX_THRESHOLD,
      autoHideDuration: null,
    });

    onClose();
  };

  const validateNotMoreThenOneInferenceUrl = useCallback(() => {
    const uniqOrchestrationIds = uniq(
      map(flatMap(flatMap(values(selectedOrchestrations), values)), 'orchestration.orchestrationId')
    );
    return uniqOrchestrationIds.length <= 1;
  }, [selectedOrchestrations]);

  const validateThresholdParams = async () => {
    // this is not in the validation schema because it is more complex
    if (!validateNotMoreThenOneInferenceUrl()) {
      setIsStepFailed((prev) => ({
        ...prev,
        [0]: true,
      }));
      return false;
    }

    checkValidationAndSetIsStepFailed(0, {
      configParams: {
        ...watch('configParams'),
        inferenceResultsArtifactUrl: selectedOrchestration?.orchestrationResultArtifactUrlPattern,
        cellRules,
      },
    });
  };

  const steps = [
    {
      label: 'Threshold Configuration',
      content: (
        <Grid container direction="column" spacing={2}>
          <Grid item>
            {!validateNotMoreThenOneInferenceUrl() && (
              <Typography color="error">Only one Inference Results Artifact URL is allowed</Typography>
            )}
            {get(errors, 'configParams.inferenceResultsArtifactUrl') && (
              <Typography color="error">
                {humanize(get(errors, 'configParams.inferenceResultsArtifactUrl')?.message)}
              </Typography>
            )}
            <InferenceModelsForSlides
              studyId={studyId}
              slides={slides}
              modelsType={getModelsTypeByModelInferences(inferenceArtifacts)}
              inferenceModels={inferenceArtifacts}
              isLoading={isInferenceArtifactsLoading || isInferenceArtifactsFetching}
              selectedOrchestrations={selectedOrchestrations}
              setSelectedOrchestrations={setSelectedOrchestrations}
            />
          </Grid>
          <Grid item>
            <FormControl sx={{ width: '100%' }} error={Boolean(get(errors, 'configParams.thresholdMethod'))}>
              <InputLabel>Threshold Method</InputLabel>
              <Controller
                control={control}
                name="configParams.thresholdMethod"
                render={({ field: { onChange } }) => (
                  <Select
                    label="Threshold Method"
                    {...register('configParams.thresholdMethod')}
                    value={watch('configParams.thresholdMethod')}
                    onChange={onChange}
                    required
                    error={Boolean(get(errors, 'configParams.thresholdMethod'))}
                  >
                    {map(values(ThresholdMethod), (method) => (
                      <MenuItem key={method} value={method}>
                        {method}
                      </MenuItem>
                    ))}
                  </Select>
                )}
              />
              <FormHelperText>{get(errors, 'configParams.thresholdMethod')?.message}</FormHelperText>
            </FormControl>
          </Grid>
          {(watch('configParams.thresholdMethod') === ThresholdMethod.BALANCED_ACCURACY ||
            watch('configParams.thresholdMethod') === ThresholdMethod.F1) && (
            <>
              <Grid item>
                <CellRulesPanelSelect
                  selectedPanelId={selectedPanelId}
                  onSelectPanel={setSelectedPanelId}
                  studyId={studyId}
                  isRequired
                  showError={Boolean(get(errors, 'configParams.cellRules'))}
                />
              </Grid>
              <Grid item>
                <Controller
                  control={control}
                  name="configParams.annotationAssignmentIds"
                  render={({ field: { onChange, value } }) => (
                    <AnnotationAssignmentAutocomplete
                      casesParams={casesParams}
                      slideStainType={MULTIPLEX_STAIN_ID}
                      multiple
                      selectedValue={value || []}
                      onChange={(event, newValue) => {
                        onChange(map(newValue, 'annotationAssignmentId'));
                      }}
                      textFieldProps={{
                        error: Boolean(get(errors, 'configParams.annotationAssignmentIds')),
                        helperText: humanize(get(errors, 'configParams.annotationAssignmentIds.message')),
                        required: true,
                      }}
                    />
                  )}
                />
              </Grid>
            </>
          )}
          <Grid item>
            <Controller
              control={control}
              name="configParams.thresholdPerMarker"
              render={({ field }) => (
                <FormControlLabel
                  control={<Checkbox checked={field.value} onChange={(e) => field.onChange(e.target.checked)} />}
                  label="Threshold Per Marker"
                />
              )}
            />
          </Grid>
        </Grid>
      ),
      onNextOrBackClick: validateThresholdParams,
    },
    {
      label: 'Job Name and Description',
      subLabel: activeStep > 0 && watch('jobName'),
      optional: true,
      content: (
        <Grid container direction="column" spacing={2}>
          <Grid item>
            <Controller
              control={control}
              name="jobName"
              render={({ field: { onChange } }) => (
                <TextField
                  label="Job Name"
                  {...register('jobName')}
                  onChange={onChange}
                  placeholder="Type Here"
                  error={Boolean(errors['jobName'])}
                  helperText={humanize(errors['jobName']?.message)}
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="jobDescription"
              render={({ field: { onChange } }) => (
                <TextField
                  label="Job Description"
                  {...register('jobDescription')}
                  onChange={onChange}
                  placeholder="Type Here"
                  error={Boolean(errors['jobDescription'])}
                  helperText={humanize(errors['jobDescription']?.message)}
                  multiline
                  minRows={4}
                />
              )}
            />
          </Grid>
        </Grid>
      ),
    },
  ];

  return (
    <>
      <PlatformStepper
        handleSubmit={handleSubmit(onSubmit)}
        steps={steps}
        setActiveStepForValidation={setActiveStep}
        isStepFailed={isStepFailed}
      />
      <Accordion>
        <AccordionSummary expandIcon={<ExpandIcon />}>Threshold Params Summary (JSON)</AccordionSummary>
        <AccordionDetails>
          <Typography component="pre">
            {JSON.stringify(
              {
                jobName: watch('jobName'),
                jobDescription: watch('jobDescription'),
                configParams: {
                  ...watch('configParams'),
                  inferenceResultsArtifactUrl: selectedOrchestration?.orchestrationResultArtifactUrlPattern,
                  cellRules,
                },
              },
              null,
              2
            )}
          </Typography>
        </AccordionDetails>
      </Accordion>
    </>
  );
};

export interface ThresholdFormValues {
  jobName: string;
  jobDescription: string;
  configParams: MultiplexThresholdConfig;
}

const validationSchema = [
  yup.object({
    configParams: yup.object({
      inferenceResultsArtifactUrl: yup.string().required('Inference Results Artifact URL is required'),
      thresholdMethod: yup
        .mixed()
        // when platform supports rest of threshold methods, replace the array with values(ThresholdMethod)
        .oneOf([ThresholdMethod.AUTO], 'Threshold method is not supported')
        .required('Threshold method is required'),
      cellRules: yup.mixed().when('thresholdMethod', {
        is: (thresholdMethod: ThresholdMethod) =>
          thresholdMethod === ThresholdMethod.BALANCED_ACCURACY || thresholdMethod === ThresholdMethod.F1,
        then: yup.mixed().required('Cell rules are required for balanced accuracy and f1 threshold methods'),
        otherwise: yup.mixed().nullable(),
      }),
      annotationAssignmentIds: yup.array().when('thresholdMethod', {
        is: (thresholdMethod: ThresholdMethod) =>
          thresholdMethod === ThresholdMethod.BALANCED_ACCURACY || thresholdMethod === ThresholdMethod.F1,
        then: yup
          .array()
          .min(1, 'Annotation assignments are required for balanced accuracy and f1 threshold methods')
          .required('Annotation assignments are required for balanced accuracy and f1 threshold methods'),
        otherwise: yup.array().nullable(),
      }),
    }),
  }),
  yup.object({}),
];
