import { AxisBottom } from '@visx/axis';
import isNumber from 'lodash/isNumber';
import { DateTime } from 'luxon';
import { useMemo } from 'react';

import theme from 'config/theme';
import { ChartSize } from 'generated/graphql';
import { getMonthKey, getNumMonthsInclusive, shortMonthFormat } from 'helpers/dates';
import useChartContext from 'hooks/useChartContext';

const { colors, fontWeights, fontSizes } = theme;

interface Props {
  showLabels: boolean;
  size: ChartSize;
}

const SIZE_TO_MAX_TICKS: Record<ChartSize, number> = {
  [ChartSize.Medium]: 2,
  [ChartSize.Large]: 8,
  [ChartSize.ExtraLarge]: 12,
};

const TimeAxis: React.FC<Props> = ({ showLabels, size }) => {
  const { height, timeScale } = useChartContext();
  const [startJSDate, endJSDate] = timeScale.domain();

  const ticks = useMemo(() => {
    if (!showLabels) {
      return [];
    }

    if (size === ChartSize.Medium) {
      // For medium charts, we just force the start and the end of the date
      // ranges to be pinned to either side of the axis.
      return [startJSDate, endJSDate];
    }

    const startDateTime = DateTime.fromJSDate(startJSDate);
    const endDateTime = DateTime.fromJSDate(endJSDate);

    const maxTicks = SIZE_TO_MAX_TICKS[size];
    const numMonths = getNumMonthsInclusive(getMonthKey(startDateTime), getMonthKey(endDateTime));
    const numTicks = Math.min(maxTicks, numMonths);

    const monthsPerTick = Math.max(1, Math.ceil(numMonths / numTicks));

    const t: Date[] = [];
    let currDate = startDateTime;
    while (currDate <= endDateTime) {
      t.push(currDate.endOf('month').startOf('second').toJSDate());
      currDate = currDate.plus({ months: monthsPerTick });
    }

    return t;
  }, [startJSDate, endJSDate, showLabels, size]);

  return (
    <AxisBottom
      hideAxisLine
      hideTicks
      top={height}
      scale={timeScale}
      tickFormat={formatTick}
      tickValues={ticks}
      stroke="none"
      tickLabelProps={(_val, idx, vals) => ({
        y: 18,
        fill: colors.gray[500],
        textAnchor: getTextAnchorForIdx(idx, vals.length, size),
        fontSize: fontSizes.xxxs,
        fontWeight: fontWeights.bold,
        style: {
          userSelect: 'none',
          textTransform: 'uppercase',
        },
      })}
    />
  );
};

function getTextAnchorForIdx(idx: number, numVals: number, size: ChartSize) {
  if (size !== ChartSize.Medium) {
    return 'middle';
  }

  if (idx === 0) {
    return 'start';
  }

  if (idx === numVals - 1) {
    return 'end';
  }

  return 'middle';
}

function formatTick(value: Date | number | { valueOf: () => number }): string {
  const v = isNumber(value) ? value : value.valueOf();
  return shortMonthFormat(DateTime.fromMillis(v));
}

export default TimeAxis;
