import React, { Component } from 'react';
import * as posenet from '@tensorflow-models/posenet';
import Stats from 'stats.js';

import './App.css';
import PoseDetectionResult from "./PoseDetectionResult";
import { VIDEO_SIZE, NNPARAMS } from "./App";
import { drawKeypoints } from "./drawUtil";
import { MIN_PARTS_SCORE } from "./GamePage";


interface IPoseDetectionProps {
  onChange: (pose: PoseDetectionResult) => void;
  videoStream?: MediaStream;
}

interface IPoseDetectionState {
  started: boolean;
  stats: Stats;
  rootRef: React.RefObject<HTMLDivElement>;
  ctx?: CanvasRenderingContext2D;
  faceCtx?: CanvasRenderingContext2D;
  debug: string;
  net?: posenet.PoseNet;
}

const FACE_WIDTH = 70;

export default class PoseDetection extends Component<IPoseDetectionProps, IPoseDetectionState> {

  constructor(props: IPoseDetectionProps) {
    super(props);

    this.state = {
      started: true,
      stats: new Stats(),
      rootRef: React.createRef(),
      ctx: undefined,
      faceCtx: undefined,
      debug: "",
      net: undefined
    }
    this.startStop = this.startStop.bind(this);
  }

  render() {
    return (
      <div ref={this.state.rootRef}>
        <video id="video" className="flipVideo" width={VIDEO_SIZE.width} height={VIDEO_SIZE.height} />
        <canvas id="canvas" onClick={this.startStop} width={VIDEO_SIZE.width} height={VIDEO_SIZE.height} />
        {this.state.started ?
          null :
          <div className="videoPaused" style={{
            ...VIDEO_SIZE,
          }} onClick={this.startStop}>
            <h1>Paused</h1>
          </div>
        }
        <canvas id="faceCanvas" className="faceCanvas" width={FACE_WIDTH} height={FACE_WIDTH} />
      </div>
    )
  }

  public async componentDidMount() {
    const video = await this.initVideo();
    const net = await posenet.load(NNPARAMS.multiplier);
    this.setState({
      net
    });
    this.initStats();
    this.initCanvas();
    this.detect(net, video);
  }

  public componentWillUnmount() {
    if (this.state.net) {
      this.state.net.dispose();
    }
  }

  private detect(net: posenet.PoseNet, video: HTMLVideoElement) {
    requestAnimationFrame(async () => {
      this.state.stats.begin();
      if (this.state.started) {
        const pose = await this.detectPose(net, video);
        this.drawPose(pose, video);
      }
      this.state.stats.end();
      requestAnimationFrame(() => this.detect(net, video));
    });
  }

  private async detectPose(net: posenet.PoseNet, video: HTMLVideoElement) {
    const pose = await net.estimateSinglePose(video, NNPARAMS.imageScaleFactor, true, NNPARAMS.outputStride);
    const result: PoseDetectionResult = {
      pose
    };
    this.props.onChange(result);
    return pose;
  }

  private async initVideo() {
    const video = document.getElementById("video") as HTMLVideoElement;
    if (!this.props.videoStream || !video) {
      throw new Error("Video not found");
    }
    video.srcObject = this.props.videoStream;
    await new Promise((resolve) => {
      video.onloadedmetadata = resolve
    });
    video.play();
    return video;
  }

  private initStats() {
    this.state.stats.showPanel(0);
    const rootDiv = this.state.rootRef.current;
    const element = this.state.stats.dom;
    element.className = 'statsPanel';
    element.removeAttribute('style');
    if (rootDiv) {
      rootDiv.appendChild(element);
    }
  }

  private initCanvas() {
    const ctx = this.initCtx("canvas");
    const faceCtx = this.initCtx("faceCanvas");
    this.setState({
      ctx,
      faceCtx
    });
  }

  private initCtx(id: string): CanvasRenderingContext2D {
    const canvas = document.getElementById(id) as HTMLCanvasElement;
    if (!canvas) {
      throw new Error("Unable to find canvas");
    }
    const ctx = canvas.getContext('2d');
    if (!ctx) {
      throw new Error("Unable to load canvas");
    }
    return ctx;
  }

  private startStop() {
    this.setState({
      started: !this.state.started
    })
  }

  private drawPose(pose: posenet.Pose, video: HTMLVideoElement) {
    const ctx = this.state.ctx;
    if (!ctx) {
      return;
    }
    this.drawVideo(ctx, video, pose);
    drawKeypoints(pose.keypoints, MIN_PARTS_SCORE, ctx);
    this.drawKarma(ctx, pose.score);
  }

  private drawVideo(ctx: CanvasRenderingContext2D, video: HTMLVideoElement, pose: posenet.Pose) {
    ctx.save();
    ctx.scale(-1, 1);
    ctx.translate(-VIDEO_SIZE.width, 0);
    ctx.drawImage(video, 0, 0, VIDEO_SIZE.width, VIDEO_SIZE.height);
    this.drawFace(this.state.faceCtx, video, pose);
    ctx.restore();
  }

  private drawKarma(ctx: CanvasRenderingContext2D, score: number) {
    ctx.beginPath();
    ctx.moveTo(0, VIDEO_SIZE.height);
    ctx.lineTo(VIDEO_SIZE.width, VIDEO_SIZE.height);
    ctx.lineWidth = 50;
    ctx.strokeStyle = "hsl(" + score * 90 + ", 40%, 50%)";
    ctx.stroke();
  }

  private drawFace(ctx: CanvasRenderingContext2D | undefined, video: HTMLVideoElement, pose: posenet.Pose) {
    if (!ctx) {
      throw new Error("Ctx not found");
    }
    const nose = pose.keypoints[0];
    const eye = pose.keypoints[1];
    if (nose.score < 0.7 || eye.score < 0.7) {
      return;
    }
    const distx = nose.position.x - eye.position.x;
    const disty = nose.position.y - eye.position.y;
    const dist = Math.sqrt(distx * distx + disty * disty);
    const r = dist * 2;
    ctx.drawImage(video, (VIDEO_SIZE.width - nose.position.x) - r, nose.position.y - r, r * 2, r * 2, 10, 10, FACE_WIDTH, FACE_WIDTH);
  }

}