import React, { useCallback, useEffect, useMemo, useRef } from "react";
import { useFrame, useLoader, useThree } from "@react-three/fiber";
import { ShaderPass } from "three/examples/jsm/postprocessing/ShaderPass";
import * as THREE from "three";
import { EffectComposer } from "three/examples/jsm/postprocessing/EffectComposer";
import { TextureLoader } from "three/src/loaders/TextureLoader";
import { getDepthMapUrl } from "../../utils/image.ts";
import { useRecoilState } from "recoil";
import { lastImageEffectConfigState } from "../../states/effectState.ts";
import { buildControlsIfAdmin } from "../../utils/imageEffect.ts";

interface DepthMapWaterEffectProps {
  composer: EffectComposer | null;
  texture: THREE.Texture;
  depthMapTextureURL: string;
  params?: any;
}

const CombinedDepthEffect: React.FC<DepthMapWaterEffectProps> = ({
  composer,
  texture,
  depthMapTextureURL,
  params,
}) => {
  const { size } = useThree();
  const mouseRef = useRef(new THREE.Vector2());
  const timeRef = useRef(0);
  const depthMapTexture = useLoader(TextureLoader, getDepthMapUrl(depthMapTextureURL));
  const [, setShaderParams] = useRecoilState(lastImageEffectConfigState);

  const shaderParams = params
    ? params
    : buildControlsIfAdmin({
        heatDistortionIntensity: { value: 0.05, min: 0, max: 1, step: 0.01 },
        heatThreshold: { value: 0.15, min: 0, max: 1, step: 0.01 },
        verticalHeatIntensity: { value: 0.01, min: 0, max: 0.1, step: 0.001 },
        wavesIntensity: { value: 0.8, min: 0, max: 2, step: 0.1 },
        wavesFrequency: { value: 500, min: 0, max: 1000, step: 10 },
        wavesSpeed: { value: -0.05, min: -1, max: 1, step: 0.01 },
        rippleIntensity: { value: 0.02, min: 0, max: 0.1, step: 0.001 },
        rippleFrequency: { value: 50, min: 0, max: 100, step: 1 },
        depthThreshold: { value: 0.45, min: 0, max: 1, step: 0.01 },
        depthIntensity: { value: 0.02, min: 0, max: 0.1, step: 0.001 },
        foregroundThreshold: { value: 0.0, min: 0, max: 1, step: 0.01 },
        backgroundThreshold: { value: 0.2, min: 0, max: 1, step: 0.01 },
      });

  const memoizedSetShaderParams = useCallback(() => {
    setShaderParams(shaderParams);
  }, [JSON.stringify(shaderParams), setShaderParams]);

  useEffect(() => {
    memoizedSetShaderParams();
  }, [memoizedSetShaderParams]);

  const customPass = useMemo(() => {
    const customEffect = {
      uniforms: {
        u_time: { value: 0 },
        v_position: { value: new THREE.Vector2() },
        u_image: { value: texture },
        u_maps: { value: depthMapTexture },
        u_mouse: { value: mouseRef.current },
        u_dpi: { value: window.devicePixelRatio },
        u_resolution: { value: new THREE.Vector2(size.width, size.height) },
        u_heatDistortionIntensity: { value: shaderParams.heatDistortionIntensity },
        u_heatThreshold: { value: shaderParams.heatThreshold },
        u_verticalHeatIntensity: { value: shaderParams.verticalHeatIntensity },
        u_wavesIntensity: { value: shaderParams.wavesIntensity },
        u_wavesFrequency: { value: shaderParams.wavesFrequency },
        u_wavesSpeed: { value: shaderParams.wavesSpeed },
        u_rippleIntensity: { value: shaderParams.rippleIntensity },
        u_rippleFrequency: { value: shaderParams.rippleFrequency },
        u_depthThreshold: { value: shaderParams.depthThreshold },
        u_depthIntensity: { value: shaderParams.depthIntensity },
        u_foregroundThreshold: { value: shaderParams.foregroundThreshold },
        u_backgroundThreshold: { value: shaderParams.backgroundThreshold },
      },
      vertexShader: `
        varying vec2 v_position;
        void main() {
          v_position = uv;
          gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
        }
      `,
      fragmentShader: `
        precision mediump float;
      
        uniform float u_time;
        varying vec2 v_position;
        uniform sampler2D u_image;
        uniform sampler2D u_maps;
        uniform vec2 u_mouse;
        uniform float u_dpi;
        uniform vec2 u_resolution;
      
        uniform float u_heatDistortionIntensity;
        uniform float u_heatThreshold;
        uniform float u_verticalHeatIntensity;
        uniform float u_wavesIntensity;
        uniform float u_wavesFrequency;
        uniform float u_wavesSpeed;
        uniform float u_rippleIntensity;
        uniform float u_rippleFrequency;
        uniform float u_depthThreshold;
        uniform float u_depthIntensity;
        uniform float u_foregroundThreshold;
        uniform float u_backgroundThreshold;
      
        vec2 pixel() {
          return vec2(1.0 * u_dpi) / u_resolution;
        }
      
        float wave(float x, float freq, float speed) {
          return sin(x * freq + ((u_time * 20.0) * speed));
        }
      
        vec2 distortion(vec2 pos) {
          vec4 maps = texture2D(u_maps, pos);
          float depth = maps.b;
      
          // Check if depth is within the specified range
          float depthMask = step(u_foregroundThreshold, depth) * (1.0 - step(u_backgroundThreshold, depth));
      
          // Heat distortion
          float heatMask = 1.0 - step(u_heatThreshold, depth);
          float heatDistortionIntensity = u_heatDistortionIntensity + depth * 0.1;
          vec2 heatIntensity = vec2(heatDistortionIntensity, heatDistortionIntensity) * pixel() * 2.0;
          vec2 heatDistortion = vec2(
            wave(pos.y, 50.0, 0.05),
            wave(pos.x, 30.0, 0.02)
          ) * 2.0;
          float verticalHeat = sin(pos.y * 100.0 - (u_time * 20.0) * 5.0) * u_verticalHeatIntensity;
          heatDistortion.y += verticalHeat;
          heatDistortion *= heatMask;
      
          // Waves distortion
          float y = maps.g;
          float y2 = pow(y, 2.0);
          vec2 wavesIntensity = vec2(
            u_wavesIntensity - (y2 * u_wavesIntensity),
            u_wavesIntensity + (y2 * u_wavesIntensity)
          ) * pixel() * 2.0;
          vec2 waves = vec2(
            wave(y, u_wavesFrequency - (y2 * u_wavesFrequency), u_wavesSpeed),
            wave(y, u_wavesFrequency - (y2 * u_wavesFrequency * 0.08), u_wavesSpeed)
            + wave(y, u_wavesFrequency - (y2 * u_wavesFrequency * 0.08), u_wavesSpeed)
            + wave(pos.x, u_wavesFrequency + (y2 * u_wavesFrequency * 2.0), u_wavesSpeed)
          ) * 2.0;
          float ripple = sin(distance(pos, vec2(0.5)) * u_rippleFrequency - (u_time * 20.0) * 5.0) * u_rippleIntensity;
          waves += vec2(ripple, ripple);
      
          // Depth distortion
          vec2 depthIntensity = vec2(u_depthIntensity, u_depthIntensity);
          float d = 0.1 - pow(depth, 1.5);
          vec2 depthDistortion = depthIntensity * u_mouse * d * 3.0;
      
          // Apply distortion only within the specified depth range
          return pos + (heatDistortion * heatIntensity + waves * wavesIntensity + depthDistortion) * depthMask;
        }
      
        void main() {
          vec2 pos = v_position.xy;
          vec2 distortedPos = distortion(pos);
          vec4 color = texture2D(u_image, distortedPos);
          gl_FragColor = vec4(color.rgb, 1.0);
        }
      `,
    };

    return new ShaderPass(customEffect);
  }, [
    size,
    texture,
    depthMapTexture,
    shaderParams.heatDistortionIntensity,
    shaderParams.heatThreshold,
    shaderParams.verticalHeatIntensity,
    shaderParams.wavesIntensity,
    shaderParams.wavesFrequency,
    shaderParams.wavesSpeed,
    shaderParams.rippleIntensity,
    shaderParams.rippleFrequency,
    shaderParams.depthThreshold,
    shaderParams.depthIntensity,
    shaderParams.backgroundThreshold,
    shaderParams.foregroundThreshold,
  ]);

  useEffect(() => {
    if (composer) {
      composer.addPass(customPass);
    }

    const handleMouseMove = (event: MouseEvent) => {
      mouseRef.current.x = event.clientX / window.innerWidth;
      mouseRef.current.y = 1 - event.clientY / window.innerHeight;
    };

    window.addEventListener("mousemove", handleMouseMove);

    return () => {
      if (composer) {
        composer.removePass(customPass);
      }
      window.removeEventListener("mousemove", handleMouseMove);
    };
  }, [composer, customPass]);

  useFrame(({ clock }) => {
    const elapsedTime = clock.getElapsedTime();
    timeRef.current = elapsedTime;
    customPass.uniforms.u_time.value = timeRef.current;
    customPass.uniforms.u_mouse.value = mouseRef.current;
  });

  return null;
};

export default CombinedDepthEffect;
