import { ChartDataset, ChartOptions } from 'chart.js';
import { CohortWithSelectedFeatures } from 'interfaces/cohort_old';
import { curry, map, sortBy } from 'lodash';
import { compose } from 'redux';
import { humanize } from 'utils/helpers';
import {
  cohortToClinicalData,
  Dataset,
  formatNumber,
  getHorizontalScaleOptions,
  getVerticalScaleOptions,
} from '../chart.util';
import { cohortToScatterDataset } from './scatter.util';

const daysInMonth = 30;

export type SurvivalAnalysisType = 'overallSurvival' | 'pfs';

export const survivalAnalysisTypes: SurvivalAnalysisType[] = ['overallSurvival', 'pfs'];

const kaplanDataGenerator = curry(
  (selectedParameter: SurvivalAnalysisType, cohort: CohortWithSelectedFeatures): ChartDataset<'scatter'>['data'] => {
    const clinicalDatas = sortBy(cohortToClinicalData(cohort), `${selectedParameter}Days`);

    const numberOfPatients = clinicalDatas.length;

    const data = map(clinicalDatas, `${selectedParameter}Days`);
    const dataEvents = map(clinicalDatas, `${selectedParameter}Event`);

    // generate x,y points for the kaplan meier curve
    // x - time
    // y - survival rate

    let remainingPatients = numberOfPatients;
    let remainingSurvivors = numberOfPatients;

    const points: { x: number; y: number }[] = [];

    for (let i = 0; i < numberOfPatients; i++) {
      const time = data[i] / daysInMonth;
      const event = dataEvents[i];

      if (event) {
        points.push({ x: time, y: remainingSurvivors / remainingPatients });

        remainingSurvivors -= 1;
      }

      const survivalRate = remainingSurvivors / remainingPatients;

      points.push({ x: time, y: survivalRate });
    }

    return points;
  }
);

export const cohortToKaplanMeierDataset: (
  selectedParameter: SurvivalAnalysisType
) => (cohort: CohortWithSelectedFeatures, colorIndex: number) => Dataset = compose(
  cohortToScatterDataset,
  kaplanDataGenerator
);

export const kaplanMeierOptions = (selectedParameter: SurvivalAnalysisType): ChartOptions<'scatter'> => ({
  plugins: {
    legend: {
      position: 'bottom',
      labels: {
        font: {
          size: 9,
        },
        boxHeight: 10,
        boxWidth: 10,
      },
    },
    tooltip: {
      position: 'nearest',
      callbacks: {
        label: (context) => {
          const label = context.dataset?.label || '';
          const x = context.parsed.x;
          const y = context.parsed.y;
          return `${label} (${formatNumber(x)} Months, ${formatNumber(y)} Survival Rate)`;
        },
      },
    },
  },
  elements: {
    point: {
      pointStyle: 'cross',
      radius: 8,
    },
  },
  showLine: true,
  scales: {
    x: getHorizontalScaleOptions(`${humanize(selectedParameter)} Months (30d)`),
    y: { ...getVerticalScaleOptions('Survival Rate'), beginAtZero: true },
  },
});
