import {
  select,
  rollups,
  sort,
  sum,
  max,
  range,
  json,
  scaleOrdinal,
  schemeTableau10,
  selectAll,
} from "d3";
import {
  memoize,
  observeResize,
  formatNumber,
  formatPerc,
  getDate,
  one,
} from "./utils";
import { sankey, sankeyLeft, sankeyLinkHorizontal } from "d3-sankey";
import { colorLegend } from "./components/colorLegend";

const config = {
  subPlotHeight: 300,
  subPlotWidthPct: 0.7,
  titlePadding: 10,
  marginLeft: 10,
  marginTop: 35,
  marginRight: 10,
  marginBottom: 10,
  nodeWidth: 25,
  nodePadding: 10,
};

const getSelectionValue = (el) => (el.value === "" ? undefined : el.value);

const onboardingStatus2idx = {
  presignup: 0,
  "onboarding-started": 1,
  "onboarding-form": 2,
  "onboarding-bank": 3,
  "validation-queued": 4,
  denied: 5,
  "loan-requested": 5,
  "no-loan-requested": 5,
};

/**
 * @typedef {import("../bindings").BQFunnelRowGrouped} BQFunnelRowGrouped
 */

export async function plotFunnel({
  containerId,
  loadingId,
  startDateId,
  endDateId,
  rowsSelectionId,
  colorSelectionId,
  colorLegendId,
}) {
  const render = () =>
    viz({
      container,
      startDateElement,
      endDateElement,
      rowsSelectionElement,
      colorSelectionElement,
      colorLegendElement,
      state,
      setState,
    });
  const container = document.getElementById(containerId);
  const loadingElement = document.getElementById(loadingId);
  const startDateElement = document.getElementById(startDateId);
  const endDateElement = document.getElementById(endDateId);
  const rowsSelectionElement = document.getElementById(rowsSelectionId);
  const colorSelectionElement = document.getElementById(colorSelectionId);
  const colorLegendElement = document.getElementById(colorLegendId);
  const rawData = await getData();
  select(loadingElement).transition().style("opacity", 0).remove();
  let state = {
    rawData,
    startDate: getDate(getSelectionValue(startDateElement)),
    endDate: getDate(getSelectionValue(endDateElement)),
    rowsSelection: getSelectionValue(rowsSelectionElement),
    colorSelection: getSelectionValue(colorSelectionElement),
  };
  const setState = (next, doRender = true) => {
    state = next(state);
    if (doRender) render();
  };
  render();
}

async function getData() {
  /** @type {Array<BQFunnelRowGrouped>} */
  const data = await json("/funnel/data");
  return data.map((d) => ({ ...d, created_at: getDate(d.created_at) }));
}

function dataFilter({ state, setState, container }) {
  /**
   * @param {Array<BQFunnelRowGrouped>} rawData
   * @param {string?} colorSelection
   */
  function formatData(rawData, colorSelection) {
    let data = rollups(
      rawData,
      (d) => sum(d.map((d) => d.n)),
      (d) => d[colorSelection],
      (d) => d.status,
    ).map(([colorValue, d]) => {
      let dd = d.map(([status, n]) => ({
        idx: onboardingStatus2idx[status],
        status,
        colorValue,
        n,
      }));
      // Make sure all states exists
      for (const [status, idx] of Object.entries(onboardingStatus2idx)) {
        if (idx !== 5 && dd.map((d) => d.status).indexOf(status) === -1) {
          dd.push({ idx, status, colorValue, n: 0 });
        }
      }
      dd = sort(dd, (d) => d.idx);
      // Compute cumulatives
      let accum = 0;
      for (let i = dd.length - 1; i >= 0; i--) {
        accum += dd[i].n;
        if (
          !["denied", "loan-requested", "no-loan-requested"].includes(
            dd[i].status,
          )
        ) {
          dd[i].n = accum;
        }
      }
      return { colorValue, data: dd };
    });

    const links = [];
    let totalPresignupLost = 0;

    for (const { colorValue, data: dd } of data) {
      const presignupLost = dd[0].n - dd[1].n;
      totalPresignupLost += presignupLost;
      const max_idx = max(dd, (d) => d.idx);
      for (let i = 0; i < max_idx; i++) {
        const res = dd
          .filter((d) => d.idx === i)
          .flatMap((source) =>
            dd
              .filter((d) => d.idx === i + 1)
              .map((target) => ({
                source: source.status,
                target: target.status,
                value: target.n,
                color: colorValue,
              })),
          );
        links.push(...res);
      }
      links.push({
        source: "presignup",
        target: "tmp",
        value: presignupLost,
        color: colorValue,
      });
    }

    const nodes = [
      ...rollups(
        data.flatMap(({ data }) =>
          data.map(({ status, n }) => ({
            name: status,
            value: n,
          })),
        ),
        (d) => sum(d.map((dd) => dd.value)),
        (d) => d.name,
      ).map(([name, value]) => ({ name, value })),
      { name: "tmp", value: totalPresignupLost },
    ];

    return { links, nodes };
  }

  /**
   * @param {string?} rowsSelection
   * @param {string?} colorSelection
   */
  function processData(startDate, endDate, rowsSelection, colorSelection) {
    const data = memoize(
      () => {
        // console.log("recomputing");
        const rawData = state.rawData.filter(({ created_at }) => {
          const startFilter =
            startDate === undefined ? true : created_at >= startDate;
          const endFilter =
            endDate === undefined ? true : created_at <= endDate;
          return startFilter && endFilter;
        });
        if (rowsSelection === undefined) {
          const d = formatData(rawData, colorSelection);
          return [{ name: "all", data: d, n: d.nodes[0].value }];
        }
        return sort(
          rollups(
            rawData,
            (d) => formatData(d, colorSelection),
            (d) => d[rowsSelection],
          ).map(([name, data]) => ({
            name,
            data,
            n: data.nodes[0].value,
          })),
          (d) => -d.n,
        );
      },
      [startDate, endDate, rowsSelection, colorSelection],
      container,
    );
    setState((state) => ({
      ...state,
      data,
      startDate,
      endDate,
      rowsSelection,
      colorSelection,
    }));
  }

  if (state.data === undefined) {
    processData(
      state.startDate,
      state.endDate,
      state.rowsSelection,
      state.colorSelection,
    );
    return null;
  }

  return { data: state.data, processData };
}

function viz({
  container,
  startDateElement,
  endDateElement,
  rowsSelectionElement,
  colorSelectionElement,
  colorLegendElement,
  state,
  setState,
}) {
  const dimensions = observeResize({ state, setState, container });
  const filters = dataFilter({ state, setState, container });
  if (dimensions === null || filters === null) return;

  const { subPlotHeight, subPlotWidthPct } = config;
  const { width } = dimensions;
  const subplotWidth = width * subPlotWidthPct;

  const { data, processData } = filters;
  const { startDate, endDate, rowsSelection, colorSelection } = state;
  const nPlots = data.length;

  select(startDateElement).on("change", (event) => {
    selectAll(".data-view").text("Select a node.");
    processData(
      getDate(getSelectionValue(event.target)),
      endDate,
      rowsSelection,
      colorSelection,
    );
  });
  select(endDateElement).on("change", (event) => {
    selectAll(".data-view").text("Select a node.");
    processData(
      startDate,
      getDate(getSelectionValue(event.target)),
      rowsSelection,
      colorSelection,
    );
  });
  select(rowsSelectionElement).on("change", (event) => {
    processData(
      startDate,
      endDate,
      getSelectionValue(event.target),
      colorSelection,
    );
  });
  select(colorSelectionElement).on("change", (event) => {
    processData(
      startDate,
      endDate,
      rowsSelection,
      getSelectionValue(event.target),
    );
  });

  const colorValues = memoize(
    () =>
      new Set(
        data
          .flatMap((d) => d.data.links.map((dd) => dd.color))
          .filter((o) => o !== undefined),
      ),
    [colorSelection],
    container,
    "colors",
  );
  const colorScale = scaleOrdinal(schemeTableau10).domain(colorValues);

  // Legend
  colorLegend(select(colorLegendElement), {
    colorScale,
  });

  // Plots
  const totalUsers = sum(data.map((d) => d.n));
  select(container)
    .selectAll("div.subplot")
    .data(range(nPlots))
    .join("div")
    .attr("class", "subplot")
    .each(function (i) {
      const dataViewId = `subplot${i}`;
      one(select(this), "svg", "plot-background")
        .attr("width", subplotWidth)
        .attr("height", subPlotHeight)
        .call(plotSankey, {
          data: data[i],
          width: subplotWidth,
          height: subPlotHeight,
          totalUsers,
          colorScale,
          dataViewId,
          startDate,
          endDate,
          rowsSelection,
        });
      const dataView = memoize(
        () =>
          one(select(this), "div", "data-view")
            .attr("id", dataViewId)
            .text("Select a node."),
        [rowsSelection],
        container,
        dataViewId,
      ).style("max-width", `${width - subplotWidth - 8}px`);
    });
}

function plotSankey(
  svg,
  {
    data: { name, data, n },
    width,
    height,
    totalUsers,
    colorScale,
    dataViewId,
    startDate,
    endDate,
    rowsSelection,
  },
) {
  const {
    titlePadding,
    marginLeft,
    marginRight,
    marginBottom,
    marginTop,
    nodeWidth,
    nodePadding,
  } = config;
  const innerWidth = width - marginLeft - marginRight;
  const innerHeight = height - marginBottom - marginTop;
  const baseCount = max(data.nodes.map((d) => d.value));

  const sankeyGenerator = sankey()
    // @ts-ignore
    .nodeId((d) => d.name)
    .nodeWidth(nodeWidth)
    .nodePadding(nodePadding)
    .extent([
      [marginLeft, marginTop],
      [marginLeft + innerWidth, marginTop + innerHeight],
    ])
    .nodeAlign(sankeyLeft);
  // @ts-ignore
  let { nodes, links } = sankeyGenerator({
    nodes: data.nodes.map((d) => Object.assign({}, d)),
    links: data.links.map((d) => Object.assign({}, d)),
  });
  // @ts-ignore
  nodes = nodes.filter((d) => d.name !== "tmp");
  // @ts-ignore
  links = links.filter((d) => d.target.name !== "tmp");

  // title
  one(svg, "text", "subplot-title")
    .attr("x", titlePadding)
    .attr("y", titlePadding)
    .attr("dominant-baseline", "hanging")
    .text(`${name}: ${formatNumber(n)} (${formatPerc(n / totalUsers)})`);

  // Nodes
  one(svg, "g", "nodes")
    .selectAll("rect.node")
    // @ts-ignore
    .data(nodes)
    .join((enter) =>
      enter
        .append("rect")
        .attr("class", "node")
        .call((selection) => selection.append("title")),
    )
    .attr("x", (d) => d.x0)
    .attr("y", (d) => d.y0)
    .attr("height", (d) => d.y1 - d.y0)
    .attr("width", (d) => d.x1 - d.x0)
    .on("click", function (event, d) {
      const values = {
        // @ts-ignore
        step: d.name,
        rows_column: rowsSelection,
        rows_selection: name,
      };
      if (startDate !== undefined) {
        values["start_date"] = startDate.toISOString().slice(0, 10);
      }
      if (endDate !== undefined) {
        values["end_date"] = endDate.toISOString().slice(0, 10);
      }
      // @ts-ignore
      htmx.ajax("GET", "/funnel/details", { target: `#${dataViewId}`, values });
    })
    .select("title")
    // @ts-ignore
    .text((d) => d.name);

  // Links
  one(svg, "g", "links")
    .attr("fill", "none")
    .selectAll("g.link")
    .data(links)
    .join("g")
    .attr("class", "link")
    .selectAll("path.link")
    .data((d) => [d])
    .join("path")
    .attr("class", "link")
    .attr("d", sankeyLinkHorizontal())
    .attr("stroke", (d) =>
      // @ts-ignore
      d.color === undefined ? "#335f4240" : colorScale(d.color),
    )
    .attr("stroke-width", (d) => Math.max(1, d.width))
    .selectAll("title.link-title")
    .data((d) => [d])
    .join("title")
    .attr("class", "link-title")
    // @ts-ignore
    .text((d) => {
      // @ts-ignore
      const txt = `${d.source.name} -> ${d.target.name} (${d.value})`;
      // @ts-ignore
      return d.color === undefined ? txt : `${d.color}: ${txt}`;
    });

  // Labels
  one(svg, "g", "labels")
    .selectAll("text.label")
    .data(nodes)
    .join("text")
    .attr("class", "label")
    .attr("x", (d) => (d.x0 < width / 2 ? d.x1 + 6 : d.x0 - 6))
    .attr("y", (d) => (d.y1 + d.y0) / 2)
    .attr("dy", "0.35em")
    .attr("text-anchor", (d) => (d.x0 < width / 2 ? "start" : "end"))
    .text(
      (d) =>
        // @ts-ignore
        `${d.name}\n${formatNumber(d.value)} (${formatPerc(d.value / baseCount)})`,
    );
}

export const downloadFunnelCsv = ({
  step,
  start_date,
  end_date,
  rows_column,
  rows_selection,
  filter,
}) => {
  const params = new URLSearchParams();
  params.append("step", step);
  if (start_date !== undefined) {
    params.append("start_date", start_date);
  }
  if (end_date !== undefined) {
    params.append("end_date", end_date);
  }
  params.append("rows_column", rows_column);
  params.append("rows_selection", rows_selection);
  params.append("filter", filter);

  const url = `/funnel/download?${params.toString()}`;

  const link = document.createElement("a");
  link.href = url;
  link.setAttribute("download", `funnel-${step}.csv`);
  document.body.appendChild(link);
  link.click();
  link.remove();
};
