import { FilesetResolver, FaceLandmarker, ImageSegmenter } from '@mediapipe/tasks-vision'

let visionReady: Promise<any>
let faceLandmarkerReady: Promise<FaceLandmarker>
let segmenterReady: Promise<ImageSegmenter>

const mediapipeInit = async (): Promise<{ faceLandmarker: FaceLandmarker, imageSegmenter: ImageSegmenter}> => {
  if (!visionReady) {
    visionReady = FilesetResolver.forVisionTasks(
      "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision/wasm"
    )
  }

  const vision = await visionReady

  if (!faceLandmarkerReady) {
    faceLandmarkerReady = FaceLandmarker.createFromModelPath(vision,
      "https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task"
    )
  }

  const faceLandmarker = await faceLandmarkerReady

  if (!segmenterReady) {
    segmenterReady = ImageSegmenter.createFromModelPath(vision,
      "https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1683332563830600"
    )
  }

  const imageSegmenter = await segmenterReady

  return { faceLandmarker, imageSegmenter }
}

export const findLandmarks = async (imageEl: HTMLImageElement) => {
  const {faceLandmarker} = await mediapipeInit()
  const detected = await faceLandmarker.detect(imageEl)
  
  return detected?.faceLandmarks?.[0]
}

export const segmentPeople = async (imageEl: HTMLImageElement): Promise<HTMLCanvasElement> => {
  const {imageSegmenter} = await mediapipeInit()
  const segmentResult = await imageSegmenter.segment(imageEl)
  const confidenceMask = segmentResult?.confidenceMasks?.[0]

  if (!confidenceMask) {
    throw new Error('Could not generate segmentation mask!')
  }

  const data = confidenceMask.getAsUint8Array()
  const clampedData = new Uint8ClampedArray(data.length * 4)
  data.forEach((byte, i) => {
    clampedData[(i*4)] = byte
    clampedData[(i*4) + 1] = byte
    clampedData[(i*4) + 2] = byte
    clampedData[(i*4) + 3] = byte
  })
  const segmentImageData = new ImageData(clampedData, confidenceMask.width)
  const segmentCanvas = document.createElement('canvas') as unknown as HTMLCanvasElement
  segmentCanvas.width = confidenceMask.width
  segmentCanvas.height = confidenceMask.height
  const segmentCtx = segmentCanvas.getContext('2d')
  segmentCtx?.putImageData(segmentImageData, 0, 0)

  return segmentCanvas
}