import css from "@emotion/css";
import * as scale from "d3-scale";
import { tuple } from "fp-ts/lib/function";
import { pipe } from "fp-ts/lib/pipeable";
import { motion, transform } from "framer-motion";
import React from "react";
import ColorLegend, { ColorLegendEntry } from "../charts/lib/ColorLegend";
import { findClosest } from "../charts/lib/utils";
import { ChartTooltip, TooltipGridBodyWithLegend } from "../components/tooltip";
import { useContentRect, usePointerPosition, useTheme } from "../hooks";
import { extentNumber, insertionOrderSet } from "../lib";
import * as M from "../materials";
import { An, Ar, O } from "../prelude";
import { YAnnotation, YAnnotations } from "./annotations";
import { ChartBasis } from "./stacked-bar";

// -----------------------------------------------------------------------------
// Helpers

function stackExtent<A>(stack: An.NonEmptyArray<A>, acc: (a: A) => number) {
  return stack.reduce(([pMin, pMax], x_) => {
    const x = acc(x_);
    return tuple(x < pMin ? x : pMin, x > pMax ? x : pMax);
  }, tuple(Infinity, -Infinity));
}

const popperOptions = {
  positionFixed: true,
};

const MOAr = O.getMonoid(Ar.getMonoid<string>());

// -----------------------------------------------------------------------------
// Components

const styles = {
  xLabel: css`
    ${M.fontAxisLabel};
    fill: ${M.lightText};
    alignment-baseline: middle;
    text-anchor: middle;
  `,
  yLabel: css`
    ${M.fontAxisLabel};
    fill: ${M.lightText};
    alignment-baseline: middle;
    text-anchor: start;
  `,
  connectingLine: css`
    stroke: ${M.divider};
    shape-rendering: crispEdges;
  `,
  xHighlightLine: css`
    stroke-width: 1;
    stroke-dasharray: 4 4;
    shape-rendering: crispEdges;
  `,
  yLine: css`
    stroke-width: 1;
    stroke: ${M.divider};
    shape-rendering: crispEdges;
  `,
  yLineOver: css`
    stroke-width: 1;
    stroke: rgba(0, 0, 0, 0.15);
    shape-rendering: crispEdges;
  `,
};

export type DotStyle = "default" | "medium" | "small";
export type LabelStyle = "default" | "none";

export function foldDotStyle<A>(
  style: DotStyle,
  pattern: { [K in DotStyle]: () => A }
) {
  return pattern[style]();
}

export interface VerticalDotPlot<A> extends ChartBasis {
  data: An.NonEmptyArray<A>;
  getX(x: A): string;
  getColor(x: A): string;
  getValue(x: A): number;
  formatX(x: string, i: number, keys: An.NonEmptyArray<string>): string;
  formatValue(x: number): string;
  colorScale: scale.ScaleOrdinal<string, string>;
  colorLegendValues: Array<ColorLegendEntry>;
  highlightItems?: O.Option<An.NonEmptyArray<string>>;
  formatAxisValue?(x: number): string;
  domainY?: [number, number];
  yUnitLabel?: React.ReactNode;
  hideYAxis?: boolean;
  showTooltip?: boolean;
  showLegend?: boolean;
  dotStyle?: DotStyle;
  labelStyleX?: LabelStyle;
  showYExtent?: boolean;
  yAnnotations?: Array<YAnnotation>;
  marginLeft?: number;
}

export const VerticalDotPlot = React.memo(VerticalDotPlot_);
function VerticalDotPlot_<A>({
  colorScale,
  colorLegendValues,
  marginLeft,
  labelStyleX = "default",
  dotStyle = "default",
  highlightItems = O.none,
  ...p
}: VerticalDotPlot<A>) {
  const { client } = useTheme();
  const [pointerPositionRef, pointerPosition] =
    usePointerPosition<SVGRectElement>();
  const showYAxis = !p.hideYAxis;

  const margin = React.useMemo(
    () => ({
      top: 20,
      right: 0,
      bottom: client.screenSDown ? 12 : 30,
      left: showYAxis ? (marginLeft != null ? marginLeft : 30) : 0,
    }),
    [client.screenSDown, marginLeft, showYAxis]
  );
  const outerHeight = p.height + margin.top + margin.bottom;
  const innerHeight = p.height;
  const outerWidth = p.width;
  const innerWidth = p.width - margin.left - margin.right;

  const domainX = React.useMemo(
    () => insertionOrderSet(p.data, p.getX),
    [p.data, p.getX]
  );
  const domainY = React.useMemo(
    () => (p.domainY != null ? p.domainY : extentNumber(p.data, p.getValue)),
    [p.data, p.domainY, p.getValue]
  );
  const stacks = React.useMemo(
    () => An.groupBy(p.getX)(p.data),
    [p.data, p.getX]
  );

  const xScale = React.useMemo(
    () =>
      scale
        .scalePoint()
        .domain(domainX)
        .range([dotRadius(dotStyle), innerWidth - dotRadius(dotStyle)]),
    [domainX, innerWidth, dotStyle]
  );

  const xScaleInvert = React.useMemo(() => {
    const steps = domainX.map((x) => ({
      domain: x,
      range: xScale(x) as number,
    }));
    return (pos: number) => findClosest(steps, (x) => x.range, pos).domain;
  }, [domainX, xScale]);

  const [yScale, yTicks] = React.useMemo(() => {
    const y = scale.scaleLinear().domain(domainY).range([innerHeight, 0]);
    return [y, y.ticks(client.screenMDown ? 4 : 6)] as const;
  }, [client.screenMDown, domainY, innerHeight]);

  const selectedItems = pipe(
    MOAr.concat(
      pipe(
        pointerPosition,
        O.map((pos) => An.of(xScaleInvert(pos.x)))
      ),
      highlightItems
    )
  );

  return (
    <div>
      <svg height={outerHeight} width={outerWidth}>
        {showYAxis && (
          <g key={"ylabels"} transform={`translate(0, ${margin.top})`}>
            <YLabels
              yTicks={yTicks}
              yScale={yScale}
              xEnd={margin.left - 4}
              formatValue={p.formatAxisValue || p.formatValue}
            />
          </g>
        )}
        <g key={"main"} transform={`translate(${margin.left}, ${margin.top})`}>
          <XTicks
            labelStyleX={labelStyleX}
            domainX={domainX}
            xScale={xScale}
            formatX={p.formatX}
          />
          <YTicks yTicks={yTicks} yScale={yScale} yUnitLabel={p.yUnitLabel} />
          {p.yAnnotations && (
            <YAnnotations
              width={innerWidth + margin.right}
              yScale={yScale}
              yFormat={p.formatValue}
              annotations={p.yAnnotations}
            />
          )}
          {domainX.map((key) => {
            const stack = stacks[key];
            const x = xScale(p.getX(stack[0])) as number;
            const isSelected = pipe(
              selectedItems,
              O.fold(
                () => false,
                (sel) => sel.includes(key)
              )
            );
            return (
              <DotGroup
                key={key}
                name={key}
                stack={stack}
                x={x}
                isSelected={isSelected}
                yScale={yScale}
                colorScale={colorScale}
                getValue={p.getValue}
                getColor={p.getColor}
                formatX={p.formatX}
                formatValue={p.formatValue}
                domainX={domainX}
                showYExtent={p.showYExtent}
                dotStyle={dotStyle}
              />
            );
          })}
          {p.showTooltip && (
            <rect
              ref={pointerPositionRef}
              x={0}
              y={0}
              width={innerWidth}
              height={yScale.range()[0]}
              fill="transparent"
            />
          )}
        </g>
      </svg>
      {p.showLegend && (
        <ColorLegend inline maxWidth={"100%"} values={colorLegendValues} />
      )}
    </div>
  );
}

export function VerticalDotPlotAuto<A>(
  props: Omit<VerticalDotPlot<A>, "width">
) {
  const [ref, contentRect] = useContentRect();
  return (
    <div ref={ref} style={{ width: "100%" }}>
      {contentRect.width > 0 && (
        <VerticalDotPlot {...props} width={contentRect.width} />
      )}
    </div>
  );
}

// -----------------------------------------------------------------------------

export function useColorScaleVDP<A>({
  colorPalette,
  formatColor,
  ...p
}: {
  data: An.NonEmptyArray<A>;
  getColor(x: A): string;
  colorPalette(count: number): Array<string>;
  formatColor(x: string): string;
  domainColor?: Array<string>;
}) {
  const domainColor = React.useMemo(
    () => p.domainColor || insertionOrderSet(p.data, p.getColor),
    [p.data, p.getColor, p.domainColor]
  );

  const colorScale = React.useMemo(
    () =>
      scale.scaleOrdinal(colorPalette(domainColor.length)).domain(domainColor),
    [colorPalette, domainColor]
  );

  const colorLegendValues = React.useMemo(() => {
    return domainColor.map((key) => ({
      color: colorScale(key),
      label: formatColor(key),
    }));
  }, [formatColor, colorScale, domainColor]);

  return [colorScale, colorLegendValues] as const;
}

// -----------------------------------------------------------------------------

function dotRadius(dotStyle: DotStyle) {
  return foldDotStyle(dotStyle, {
    default: () => 8,
    medium: () => 7.5, // This is the max radius when hovered, smaller otherwise
    small: () => 6.5, // This is the max radius when hovered, smaller otherwise
  });
}

interface DotGroupProps<A> {
  name: string;
  stack: An.NonEmptyArray<A>;
  x: number;
  isSelected: boolean;
  yScale: scale.ScaleLinear<number, number>;
  colorScale: scale.ScaleOrdinal<string, string>;
  getColor(x: A): string;
  getValue(x: A): number;
  formatX(x: string, i: number, keys: An.NonEmptyArray<string>): string;
  formatValue(x: number): string;
  domainX: An.NonEmptyArray<string>;
  showYExtent?: boolean;
  dotStyle?: DotStyle;
}
function DotGroup_<A>({
  name,
  stack,
  x,
  isSelected,
  yScale,
  colorScale,
  getValue,
  getColor,
  formatX,
  formatValue,
  domainX,
  showYExtent,
  dotStyle = "default",
}: DotGroupProps<A>) {
  const [y1, y2] = stackExtent(stack, getValue).map(yScale);
  const radius = dotRadius(dotStyle);

  const [maxDuration, maxDots] = [0.4, stack.length];

  return (
    <motion.g>
      {showYExtent &&
        isSelected &&
        stack.map((d) => {
          const y = yScale(getValue(d));
          const color = colorScale(getColor(d));
          return (
            <motion.line
              key={`y-highlight-${color}`}
              css={styles.xHighlightLine}
              x1={0}
              x2={innerWidth}
              y1={y}
              y2={y}
              stroke={color}
            />
          );
        })}
      <motion.line
        css={styles.connectingLine}
        initial={{
          x1: x,
          x2: x,
          y1,
          y2,
          strokeWidth: isSelected ? 3 : 1,
          opacity: 0,
        }}
        animate={{
          x1: x,
          x2: x,
          y1,
          y2,
          strokeWidth: isSelected ? 3 : 1,
          opacity: [0, 1],
        }}
        transition={{ ease: "easeOut" }}
      />
      {stack.map((d, index) => {
        const cy = yScale(getValue(d));
        const { r, fill, strokeWidth } = foldDotStyle(dotStyle, {
          default: () => ({
            r: radius,
            fill: colorScale(getColor(d)),
            strokeWidth: 0,
          }),
          medium: () => ({
            r: isSelected ? radius - 3 / 2 : radius / 2.5,
            fill: isSelected ? "#fff" : colorScale(getColor(d)),
            strokeWidth: isSelected ? 3 : 0,
          }),
          small: () => ({
            r: isSelected ? radius - 3 / 2 : radius / 2.6,
            fill: isSelected ? "#fff" : colorScale(getColor(d)),
            strokeWidth: isSelected ? 3 : 0,
          }),
        });
        return (
          <motion.circle
            key={`dot-${getColor(d)}`}
            stroke={colorScale(getColor(d))}
            initial={{ cx: x, cy, r, fill, strokeWidth, opacity: 0 }}
            animate={{ cx: x, cy, r, fill, strokeWidth, opacity: [0, 1] }}
            transition={{
              ease: "easeOut",
              delay: transform(index, [0, maxDots], [0, maxDuration]),
            }}
          />
        );
      })}
      {isSelected && (
        <ChartTooltip
          title={formatX(
            name,
            0 /* Special case: we always want the rendering for the 0 position */,
            domainX
          )}
          content={
            <TooltipGridBodyWithLegend
              rows={stack.map((x) => [
                colorScale(getColor(x)),
                getColor(x),
                formatValue(getValue(x)),
              ])}
            />
          }
          placement="top"
          visible={isSelected}
          hideOnClick={false}
          popperOptions={popperOptions}
        >
          <circle cx={x} cy={y2} r={radius} fill={"transparent"} />
        </ChartTooltip>
      )}
    </motion.g>
  );
}
const DotGroup = React.memo(DotGroup_);

const XTicks = React.memo(
  ({
    labelStyleX,
    domainX,
    xScale,
    formatX,
  }: {
    labelStyleX: LabelStyle;
    domainX: An.NonEmptyArray<string>;
    xScale: scale.ScalePoint<string>;
    formatX: (x: string, i: number, keys: An.NonEmptyArray<string>) => string;
  }) => {
    return labelStyleX !== "none" ? (
      <>
        {domainX.map((tick, i) => {
          const x = xScale(tick) || 0;
          return (
            <React.Fragment key={`xaxis-${tick}`}>
              <line
                css={styles.yLineOver}
                x1={x}
                x2={x}
                y1={innerHeight}
                y2={innerHeight + 4}
              />
              <text css={styles.xLabel} x={x} y={innerHeight} dy={"1.2em"}>
                {formatX(tick, i, domainX)}
              </text>
            </React.Fragment>
          );
        })}
      </>
    ) : null;
  }
);

const YTicks = React.memo(
  ({
    yTicks,
    yScale,
    yUnitLabel,
  }: {
    yTicks: Array<number>;
    yScale: scale.ScaleLinear<number, number>;
    yUnitLabel: React.ReactNode;
  }) => {
    return (
      <>
        {yTicks.map((tick, i) => {
          const isLast = i === yTicks.length - 1;
          const y = yScale(tick);
          return (
            <React.Fragment key={`yaxis-${tick}`}>
              <motion.line
                css={styles.yLineOver}
                transition={{ ease: "easeOut" }}
                x1={0}
                x2={innerWidth}
                y1={y}
                y2={y}
                initial={{ y1: y, y2: y, opacity: 0 }}
                animate={{ y1: y, y2: y, opacity: [0, 1] }}
              />
              {yUnitLabel && isLast && (
                <motion.text
                  css={styles.yLabel}
                  transition={{ ease: "easeOut" }}
                  x={0}
                  y={y}
                  dy={"-0.8em"}
                  initial={{ y }}
                  animate={{ y }}
                >
                  {yUnitLabel}
                </motion.text>
              )}
            </React.Fragment>
          );
        })}
      </>
    );
  }
);

const YLabels = React.memo(
  ({
    yTicks,
    yScale,
    xEnd,
    formatValue,
  }: {
    yTicks: Array<number>;
    yScale: scale.ScaleLinear<number, number>;
    xEnd: number;
    formatValue: (x: number) => string;
  }) => {
    return (
      <>
        {yTicks.map((tick) => {
          const y = yScale(tick);
          return (
            <React.Fragment key={`ylabel-${tick}`}>
              <motion.line
                css={styles.yLine}
                transition={{ ease: "easeOut" }}
                x1={0}
                x2={xEnd}
                y1={y}
                y2={y}
                initial={{ y1: y, y2: y, opacity: 0 }}
                animate={{ y1: y, y2: y, opacity: [0, 1] }}
              />
              <motion.text
                css={styles.yLabel}
                y={y}
                initial={{ y, opacity: 0 }}
                animate={{ y, opacity: [0, 1] }}
                transition={{ ease: "easeOut" }}
                dy={"-0.8em"}
              >
                {formatValue(tick)}
              </motion.text>
            </React.Fragment>
          );
        })}
      </>
    );
  }
);
