import { PartialCohort } from 'interfaces/cohort_old';
import React from 'react';

import { Skeleton } from '@mui/material';
import { BoxAndWiskers, BoxPlotController, Violin, ViolinController } from '@sgratzl/chartjs-chart-boxplot';
import {
  BarElement,
  CategoryScale,
  Chart as ChartJs,
  ChartOptions,
  Legend,
  LinearScale,
  PointElement,
  Tooltip,
} from 'chart.js';
import { ChartType } from 'interfaces/chart';
import { filter, includes, isEmpty, map } from 'lodash';
import { DistanceBasedOptions } from 'utils/features';
import { CategoricalChartKey, ChartKey, CountBy, NumericalChartKey } from './chart.util';
import BoxChart from './charts/BoxChart';
import DistanceBasedChart from './charts/DistanceBasedChart';
import HistogramChart from './charts/HistogramChart';
import KaplanMeierChart from './charts/KaplanMeierChart';
import PieChart from './charts/PieChart';
import ScatterChart from './charts/ScatterChart';
import { SurvivalAnalysisType } from './charts/kaplanMeier.util';

type CommonChartProps = {
  chartIndex?: number;
  chartType?: ChartType;
  cohorts: PartialCohort[];
  horizontalKey?: ChartKey;
  loading?: boolean;
  layout?: 'small' | 'medium' | 'large';
  filteredCaseIds?: number[];
  preview?: boolean;
  featureOptions?: object;
};

type GenericChartProps = CommonChartProps & {
  verticalKey?: ChartKey;
  chartType?: Omit<ChartType, ChartType.DistanceBased>;
  categoricalKey?: ChartKey;
  kaplanMeierParameter?: SurvivalAnalysisType;
  countBy?: CountBy;
  options?: ChartOptions<'doughnut'> | ChartOptions<'bar'> | ChartOptions<'scatter'> | ChartOptions<'boxplot'>;
};

type DistanceBasedChartProps = CommonChartProps & {
  chartType: ChartType.DistanceBased;
  featureOptions?: DistanceBasedOptions;
};

type Props = GenericChartProps | DistanceBasedChartProps;

ChartJs.register(
  CategoryScale,
  LinearScale,
  BarElement,
  PointElement,
  Violin,
  Tooltip,
  Legend,
  BoxPlotController,
  ViolinController,
  BoxAndWiskers
);

// Used to type check the props
const isDistanceBasedChart = (props: Props): props is DistanceBasedChartProps =>
  props.chartType === ChartType.DistanceBased;

const Chart: React.FunctionComponent<React.PropsWithChildren<Props>> = (props) => {
  const { cohorts: allCohorts, filteredCaseIds } = props;
  const cohorts = filteredCaseIds
    ? filter(
        map(allCohorts, (cohort) => ({
          ...cohort,
          procedures: filter(cohort?.procedures, ({ id }) => includes(filteredCaseIds, id)),
        })),
        ({ procedures }) => !isEmpty(procedures)
      )
    : allCohorts;

  const chosenChart = isDistanceBasedChart(props) ? (
    <DistanceBasedChart
      cohorts={cohorts}
      horizontalKeyName={props.horizontalKey?.name}
      chartIndex={props.chartIndex}
      {...props.featureOptions}
    />
  ) : props.chartType === ChartType.Scatter ? (
    <ScatterChart cohorts={cohorts} verticalKey={props.verticalKey} horizontalKey={props.horizontalKey} />
  ) : props.chartType === ChartType.KaplanMeier ? (
    <KaplanMeierChart cohorts={cohorts} parameter={props.kaplanMeierParameter} />
  ) : props.chartType === ChartType.Histogram ? (
    <HistogramChart cohorts={cohorts} horizontalKey={props.horizontalKey} countBy={props.countBy} />
  ) : props.chartType === ChartType.Box ? (
    <BoxChart
      cohorts={cohorts}
      categoricalKey={props.categoricalKey as CategoricalChartKey}
      verticalKey={props.verticalKey as NumericalChartKey}
    />
  ) : props.chartType === ChartType.Pie ? (
    <PieChart
      cohorts={cohorts}
      categoricalKey={props.categoricalKey as CategoricalChartKey}
      countBy={props.countBy}
      preview={props.preview}
      options={props.options as ChartOptions<'doughnut'>}
    />
  ) : (
    <Skeleton variant="rectangular" />
  );
  if (props.loading) {
    return (
      <Skeleton
        variant="rectangular"
        sx={{
          width: '100%',
          height: props.layout === 'small' ? 250 : 300,
        }}
      />
    );
  }

  return chosenChart;
};

export default Chart;
