import { Skeleton } from '@mui/material';
import { Image } from 'image-js';
import _ from 'lodash';
import React, { useEffect, useRef, useState } from 'react';
import { MAX_UINT16, MAX_UINT8 } from 'utils/constants';

const EIGHT_BIT_WHITE_RGBA: [number, number, number, number] = [MAX_UINT8, MAX_UINT8, MAX_UINT8, MAX_UINT8];
const EIGHT_BIT_MAX_WHITE: number = MAX_UINT8;
const EIGHT_BIT_MIN_BLACK: number = 0;

const SIXTEEN_BIT_WHITE_RGBA: [number, number, number, number] = [MAX_UINT16, MAX_UINT16, MAX_UINT16, MAX_UINT16];
const SIXTEEN_BIT_MAX_WHITE: number = MAX_UINT16;
const SIXTEEN_BIT_MIN_BLACK: number = 0;

export const adjustmentSettingsPresets: Record<
  string,
  {
    [bitDepth: number]: {
      blackPoint: number;
      whitePoint: number;
      colorAdjustment: [number, number, number, number];
    };
  }
> = {
  // these values are based on trial and error and are not based on any scientific data
  // they are used to color correct multiplex images
  // the blackPoint value is for removing noise from the image background
  // the whitePoint value is for increasing the brightness of the image
  multiplex: {
    8: {
      blackPoint: 3,
      whitePoint: 200,
      colorAdjustment: [0, 0, MAX_UINT8, MAX_UINT8],
    },
    16: {
      blackPoint: 200,
      whitePoint: 5000,
      colorAdjustment: [0, 0, MAX_UINT16, MAX_UINT16],
    },
  },
  default: {
    8: {
      blackPoint: EIGHT_BIT_MIN_BLACK,
      whitePoint: EIGHT_BIT_MAX_WHITE,
      colorAdjustment: EIGHT_BIT_WHITE_RGBA,
    },
    16: {
      blackPoint: SIXTEEN_BIT_MIN_BLACK,
      whitePoint: SIXTEEN_BIT_MAX_WHITE,
      colorAdjustment: SIXTEEN_BIT_WHITE_RGBA,
    },
  },
};

interface ImageWithAdjustedLevelsProps {
  src: string;
  bitDepth?: 8 | 16;
  blackPoint?: number;
  whitePoint?: number;
  colorAdjustment?: [number, number, number, number];
  lazyLoad?: boolean;
  onError?: () => void;
}

function ImageWithAdjustedLevels({
  src,
  bitDepth = 8,
  blackPoint,
  whitePoint,
  colorAdjustment,
  lazyLoad = true,
  onError = () => {},
}: ImageWithAdjustedLevelsProps) {
  const defaultSettings = adjustmentSettingsPresets.default[bitDepth];
  blackPoint = blackPoint ?? defaultSettings.blackPoint;
  whitePoint = whitePoint ?? defaultSettings.whitePoint;
  colorAdjustment = colorAdjustment ?? defaultSettings.colorAdjustment;
  const canvasRef = useRef<HTMLCanvasElement>(null);
  const [isLoading, setIsLoading] = useState(true);
  const [bgColor, setBgColor] = useState('#ffffff');
  const [image, setImage] = useState<Image | undefined>(undefined);

  const observerRef = useRef<IntersectionObserver | null>(null);

  useEffect(() => {
    setIsLoading(true);
    if (!lazyLoad) {
      Image.load(src)
        .then((img) => {
          setImage(img);
        })
        .catch(onError);
    } else {
      if (observerRef.current) observerRef.current.disconnect();
      observerRef.current = new IntersectionObserver((entries) => {
        _.forEach(entries, (entry) => {
          if (entry?.isIntersecting) {
            Image.load(src)
              .then((img) => {
                setImage(img);
              })
              .catch(onError);
            observerRef.current.unobserve(entry.target);
          }
        });
      });
      const canvas = canvasRef.current;
      observerRef.current.observe(canvas);
    }

    return () => {
      if (observerRef.current) observerRef.current.disconnect();
    };
  }, [src]);

  useEffect(() => {
    if (!image) return;
    const canvas = canvasRef.current;
    if (!canvas) return;
    const ctx = canvas.getContext('2d');
    canvas.width = image.width;
    canvas.height = image.height;
    const maxVal = Math.pow(2, bitDepth) - 1;
    const imageData = new ImageData(image.width, image.height);
    for (let i = 0; i < image.data.length; i += image.channels) {
      // in case of grayscale images, we need to duplicate the pixel values for red, green, and blue channels
      const pixel =
        image.channels === 1
          ? { r: image.data[i], g: image.data[i], b: image.data[i] }
          : { r: image.data[i], g: image.data[i + 1], b: image.data[i + 2] };
      // blackPoint and whitePoint are used to adjust the levels of the image
      const levelAdjustedPixel = {
        r: (pixel.r - blackPoint) * (maxVal / (whitePoint - blackPoint)),
        g: (pixel.g - blackPoint) * (maxVal / (whitePoint - blackPoint)),
        b: (pixel.b - blackPoint) * (maxVal / (whitePoint - blackPoint)),
      };
      // colorAdjustment is used to adjust the color of the image
      const colorCorrectedPixel = {
        r: Math.max(0, Math.min(maxVal, levelAdjustedPixel.r) - (maxVal - colorAdjustment[0])),
        g: Math.max(0, Math.min(maxVal, levelAdjustedPixel.g) - (maxVal - colorAdjustment[1])),
        b: Math.max(0, Math.min(maxVal, levelAdjustedPixel.b) - (maxVal - colorAdjustment[2])),
        a: colorAdjustment[3],
      };
      // convert the pixel values to 8-bit values as ImageData only accepts 8-bit values
      const eightBitPixel = {
        r: (colorCorrectedPixel.r * MAX_UINT8) / maxVal,
        g: (colorCorrectedPixel.g * MAX_UINT8) / maxVal,
        b: (colorCorrectedPixel.b * MAX_UINT8) / maxVal,
        a: (colorCorrectedPixel.a * MAX_UINT8) / maxVal,
      };
      const targetPixelIndex = (i / image.channels) * 4;
      imageData.data[targetPixelIndex] = eightBitPixel.r;
      imageData.data[targetPixelIndex + 1] = eightBitPixel.g;
      imageData.data[targetPixelIndex + 2] = eightBitPixel.b;
      imageData.data[targetPixelIndex + 3] = eightBitPixel.a;
    }
    ctx.putImageData(imageData, 0, 0);
    const topLeftPixel = ctx.getImageData(0, 0, 1, 1);
    const topLeftPixelColor = `rgb(${topLeftPixel.data[0]}, ${topLeftPixel.data[1]}, ${topLeftPixel.data[2]})`;
    setBgColor(topLeftPixelColor);
    setIsLoading(false);
  }, [image]);

  return (
    <div
      style={{
        position: 'absolute',
        width: '100%',
        height: '100%',
        backgroundColor: isLoading ? 'transparent' : bgColor,
      }}
    >
      <canvas
        data-testid="carousel-item-thumbnail-img"
        data-src={src}
        ref={canvasRef}
        style={{
          position: 'absolute',
          top: '50%',
          left: '50%',
          transform: 'translate(-50%, -50%)',
          maxWidth: '100%',
          maxHeight: '100%',
        }}
      />
      {isLoading && (
        <Skeleton
          variant="rectangular"
          animation="wave"
          sx={{ position: 'absolute', top: 0, left: 0, width: '100%', height: '100%' }}
        />
      )}
    </div>
  );
}

export default ImageWithAdjustedLevels;
