import { GridRowId } from '@mui/x-data-grid';
import { filter, includes } from 'lodash';
import React from 'react';

export interface RowSelectionContextProps {
  allSelected: (numRows: number) => boolean;
  noneSelected: (numRows: number) => boolean;
  someSelected: (numRows: number) => boolean;
  selectedRows: GridRowId[];
  omittedRows: GridRowId[];
  toggleSelection: (row: GridRowId) => void;
  selectionMode: 'select' | 'omit';
  setSelectionMode: (mode: 'select' | 'omit') => void;
  clearSelection: () => void;
  selectAll: () => void;
  isRowSelected: (row: GridRowId) => boolean;
  selectedRowsCount: (numRows: number) => number;
}

interface RowSelectionContextProviderProps {
  children?: React.ReactNode;
}

const RowSelectionContext = React.createContext<RowSelectionContextProps | undefined>(undefined);

const toggleRowIdSelection = (rowId: GridRowId, rows: GridRowId[]) =>
  includes(rows, rowId) ? filter(rows, (r) => r !== rowId) : [...rows, rowId];

export const RowSelectionContextProvider: React.FunctionComponent<
  React.PropsWithChildren<RowSelectionContextProviderProps>
> = ({ children }: RowSelectionContextProviderProps) => {
  const [selectionMode, setSelectionMode] = React.useState<'select' | 'omit'>('select');
  const [selectedRows, setSelectedRows] = React.useState<GridRowId[]>([]);
  const [omittedRows, setOmittedRows] = React.useState<GridRowId[]>([]);

  const contextValue: RowSelectionContextProps = React.useMemo(
    () => ({
      allSelected: (numRows: number) =>
        numRows > 0 && (selectionMode === 'select' ? selectedRows.length === numRows : omittedRows.length === 0),
      noneSelected: (numRows: number) =>
        numRows > 0 && (selectionMode === 'select' ? selectedRows.length === 0 : omittedRows.length === numRows),
      someSelected: (numRows: number) =>
        numRows > 0 && (selectionMode === 'select' ? selectedRows.length > 0 : omittedRows.length < numRows),
      toggleSelection: (row: GridRowId) => {
        if (selectionMode === 'select') {
          setSelectedRows((prev) => toggleRowIdSelection(row, prev));
        } else {
          setOmittedRows((prev) => toggleRowIdSelection(row, prev));
        }
      },
      selectedRows: selectedRows.sort(),
      omittedRows: omittedRows.sort(),
      selectedRowsCount: (numRows: number) =>
        selectionMode === 'select' ? selectedRows.length : numRows - omittedRows.length,
      selectionMode,
      setSelectionMode,
      selectAll: () => {
        setSelectionMode('omit');
        setOmittedRows([]);
      },
      clearSelection: () => {
        setSelectionMode('select');
        setSelectedRows([]);
      },
      isRowSelected: (row: GridRowId) => {
        if (selectionMode === 'select') {
          return includes(selectedRows, row);
        } else {
          return !includes(omittedRows, row);
        }
      },
    }),
    [selectedRows, setSelectedRows, omittedRows, setOmittedRows, selectionMode, setSelectionMode]
  );

  return <RowSelectionContext.Provider value={contextValue}>{children}</RowSelectionContext.Provider>;
};

export const useRowSelectionContext = () => {
  const context = React.useContext(RowSelectionContext as React.Context<RowSelectionContextProps>);
  if (context === undefined) {
    throw new Error('useRowSelectionContext must be used within a RowSelectionContextProvider');
  }
  return context;
};
