import * as React from 'react';
import * as ReactRedux from 'react-redux';
import * as _ from 'lodash';
import { useResource, useFrame, CanvasContext } from 'react-three-fiber';
import { useTheme, Theme } from '@material-ui/core/styles';
import * as THREE from 'three';
import { List } from 'immutable';

import * as selectors from 'src/selectors';
import { DisplayPoint, Key, RootState, ProjectionView } from 'src/types';

import PointLabels from './point-labels/PointLabels';
import fragShader from './points.frag';
import vertShader from './points.vert';
import { MAX_POINTS_DISPLAY } from 'src/constants';
import { isRedditPost, isRedditComment } from '@nebula/common';

interface ProjectedTextAttributeArrays {
  positions?: Float32Array;
  sizes?: Float32Array;
  colors?: Float32Array;
}

const BASELINE_POINT_SIZE = 0.09;

function sizeForPoint(point: DisplayPoint): number {
  let size = BASELINE_POINT_SIZE;
  // TODO: Consider reviving this logic with stats from live dataset
  if (isRedditPost(point.item)) {
    const score = Math.max(0, point.item.score || 1000);
    size += 0.01 * 2 * Math.sqrt(score / 1000);
  } else if (isRedditComment(point.item)) {
    const score = Math.max(0, point.item.score || 1000);
    size += 0.01 * 2 * Math.sqrt(score / 150);
    // } else if (isTweet(point.item)) {
    //   const tweet = point.item;
    //   const numFavorites = tweet.numFavorites || 0;
    //   const numFollowers = tweet.authorNumFollowers || 1;
    //   size += 0.01 * (50 * Math.sqrt(numFavorites / numFollowers));
  } else {
    // size += 0.03;
  }
  return Math.min(0.25, size);
}

function toPointsGeomAttrArrays(
  points: List<DisplayPoint>,
  theme: Theme
): ProjectedTextAttributeArrays {
  points = points.slice(0, MAX_POINTS_DISPLAY);
  const positions = new Float32Array(MAX_POINTS_DISPLAY * 3);
  const flattenedProjections = _.flatten(
    points.map((p) => p.position.toArray()).toArray()
  );
  positions.set(flattenedProjections);

  const colors = new Float32Array(MAX_POINTS_DISPLAY * 3);
  colors.set(_.flatten(points.map((p) => p.color.toArray()).toArray()));

  const sizes = new Float32Array(MAX_POINTS_DISPLAY);
  sizes.set(points.map(sizeForPoint).toArray());
  return { positions, colors, sizes };
}

function rayDistanceThreshold(point: DisplayPoint): number {
  return sizeForPoint(point) * 0.16;
}

// NDC stands for normalized device coordinates - screen space
function ndcDistance(
  a: THREE.Vector3,
  b: THREE.Vector3,
  camera: THREE.Camera
): number {
  const [aNdc, bNdc] = [a.clone().project(camera), b.clone().project(camera)];
  return aNdc.distanceTo(bNdc);
}

interface ProjectionPointsProps {
  dims?: number;
  points?: List<DisplayPoint>;
  cameraPosition: THREE.Vector3;
  projectionView: ProjectionView;
  hoveredPointKey: Key;
  onPointHovered: (key: Key) => void;
  onPointClosed: (key: Key) => void;
  pointerDown: boolean;
  labelContainerRef: React.MutableRefObject<HTMLDivElement>;
}

function ProjectionPoints(props: ProjectionPointsProps) {
  const theme = useTheme();
  let points = props.points;
  const effectiveNumPoints = Math.min(points.size, MAX_POINTS_DISPLAY);
  // TODO: Might be able to make this much faster (esp. fewer GCs) by updating
  // these in-place rather than recreating every time props change.
  // To do that, I'd need to know which elements changed. To do THAT I might
  // need to do somewhere where each point is its own component that must be
  // passed as a child to this component. Then there might be some fancy way to
  // hook into React to get the diff on each update. Maybe with React's keys.
  // Related: https://stackoverflow.com/questions/28784050
  const { positions, sizes, colors } = React.useMemo(() => {
    const arrays = toPointsGeomAttrArrays(points, theme);
    return arrays;
  }, [props.points, theme]);

  const [threePointsRef, threePoints]: [
    React.RefObject<THREE.Points>,
    THREE.Points
  ] = useResource();

  useFrame((state: CanvasContext, delta: number) => {
    if (props.pointerDown) return;
    const raycaster = state.raycaster;
    const camera = raycaster.camera;
    if (!camera) return;
    const intersections = raycaster.intersectObject(threePoints);
    if (intersections[0]) {
      let closestToRay = intersections[0];
      for (const intersection of intersections) {
        // TODO: Try using a min improvement threshold to make sure to select
        // point closer to camera even if slightly futher from ray?
        if (intersection.distanceToRay < closestToRay.distanceToRay) {
          closestToRay = intersection;
        }
      }
      const index = closestToRay.index;
      const point = points.get(index);

      const closestPointOnRayVector = closestToRay.point;
      const pointVector = point?.position;
      if (!point || !pointVector) {
        return;
      } else if (
        ndcDistance(closestPointOnRayVector, pointVector, camera) >
        rayDistanceThreshold(point)
      ) {
        props.onPointHovered(null);
      } else if (props.hoveredPointKey !== point.key) {
        props.onPointHovered(point.key);
      }
    }
  });

  return (
    <group position={[0, 0, 0]} rotation={[0, 0, 0]}>
      <PointLabels
        points={points}
        threePoints={threePoints}
        cameraPosition={props.cameraPosition}
        projectionView={props.projectionView}
        labelContainerRef={props.labelContainerRef}
        onPointClosed={props.onPointClosed}
      />
      <points ref={threePointsRef}>
        <bufferGeometry
          attach="geometry"
          drawRange={{ start: 0, count: effectiveNumPoints }}
        >
          <bufferAttribute
            attachObject={['attributes', 'position']}
            array={positions}
            count={MAX_POINTS_DISPLAY}
            // We always represent points as Vector3s, even in 2D mode
            itemSize={3}
            onUpdate={(self) => (self.needsUpdate = true)}
          />
          <bufferAttribute
            attachObject={['attributes', 'customColor']}
            array={colors}
            count={MAX_POINTS_DISPLAY}
            itemSize={3}
            onUpdate={(self) => (self.needsUpdate = true)}
          />
          <bufferAttribute
            attachObject={['attributes', 'size']}
            array={sizes}
            count={MAX_POINTS_DISPLAY}
            itemSize={1}
            onUpdate={(self) => (self.needsUpdate = true)}
          />
        </bufferGeometry>
        <shaderMaterial
          attach="material"
          vertexShader={vertShader}
          fragmentShader={fragShader}
          // Manually tuned: tradeoff between looking jagged when small and
          // having transparent black artifacts on edges.
          alphaTest={0.4}
        />
      </points>
    </group>
  );
}

export default ReactRedux.connect(
  (state: RootState) => ({
    points: selectors.displayPoints(state),
    projectionView: selectors.projectionView(state),
  }),
  null
)(ProjectionPoints);
