import {
  createComputed,
  createEffect,
  createMemo,
  createSignal,
  For,
  JSX,
  onCleanup,
  onMount,
  Show,
  untrack,
} from "solid-js";
import c from "class-c";

import Range from "@repo/utils/Range";
import createPropsProvider from "@repo/utils-solid/createPropsProvider";
import USolid from "@repo/utils-solid/USolid";

import styles from "./Masonry.module.scss";

declare namespace Masonry {
  interface Props {
    class?: string;
    gap?: number | { row: number; column: number };
    animate?: boolean;
    alignColumns?: boolean;
    minColumnWidth: number;
    children: JSX.Element;
  }
}

function Masonry(_props: Masonry.Props) {
  const {
    gap: _gap = 0,
    animate = false,
    alignColumns = false,
    minColumnWidth,
    children: _children,
  } = Masonry.PropsProvider.useMerge(_props) satisfies D;

  const [columnWidth, setColumnWidth] = createSignal(0);

  // Prevent children from being rendered until columnWidth is computed so they
  // only render when they're measured
  // TODO: Check if Solid 2.0 fixes this
  const childArray = createMemo(() => {
    if (USolid.runMemo(() => columnWidth() === 0)) return [];
    return USolid.children(() => _children)
      .toArray()
      .filter(Boolean) as HTMLElement[];
  });

  const gap = createMemo(
    () =>
      (typeof _gap === "number" ? { row: _gap, column: _gap } : _gap) as {
        row: number;
        column: number;
      },
  );
  const indexByChild = createMemo(() =>
    animate ? new Map(childArray().map((child, i) => [child, i])) : null,
  );

  // We need clientHeight to calculate distribution
  createComputed(() => {
    for (const child of childArray())
      if (!(child instanceof HTMLElement))
        throw new Error(
          "Masonry: Provided child is not an instance of HTMLElement",
        );
  });

  const [columnCount, setColumnCount] = createSignal(0);
  const [columns, setColumns] = createSignal<HTMLElement[][]>([]);

  let ref: HTMLDivElement = null!;

  onMount(() => {
    const getColCount = (width: number) =>
      Math.max(
        1,
        Math.floor((width + gap().column) / (minColumnWidth + gap().column)),
      );

    let widthObserver: ResizeObserver;

    function updateColumns(initial = false) {
      widthObserver?.disconnect();
      requestAnimationFrame(() => {
        setColumnCount(getColCount(ref.clientWidth));
        if (initial) setColumns(Array(columnCount()).fill([]));

        // Column elements will be rendered using columnCount, need to wait for
        // render
        requestAnimationFrame(() => {
          const column = ref.querySelector(`.${styles.column}`)!;
          if (!column) return;
          const updateWidth = () => setColumnWidth(column.clientWidth);
          widthObserver = new ResizeObserver(updateWidth);
          widthObserver.observe(column);
          updateWidth();
        });
      });
    }
    updateColumns(true);

    const observer = new ResizeObserver(() => {
      if (
        // Compensate for scrollbar jank
        ![
          getColCount(ref.clientWidth),
          getColCount(ref.clientWidth - 20),
        ].includes(columnCount())
      )
        updateColumns();
    });
    observer.observe(ref);
    onCleanup(() => {
      observer.disconnect();
    });
  });

  type Stage =
    | readonly [name: "measuring", initial: boolean]
    | readonly [name: "rendered", col: number]
    | readonly [name: "move", fromRelative: { x: number; y: number }];
  const [childStages, setChildStages] = createSignal<Map<HTMLElement, Stage>>(
    new Map(),
  );

  let initial = true;
  createComputed(() => {
    const stages = untrack(childStages);

    setChildStages(
      new Map(
        childArray().map((child) => [
          child,
          stages.get(child) || (["measuring", initial] as const),
        ]),
      ),
    );
    initial = false;
  });

  createEffect(() => {
    if (columnWidth() === 0) return;

    const heights: number[] = Array(columnCount()).fill(0);
    const nextColumns = Array.from(
      { length: columnCount() },
      () => [] as HTMLElement[],
    );

    const getColumnRect = (index: number) =>
      ref
        .querySelector(`:scope > .${styles.column}:nth-child(${index + 1})`)
        ?.getBoundingClientRect() ??
      (() => {
        const rect = ref.getBoundingClientRect();
        return { y: rect.y, x: rect.right };
      })();

    for (const child of childArray()) {
      if (child == null) continue;
      const stage = childStages().get(child);

      const minIndex = heights.indexOf(Math.min(...heights));
      nextColumns[minIndex]!.push(child);
      if (
        stage instanceof Array &&
        stage[0] === "rendered" &&
        stage[1] !== minIndex
      ) {
        const current = child.getBoundingClientRect();

        const next = getColumnRect(minIndex);

        next.y += heights[minIndex]!;

        childStages().set(child, [
          "move",
          { x: current.x - next.x, y: current.y - next.y },
        ]);
      }

      heights[minIndex]! += child.scrollHeight + gap().column;
    }

    // const cols = untrack(columns);
    // for (let i = 0; i < columnCount(); i++) {
    //   const prevSet = new Set(cols[i]);
    // }

    setColumns(nextColumns);
  });

  return (
    <div
      class={c`
        ${c(styles)`
          masonry
          ${{ animate }}
          ${alignColumns && "align"}
        `}
        ${_props.class}`}
      style={{
        "--column-gap": `${gap().column}px`,
        "--row-gap": `${gap().row}px`,
      }}
      ref={ref}
    >
      <For each={Range(columns().length)}>
        {(i) => (
          <div class={styles.column}>
            <For each={columns()[i]}>
              {(child) => {
                const [animationDone, setAnimationDone] =
                  createSignal(!animate);

                return (
                  <div
                    class={styles.item}
                    onAnimationEnd={() => {
                      setAnimationDone(true);
                    }}
                    style={(() => {
                      const stage = childStages().get(child)!;
                      // Update mutably to prevent signal update, this is just for tracking
                      childStages().set(child, ["rendered", i]);

                      return animationDone()
                        ? {
                            animation: "none",
                          }
                        : {
                            transform:
                              stage[0] === "measuring"
                                ? "translateY(24px)"
                                : stage instanceof Array && stage[0] === "move"
                                  ? `translate(${stage[1].x}px, ${stage[1].y}px)`
                                  : "",
                            opacity: stage[0] === "measuring" ? 0 : 1,
                            "animation-delay":
                              stage[0] === "measuring" && stage[1] && animate
                                ? `${indexByChild()!.get(child)! * 50}ms`
                                : undefined,
                          };
                    })()}
                  >
                    {child}
                  </div>
                );
              }}
            </For>
          </div>
        )}
      </For>
      <Show when={columnWidth() > 0}>
        <div class={styles.measure} style={{ width: `${columnWidth()}px` }}>
          <For each={[...childStages().keys()]}>
            {(child) => childStages().get(child)![0] === "measuring" && child}
          </For>
        </div>
      </Show>
    </div>
  );
}

Masonry.PropsProvider = createPropsProvider<Masonry.Props>("Masonry");

export default Masonry;
