import { useEffect } from 'react';
import { $scene } from '@google/model-viewer/lib/model-viewer-base';
import { Mesh, Object3D, Material } from 'three';
import * as THREE from 'three';
import { ModelViewerElement } from '@google/model-viewer';

const getMaterial = (material: Material | Material[]) =>
  Array.isArray(material) ? material[0]?.name : material.name;
const parseName = (materialName?: string, prefix: string = 'photo_') => {
  if (!materialName) return null;
  if (materialName && materialName.startsWith(prefix)) {
    const number = parseInt(materialName.replace(prefix, ''));
    return isNaN(number) ? null : number;
  }
  return null;
};

const useMouseEvents = (
  modelViewerRef: React.RefObject<ModelViewerElement>,
  onUpload?: (position: number) => void
) => {
  const photos: Object3D[] = [];
  const frames: Object3D[] = [];
  const mouse = new THREE.Vector2();
  const raycaster = new THREE.Raycaster();

  const getIntersections = (ev: MouseEvent) => {
    if (!modelViewerRef.current) return [];

    const scene = modelViewerRef.current[$scene];
    const rect = modelViewerRef.current.getBoundingClientRect();
    const x = (ev.clientX - rect.left) / rect.width;
    const y = (ev.clientY - rect.top) / rect.height;

    mouse.set(x * 2 - 1, -(y * 2) + 1);
    raycaster.setFromCamera(mouse, scene.camera);

    const intersections = frames.length
      ? raycaster.intersectObjects(frames, true)
      : raycaster.intersectObjects(photos);

    return intersections;
  };

  const onClick = (ev: MouseEvent) => {
    const intersections = getIntersections(ev);
    if (!intersections.length) return;

    const targetObject = frames.length
      ? intersections.find((i) => i.object.name?.startsWith('frame_'))?.object ||
        intersections[0]?.object.parent
      : intersections[0]?.object;

    if (!targetObject) return;

    const prefix = frames.length ? 'frame_' : 'photo_';
    const name = frames.length ? targetObject.name : getMaterial((targetObject as Mesh).material);
    const photoNumber = parseName(name, prefix);

    if (photoNumber !== null) {
      onUpload?.(photoNumber);
    }
  };

  // TODO: Add throtling
  const onMove = (ev: MouseEvent) => {
    const intersections = getIntersections(ev);
    if (intersections.length) {
      document.body.style.cursor = 'pointer';
    } else {
      document.body.style.cursor = 'auto';
    }
  };

  const loaded = () => {
    const modelViewer = modelViewerRef?.current;
    if (!modelViewer) return;
    const scene = modelViewerRef?.current?.[$scene];
    scene?.traverse((object: Object3D) => {
      // @ts-ignore The type is off for this

      if (object.isMesh && object.material) {
        const mesh = object as Mesh;
        if (object.name.startsWith('frame')) {
          frames.push(object);
        }
        if (getMaterial(mesh.material).startsWith('photo')) {
          photos.push(mesh);
        }
      }
    });
  };

  useEffect(() => {
    const modelViewer = modelViewerRef.current;
    if (!modelViewer) return;
    modelViewer.addEventListener('load', loaded);
    modelViewer.addEventListener('click', onClick);
    modelViewer.addEventListener('mousemove', onMove);
    return () => {
      modelViewer.removeEventListener('load', loaded);
      modelViewer.removeEventListener('click', onClick);
      modelViewer.removeEventListener('mousemove', onMove);
    };
  }, []);
};

export default useMouseEvents;
