import css from "@emotion/css";
import { estimateTextWidth } from "charts/lib/utils";
import { bisector, max } from "d3-array";
import * as scale from "d3-scale";
import { motion, transform } from "framer-motion";
import nanoid from "nanoid";
import React from "react";
import ColorLegend from "../charts/lib/ColorLegend";
import { ChartTooltip, TooltipGridBodyWithLegend } from "../components/tooltip";
import { useContentRect, usePointerPosition, useTheme } from "../hooks";
import { fromNonEmptyArray, insertionOrderSet } from "../lib";
import { getContrastText, lighten } from "../lib/colorManipulator";
import * as M from "../materials";
import { An, Ar, O, pipe } from "../prelude";

const MOAr = O.getMonoid(Ar.getMonoid<string>());

// -----------------------------------------------------------------------------
// Helpers

function maxForStack<A>(stack: An.NonEmptyArray<A>, acc: (a: A) => number) {
  return stack.reduce((sum, x) => sum + acc(x), 0);
}

function maxAcrossStacks<A>(
  stacks: Record<string, An.NonEmptyArray<A>>,
  acc: (a: A) => number
) {
  return Object.keys(stacks).reduce((maxSum, stackKey) => {
    const stackSum = maxForStack(stacks[stackKey], acc);
    return stackSum > maxSum ? stackSum : maxSum;
  }, 0);
}

function mkLayerScale<A>(
  scale: (x: A) => number,
  getLayer: (x: A) => string,
  layerKeys: An.NonEmptyArray<string>
) {
  return (stack: An.NonEmptyArray<A>, layerKey: string) => {
    const upto = pipe(
      layerKeys,
      Ar.findIndex((x) => x === layerKey),
      O.getOrElse(() => 0)
    );

    return Ar.range(0, upto)
      .map((i) => layerKeys[i])
      .reduce(
        (info, layerKey) => {
          return pipe(
            stack,
            Ar.findFirst((x) => getLayer(x) === layerKey),
            O.map((x) => {
              return info.start === -1
                ? { start: 0, height: scale(x) }
                : { start: info.start + info.height, height: scale(x) };
            }),
            O.getOrElse(() => info)
          );
        },
        { start: -1, height: 0 }
      );
  };
}

// -----------------------------------------------------------------------------
// Components

const styles = {
  xLabel: css`
    ${M.fontChartLabel};
    fill: ${M.lightText};
    alignment-baseline: middle;
    text-anchor: middle;
  `,
  xLabelRotate: css`
    text-anchor: end;
  `,
  yLabel: css`
    ${M.fontChartLabel};
    fill: ${M.lightText};
    alignment-baseline: middle;
    text-anchor: start;
  `,
  yLine: css`
    stroke-width: 1;
    stroke: ${M.divider};
  `,
  yLineOver: css`
    stroke-width: 1;
    stroke: rgba(0, 0, 0, 0.15);
  `,
};

export interface ChartBasis {
  height: number;
  width: number;
}

export interface StackedBar<A> extends ChartBasis {
  data: An.NonEmptyArray<A>;
  getStack(x: A): string;
  getLayer(x: A): string;
  getLayerLabel(x: string): string;
  getValue(x: A): number;
  formatX(x: string, i: number, keys: An.NonEmptyArray<string>): string;
  formatY(x: number): string;
  colorPalette(count: number): Array<string>;
  formatYAxisLabel?(x: number): string;
  yUnitLabel?: React.ReactNode;
  yNice?: number;
  showTooltip?: boolean;
  hideXAxisLabel?: boolean;
  hideYAxis?: boolean;
  hideLegend?: boolean;
  highlightedItems?: O.Option<An.NonEmptyArray<string>>;
  rotateXLabels?: boolean;
  annotations?: {
    label: string;
    items: Array<{ layer: string; startStack: string; endStack: string }>;
  };
}

export const StackedBar = React.memo(StackedBar_);
function StackedBar_<A>(p: StackedBar<A>) {
  const { client } = useTheme();
  const [pointerPositionRef, pointerPosition] =
    usePointerPosition<SVGRectElement>();

  const stackKeys = React.useMemo(
    () => insertionOrderSet(p.data, p.getStack),
    [p.data, p.getStack]
  );
  const layerKeys = React.useMemo(
    () => insertionOrderSet(p.data, p.getLayer),
    [p.data, p.getLayer]
  );
  const stacks = React.useMemo(
    () => An.groupBy(p.getStack)(p.data),
    [p.data, p.getStack]
  );
  const maxValue = React.useMemo(
    () => maxAcrossStacks(stacks, p.getValue),
    [stacks, p.getValue]
  );

  const marginBottom = React.useMemo(() => {
    if (p.rotateXLabels) {
      const maxWidth = max(stackKeys, (d) => estimateTextWidth(d, 12)) ?? 0;

      return Math.max(maxWidth + 20, 30);
    }

    return 30;
  }, [p.rotateXLabels, stackKeys]);

  const haveAnnotations = !!p.annotations;
  const { margin, outerHeight, innerHeight, outerWidth, innerWidth } =
    React.useMemo(() => {
      const margin = {
        top: haveAnnotations ? (client.screenMDown ? 46 : 54) : 20,
        right: 0,
        bottom: marginBottom,
        left: p.hideYAxis ? 0 : client.screenMDown ? 30 : 50,
      };
      return {
        margin,
        outerHeight: p.height + margin.top + margin.bottom,
        innerHeight: p.height,
        outerWidth: p.width,
        innerWidth: p.width - margin.left - margin.right,
      };
    }, [
      client.screenMDown,
      p.height,
      p.width,
      p.hideYAxis,
      haveAnnotations,
      marginBottom,
    ]);

  const xScale = React.useMemo(
    () =>
      scale
        .scaleBand()
        .domain(fromNonEmptyArray(stackKeys))
        .range([0, innerWidth])
        .paddingInner(1 / (innerWidth / stackKeys.length))
        .paddingOuter(0),
    [stackKeys, innerWidth]
  );

  const xScaleInvert = React.useMemo(() => {
    const steps = stackKeys.map((x) => ({
      domain: x,
      range: xScale(x) as number,
    }));
    return (pos: number) => {
      const idx = bisector<{ range: number }, unknown>((x) => x.range).left(
        steps,
        pos
      );
      return (steps[idx - 1] || steps[0]).domain;
    };
  }, [stackKeys, xScale]);

  const { yScale, yTicks } = React.useMemo(() => {
    const s = scale
      .scaleLinear()
      .domain([0, maxValue])
      .range([0, innerHeight])
      .nice(p.yNice);
    return { yScale: s, yTicks: s.ticks(6) };
  }, [maxValue, innerHeight, p.yNice]);

  const layerScale = React.useMemo(
    () =>
      mkLayerScale<A>(
        (a: A) => yScale(p.getValue(a)) as $FixMe,
        (a: A) => p.getLayer(a),
        layerKeys
      ),
    [layerKeys, yScale, p.getLayer, p.getValue] // eslint-disable-line react-hooks/exhaustive-deps
  );

  const colorScale = React.useMemo(
    () =>
      scale.scaleOrdinal(p.colorPalette(layerKeys.length)).domain(layerKeys),
    [layerKeys, p.colorPalette] // eslint-disable-line react-hooks/exhaustive-deps
  );

  const annotations = React.useMemo(() => {
    return p.annotations
      ? {
          label: p.annotations.label,
          items: p.annotations.items.map((x) => ({
            x0: xScale(x.startStack) as number,
            x1: xScale(x.endStack) as number,
            label: p.getLayerLabel(x.layer),
            color: colorScale(x.layer),
          })),
        }
      : null;
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [xScale, colorScale, p.getLayerLabel, p.annotations]);

  const hoveredItem = pipe(
    pointerPosition,
    O.map((pos) => xScaleInvert(pos.x))
  );

  const selectedItems = MOAr.concat(
    pipe(
      hoveredItem,
      O.map((x) => An.of(x))
    ),
    p.highlightedItems || O.none
  );

  const colorLegendValues = React.useMemo(
    () =>
      layerKeys.map((key) => ({
        color: colorScale(key),
        label: p.getLayerLabel(key),
      })),
    [layerKeys, colorScale, p]
  );

  return (
    <>
      <svg height={outerHeight} width={outerWidth}>
        {annotations && (
          <g transform={`translate(${margin.left}, 0)`}>
            <StackAnnotation
              height={client.screenMDown ? 24 : 32}
              label={annotations.label}
              items={annotations.items}
            />
          </g>
        )}
        <g transform={`translate(${margin.left}, ${margin.top})`}>
          {stackKeys.map((key) => {
            const stack: An.NonEmptyArray<A> = stacks[key];
            return (
              <StackItem
                key={key}
                stack={stack}
                getLayer={p.getLayer}
                xScale={xScale}
                getStack={p.getStack}
                colorScale={colorScale}
                innerHeight={innerHeight}
                layerScale={
                  layerScale as $FixMe /* Can't infer A when memoized component … */
                }
                highlight={pipe(
                  selectedItems,
                  O.exists((xs) => xs.includes(key))
                )}
                hasSelectedItems={O.isSome(selectedItems)}
              />
            );
          })}
          {!p.hideXAxisLabel && (
            <XAxisLabels
              stackKeys={stackKeys}
              xScale={xScale}
              formatX={p.formatX}
              innerHeight={innerHeight}
              rotate={p.rotateXLabels}
            />
          )}
          {!p.hideYAxis && (
            <YAxisLongTicks
              yScale={yScale}
              yTicks={yTicks}
              innerHeight={innerHeight}
              innerWidth={innerWidth}
              yUnitLabel={p.yUnitLabel}
            />
          )}
        </g>

        <g transform={`translate(0, ${margin.top})`}>
          {!p.hideYAxis && (
            <YAxisLabels
              yScale={yScale}
              yTicks={yTicks}
              width={margin.left - 4}
              innerHeight={innerHeight}
              formatY={p.formatY}
              formatYAxisLabel={p.formatYAxisLabel}
            />
          )}
        </g>

        {
          <g transform={`translate(${margin.left}, ${margin.top})`}>
            {stackKeys
              .filter((key) =>
                pipe(
                  hoveredItem,
                  O.exists((x) => x === key)
                )
              )
              .map((key) => {
                const stack = stacks[key];
                return (
                  <StackTooltip
                    key={key}
                    stack={stack}
                    x={xScale(key) as number}
                    width={xScale.bandwidth()}
                    colorScale={colorScale}
                    getLayer={p.getLayer}
                    getLayerLabel={p.getLayerLabel}
                    getValue={p.getValue}
                    formatY={p.formatY}
                    xScale={xScale}
                    yScale={yScale}
                    innerHeight={innerHeight}
                    title={p.formatX(
                      key,
                      0 /* Special case: we always want the rendering for the 0 position */,
                      stackKeys
                    )}
                  />
                );
              })}

            {p.showTooltip && (
              <rect
                ref={pointerPositionRef}
                x={0}
                y={0}
                width={innerWidth}
                height={innerHeight}
                fill="transparent"
              />
            )}
          </g>
        }
      </svg>
      <div>
        {!p.hideLegend && (
          <ColorLegend inline maxWidth={"100%"} values={colorLegendValues} />
        )}
      </div>
    </>
  );
}

export function StackedBarAuto<A>(props: Omit<StackedBar<A>, "width">) {
  const [ref, contentRect] = useContentRect();
  return (
    <div ref={ref} style={{ width: "100%" }}>
      {contentRect.width > 0 && (
        <StackedBar {...props} width={contentRect.width} />
      )}
    </div>
  );
}

// -----------------------------------------------------------------------------

const StackItem = React.memo(StackItem_);
function StackItem_<A>({
  stack,
  xScale,
  layerScale,
  getLayer,
  getStack,
  colorScale,
  innerHeight,
  highlight = false,
  hasSelectedItems,
}: {
  stack: An.NonEmptyArray<A>;
  xScale: scale.ScaleBand<string>;
  layerScale: (
    stack: An.NonEmptyArray<A>,
    layerKey: string
  ) => {
    start: number;
    height: number;
  };
  getLayer(x: A): string;
  getStack(x: A): string;
  innerHeight: number;
  colorScale: scale.ScaleOrdinal<string, string>;
  highlight?: boolean;
  hasSelectedItems?: boolean;
}) {
  const xMax = xScale.range()[1];
  const mkDelay = transform([0, xMax], [0, 0.5]);
  return (
    <g>
      {stack.map((d) => {
        const { start, height } = layerScale(stack, getLayer(d));
        const x = xScale(getStack(d));
        const y = innerHeight - (start + height);
        const delay = mkDelay(x || 0);
        const props = {
          x,
          y,
          height,
          fill: hasSelectedItems
            ? highlight
              ? colorScale(getLayer(d))
              : lighten(colorScale(getLayer(d)), 0.5)
            : colorScale(getLayer(d)),
          opacity: 1,
        };
        return (
          <motion.rect
            key={`${getLayer(d)}`}
            width={xScale.bandwidth()}
            initial={false}
            animate={props}
            transition={{ y: { delay }, height: { delay } }}
          />
        );
      })}
    </g>
  );
}

const XAxisLabels = React.memo(XAxisLabels_);
function XAxisLabels_({
  xScale,
  stackKeys,
  formatX,
  innerHeight,
  rotate,
}: {
  xScale: scale.ScaleBand<string>;
  stackKeys: An.NonEmptyArray<string>;
  formatX(x: string, i: number, keys: An.NonEmptyArray<string>): string;
  innerHeight: number;
  rotate?: boolean;
}) {
  return (
    <>
      {stackKeys.map((stack, i) => {
        const x = (xScale(stack) || 0) + xScale.bandwidth() / 2;
        return (
          <g key={stack} transform={`translate(${x}, ${innerHeight})`}>
            <line css={styles.yLineOver} y2={4} />
            <text
              css={[styles.xLabel, rotate && styles.xLabelRotate]}
              {...(rotate
                ? { dx: "-1.2em", transform: "rotate(-90)" }
                : { dy: "1.2em" })}
            >
              {formatX(stack, i, stackKeys)}
            </text>
          </g>
        );
      })}
    </>
  );
}

const YAxisLongTicks = React.memo(YAxisLongTicks_);
function YAxisLongTicks_({
  yScale,
  yTicks,
  innerHeight,
  innerWidth,
  yUnitLabel,
}: {
  yScale: scale.ScaleLinear<number, number>;
  yTicks: Array<number>;
  innerHeight: number;
  innerWidth: number;
  yUnitLabel?: React.ReactNode;
}) {
  return (
    <>
      {yTicks.map((tick, i) => {
        const isLast = i === yTicks.length - 1;
        const y = innerHeight - (yScale(tick) as $FixMe);
        return (
          <React.Fragment key={tick}>
            <motion.line
              css={styles.yLineOver}
              transition={{ ease: "easeOut" }}
              x1={0}
              x2={innerWidth}
              initial={{ y1: y, y2: y, opacity: 0 }}
              animate={{ y1: y, y2: y, opacity: [0, 1] }}
            />
            {yUnitLabel && isLast && (
              <motion.text
                css={styles.yLabel}
                transition={{ ease: "easeOut" }}
                x={0}
                dy={"-0.8em"}
                initial={{ y }}
                animate={{ y }}
              >
                {yUnitLabel}
              </motion.text>
            )}
          </React.Fragment>
        );
      })}
    </>
  );
}

const YAxisLabels = React.memo(YAxisLabels_);
function YAxisLabels_({
  yScale,
  yTicks,
  width,
  innerHeight,
  formatY,
  formatYAxisLabel,
}: {
  yScale: scale.ScaleLinear<number, number>;
  yTicks: Array<number>;
  width: number;
  innerHeight: number;
  formatY(x: number): string;
  formatYAxisLabel?(x: number): string;
}) {
  return (
    <>
      {yTicks.map((tick) => {
        const y = innerHeight - (yScale(tick) as $FixMe);
        return (
          <React.Fragment key={tick}>
            <motion.line
              css={styles.yLine}
              transition={{ ease: "easeOut" }}
              x1={0}
              x2={width}
              y1={y}
              y2={y}
              initial={{ y1: y, y2: y, opacity: 0 }}
              animate={{ y1: y, y2: y, opacity: [0, 1] }}
            />
            <motion.text
              css={styles.yLabel}
              y={y}
              initial={{ y, opacity: 0 }}
              animate={{ y, opacity: [0, 1] }}
              transition={{ ease: "easeOut" }}
              dy={"-0.8em"}
            >
              {formatYAxisLabel ? formatYAxisLabel(tick) : formatY(tick)}
            </motion.text>
          </React.Fragment>
        );
      })}
    </>
  );
}

const StackTooltip = React.memo(StackTooltip_);
function StackTooltip_<A>({
  stack,
  colorScale,
  getLayer,
  getLayerLabel,
  getValue,
  formatY,
  yScale,
  innerHeight,
  title,
  x,
  width,
}: {
  stack: An.NonEmptyArray<A>;
  key: string;
  colorScale: scale.ScaleOrdinal<string, string>;
  xScale: scale.ScaleBand<string>;
  yScale: scale.ScaleLinear<number, number>;
  getLayer(x: A): string;
  getLayerLabel(x: string): string;
  getValue(x: A): number;
  formatY(x: number): string;
  innerHeight: number;
  title: string;
  x: number;
  width: number;
}) {
  const height = yScale(maxForStack(stack, getValue)) as number;
  const y = innerHeight - height;

  const rows = stack
    .map((x) => [
      colorScale(getLayer(x)),
      getLayerLabel(getLayer(x)),
      formatY(getValue(x)),
    ])
    .reverse();

  return (
    <ChartTooltip
      title={title}
      content={<TooltipGridBodyWithLegend rows={rows} />}
      placement="top"
      flipBehavior={["top"]}
      visible={true}
      hideOnClick={false}
    >
      <rect x={x} y={y} width={width} height={height} fill="transparent" />
    </ChartTooltip>
  );
}

// -----------------------------------------------------------------------------

export interface StackAnnotation {
  label: string;
  x0: number;
  x1: number;
  color: string;
}

const StackAnnotation = React.memo(StackAnnotation_);
function StackAnnotation_({
  label,
  items,
  height,
}: {
  label: string;
  items: Array<StackAnnotation>;
  height: number;
}) {
  const gradientId = `gradient-${nanoid(5)}`;

  return (
    <>
      <defs>
        <linearGradient id={gradientId}>
          <stop offset="0%" stopColor="white" />
          <stop offset="80%" stopColor="white" />
          <stop offset="100%" stopColor="black" />
        </linearGradient>
      </defs>
      <text
        x={items[0].x0}
        dy="0.8em"
        css={css`
          ${M.fontChartLabel};
          fill: ${M.lightText};
        `}
      >
        {label}
      </text>
      {items.map(({ label, color, x0, x1 }, i) => {
        const maskId = `mask-${nanoid(5)}`;

        const width = x1 - x0;
        return (
          <g key={i} transform={`translate(${x0},16)`}>
            <mask id={maskId}>
              <rect
                x="0"
                y="0"
                width={width - height / 4}
                height={height}
                fill={`url(#${gradientId})`}
              />
            </mask>
            <path
              d={`M 0 0 L ${width - height / 2} 0 L ${width} ${height / 2} L ${
                width - height / 2
              } ${height} L 0 ${height} z`}
              fill={color}
            />
            <text
              x={height / 8}
              y={height / 2}
              dy="0.3em"
              css={css`
                ${M.fontChartLabel};
                fill: ${getContrastText(color, {
                  light: M.whiteText,
                  dark: M.blackText,
                })};
              `}
              mask={`url(#${maskId})`}
            >
              {label}
            </text>
          </g>
        );
      })}
    </>
  );
}
