import {
  Box,
  Card,
  CardContent,
  Collapse,
  FormControl,
  FormControlLabel,
  FormLabel,
  Grid,
  Radio,
  RadioGroup,
} from '@mui/material';
import { DataGrid, GridColDef, GridRowSelectionModel, GridValueGetterParams } from '@mui/x-data-grid';
import { useQuery } from '@tanstack/react-query';
import { fetchStudies } from 'api/study';
import Loader from 'components/Loader';
import { ChartType } from 'interfaces/chart';
import CohortWithQuery from 'interfaces/cohort_old';
import {
  castArray,
  compact,
  filter,
  first,
  flatMap,
  get,
  includes,
  keys,
  last,
  map,
  noop,
  orderBy,
  sample,
  uniq,
} from 'lodash';
import numeral from 'numeral';
import React, { useEffect } from 'react';
import { useDispatch } from 'react-redux';
import { useAppSelector } from 'redux/hooks';
import { setCohorts } from 'redux/modules/chartSlice';
import { JsonParam, QueryParamConfig, StringParam, useQueryParam, withDefault } from 'use-query-params';
import { studyToCohort } from 'utils/cohort.util';
import { useGetNameOverrideOrDisplayNameWithContext } from 'utils/features/contextHooks';
import { useCurrentLabId } from 'utils/useCurrentLab';
import { useFilledCohorts } from '.';
import TableCellTextContent from '../atoms/TableCellTextContent';
import AddChart from './AddChart';
import CohortSelect from './CohortSelect/CohortSelect';
import ControlledChart, { ControlledChartOptions } from './ControlledChart';
import KeySelect from './KeySelect/KeySelect';
import { ChartKeyType, allCategoricalKeys, convertToChartKey, convertToChartKeys, flattenCohort } from './chart.util';
import { SurvivalAnalysisType, survivalAnalysisTypes } from './charts/kaplanMeier.util';

interface Props {
  selectedCohortIds?: string[];
}

const ExploratoryAnalysis: React.FunctionComponent<React.PropsWithChildren<Props>> = ({ selectedCohortIds }) => {
  const cohorts = useAppSelector((state) => state.charts?.cohorts);

  const dispatch = useDispatch();

  const { labId } = useCurrentLabId();

  const { data: studies, isFetching: isLoadingStudies } = useQuery(['studies', labId], fetchStudies(labId));

  const selectedCohorts = filter(cohorts, (cohort: CohortWithQuery) => includes(selectedCohortIds, cohort.id));

  const [filledCohorts, isLoading] = useFilledCohorts(selectedCohorts);

  useEffect(() => {
    if (isLoadingStudies) {
      return;
    }
    const newCohorts = compact(map(studies, (study) => studyToCohort({ labId, ...study })));

    dispatch(setCohorts(newCohorts));
  }, [labId, JSON.stringify(studies), isLoadingStudies]);

  const [existingChartsMap, setExistingChartsMap] = useQueryParam('charts', ChartConfigurationParam);
  const [survivalAnalysisType, setSurvivalAnalysisType] = useQueryParam<SurvivalAnalysisType>('survivalAnalysisType');

  const [analysisCompareOption, setAnalysisCompareOption] = useQueryParam(
    'analysisCompareOption',
    withDefault(StringParam, 'survival')
  );

  const updateSurvivalAnalysisType = (type: string) => {
    setSurvivalAnalysisType(type as SurvivalAnalysisType, 'replaceIn');
  };

  const addChart = (chart: ChartDefinition) => {
    const { id, chartOptions } = chart;
    setExistingChartsMap({ ...existingChartsMap, [id]: chartOptions });
  };

  const removeChart = (id: number) => {
    const newCharts = { ...existingChartsMap };
    delete newCharts[id];
    setExistingChartsMap(newCharts);
  };

  const features = flatMap(filledCohorts, flattenCohort);
  const stringKeys = uniq(flatMap(features, (feature) => keys(feature)));
  const availableKeys = convertToChartKeys(stringKeys, ChartKeyType.Numerical);
  const cohort = first(filledCohorts);

  const { survivalAnalysis } = cohort || {};

  const records = get(survivalAnalysis, survivalAnalysisType as SurvivalAnalysisType);

  const { getNameOverrideOrDisplayNameWithContext, isLoadingFormatterData } =
    useGetNameOverrideOrDisplayNameWithContext(true);
  const rows = !isLoadingFormatterData
    ? map(records, (record) => ({
        id: record.feature,
        ...record,
        feature: getNameOverrideOrDisplayNameWithContext(record.feature),
      }))
    : [];

  const defaultTargetFeature = get(first(orderBy(rows, ['logrankPValue', 'asc'])), 'id');

  const [selectionModel, setSelectionModel] = React.useState<GridRowSelectionModel>(castArray(defaultTargetFeature));

  useEffect(() => {
    if (!isLoading && !isLoadingStudies && defaultTargetFeature) {
      setSelectionModel(castArray(defaultTargetFeature));
    }
  }, [isLoading, defaultTargetFeature, isLoadingStudies]);

  const targetFeatureId = first(selectionModel);

  const targetFeatureKey: string = targetFeatureId?.toString();

  const existingCharts: ChartDefinition[] = map(existingChartsMap, (chartOptions, id) => ({
    id: Number(id),
    chartOptions: {
      ...chartOptions,
      splittingKey: chartOptions.userSelectedSplittingKey
        ? chartOptions.splittingKey
        : convertToChartKey(targetFeatureKey, ChartKeyType.Numerical),
    },
  }));

  const formatNumber = (value: number) => {
    return numeral(value).format('0,0[.]0000');
  };

  const createColumn = (field: string, headerName: string) => ({
    field,
    headerName,
    width: 150,
    renderCell: (params: GridValueGetterParams) =>
      params && <TableCellTextContent text={params?.value && formatNumber(params.value).toString()} />,
  });

  const columns: GridColDef[] = [
    {
      field: 'radiobutton',
      headerName: 'Target',
      width: 100,
      sortable: false,
      renderCell: (params) => <Radio checked={params.id === targetFeatureId} value={params.id} />,
    },
    {
      field: 'feature',
      headerName: 'Feature',
      width: 350,
      renderCell: (params: GridValueGetterParams) =>
        params && <TableCellTextContent text={params?.value?.toString()} />,
    },
    createColumn('logrankPValue', 'Log-Rank p-value threshold=0.5'),
    createColumn('coxPValue', 'Cox p-value'),
    createColumn('coxHazard', 'Cox Hazard Ratio'),
    createColumn('concordanceIndex', 'Concordance Index'),
  ];

  return isLoadingStudies ? (
    <Loader />
  ) : (
    <Grid item container spacing={2} xs={12} direction="column">
      <Grid item>
        <Card>
          <CardContent>
            <Grid container alignItems="center" pl={1}>
              <Grid item xs={6} md={7}>
                <FormControl fullWidth>
                  <FormLabel id="choose-compare-option-label">Analysis Type</FormLabel>
                  <RadioGroup
                    row
                    aria-labelledby="choose-compare-option-label"
                    name="choose-compare-option-group"
                    value={analysisCompareOption}
                    onChange={(e: React.ChangeEvent<HTMLInputElement>) => setAnalysisCompareOption(e.target.value)}
                  >
                    <FormControlLabel value="cohort" control={<Radio />} label="Compare to another Cohort" />
                    <FormControlLabel value="targetFeature" control={<Radio />} label="Target Feature" />
                    <FormControlLabel value="survival" control={<Radio />} label="Survival Analysis" />
                  </RadioGroup>
                </FormControl>
              </Grid>
              <Grid item xs={6} md={4}>
                {analysisCompareOption === 'survival' ? (
                  <KeySelect
                    keys={survivalAnalysisTypes}
                    name={'Survival Analysis Baseline'}
                    updateSelectedKey={updateSurvivalAnalysisType}
                    selectedKey={survivalAnalysisType}
                  />
                ) : analysisCompareOption === 'cohort' ? (
                  <CohortSelect
                    cohorts={cohorts}
                    isLoading={isLoading}
                    selectedCohortIds={[]}
                    updateSelectedCohortIds={noop}
                  />
                ) : analysisCompareOption === 'targetFeature' ? (
                  <KeySelect keys={stringKeys} selectedKey={''} updateSelectedKey={noop} name="Target Feature" />
                ) : null}
              </Grid>
            </Grid>
          </CardContent>
          <CardContent>
            <Collapse in={Boolean(survivalAnalysisType)}>
              <Box sx={{ width: '100%', height: 320 }}>
                <DataGrid
                  loading={isLoading || isLoadingFormatterData}
                  rows={rows}
                  columns={columns}
                  pageSizeOptions={[5]}
                  rowSelectionModel={selectionModel}
                  onRowSelectionModelChange={setSelectionModel}
                  initialState={{
                    sorting: {
                      sortModel: [{ field: 'logrankPValue', sort: 'asc' }],
                    },
                  }}
                />
              </Box>
            </Collapse>
          </CardContent>
        </Card>
      </Grid>
      <Grid item>
        <Collapse in={Boolean(survivalAnalysisType)}>
          <Grid item container xs={12} spacing={1}>
            {map(existingCharts, (chart: ChartDefinition) => (
              <Grid item xs={6}>
                <ControlledChart
                  id={chart.id}
                  key={chart.id}
                  availableKeys={availableKeys}
                  loading={isLoading}
                  cohorts={filledCohorts}
                  onRemove={() => removeChart(chart.id)}
                  chartOptions={chart.chartOptions}
                  onChangeChartOptions={(chartOptions) =>
                    setExistingChartsMap({ ...existingChartsMap, [chart.id]: chartOptions })
                  }
                />
              </Grid>
            ))}
            <Grid item xs={6} sx={{ display: 'flex', minHeight: 400 }}>
              <AddChart
                cohorts={filledCohorts}
                onClick={(type: ChartType) =>
                  addChart({
                    id: last(existingCharts)?.id + 1 || 0,
                    chartOptions: {
                      type,
                      splittingKey:
                        type === ChartType.KaplanMeier
                          ? convertToChartKey(targetFeatureKey)
                          : convertToChartKey('None'),
                      userSelectedSplittingKey: type !== ChartType.KaplanMeier,
                      kaplanMeierParameter: type === ChartType.KaplanMeier ? survivalAnalysisType : null,
                      ...(includes([ChartType.Histogram, ChartType.Scatter], type) && {
                        horizontalKey: convertToChartKey(targetFeatureKey, ChartKeyType.Numerical),
                        verticalKey: sample(availableKeys),
                      }),
                      ...(type === ChartType.Box && {
                        horizontalKey: convertToChartKey(targetFeatureKey, ChartKeyType.Numerical),
                        verticalKey: convertToChartKey(targetFeatureKey, ChartKeyType.Numerical),
                        categoricalKey: sample(convertToChartKeys(allCategoricalKeys, ChartKeyType.Categorical)),
                      }),
                    },
                  })
                }
              />
            </Grid>
          </Grid>
        </Collapse>
      </Grid>
    </Grid>
  );
};

const ChartConfigurationParam: QueryParamConfig<Record<number, ControlledChartOptions>> = withDefault(JsonParam, {
  1: { type: 'kaplan-meier', kaplanMeierParameter: 'pfs' },
  2: { type: 'kaplan-meier', kaplanMeierParameter: 'overallSurvival' },
}) as QueryParamConfig<Record<number, ControlledChartOptions>>;

interface ChartDefinition {
  id: number;
  chartOptions: ControlledChartOptions;
}

export default ExploratoryAnalysis;
