import { sum, sumBy } from 'lodash';
import React, { Fragment, useMemo, useRef } from 'react';
import { UncontrolledTooltip } from 'reactstrap';
import { Forest, Tree } from '../../effects';
import classnames from 'classnames';

export interface ForestDiagramProps<T> {
    forest: Forest<T>;
    total?: number;
    getId: (item: T) => number | string;
    getName: (item: T) => React.ReactNode;
    getValue: (item: T) => number | undefined;
    maxLevel?: number;
}

export const ForestDiagram = <T,>({
    forest,
    total,
    getId,
    getName,
    getValue,
    maxLevel = -1,
}: ForestDiagramProps<T>): React.ReactElement | null => {
    const sum = useMemo(() => total ?? sumBy(forest, (tree) => getValue(tree) ?? 0), [total, forest, getValue]);

    if (maxLevel === 0) return null;

    return (
        <div className="forest-diagram">
            <div className="forest-diagram__part">
                {forest.map((tree) => (
                    <ForestDiagramCell
                        tree={tree}
                        sum={sum}
                        getId={getId}
                        getName={getName}
                        getValue={getValue}
                        maxLevel={maxLevel - 1}
                        key={getId(tree)}
                    />
                ))}
            </div>
        </div>
    );
};

interface ForestDiagramCellProps<T> {
    tree: Tree<T>;
    sum: number;
    getId: (item: T) => number | string;
    getName: (item: T) => React.ReactNode;
    getValue: (item: T) => number | undefined;
    maxLevel: number;
}

const ForestDiagramCell = <T,>({ tree, sum, getId, getName, getValue, maxLevel }: ForestDiagramCellProps<T>) => {
    const id = getId(tree);
    const name = getName(tree);
    const value = getValue(tree) ?? 0;
    const descendantValues = useMemo(() => getDescendantValues([tree], getValue), [tree, getValue]);

    const ref = useRef<HTMLDivElement>(null);

    if (descendantValues === 0) return null;

    return (
        <Fragment key={id}>
            <div style={{ width: `${(descendantValues / sum) * 100}%` }}>
                <div className="forest-diagram-row" ref={ref}>
                    {!!value && (
                        <div
                            style={{ width: `${(value / descendantValues) * 100}%` }}
                            className={classnames('forest-diagram__cell', {
                                'forest-diagram__cell--combined': value < descendantValues,
                            })}
                        >
                            {name}
                        </div>
                    )}
                    {value < descendantValues && (
                        <div
                            style={{ width: `${(1 - value / descendantValues) * 100}%` }}
                            className={classnames('forest-diagram__cell--invalid', {
                                'forest-diagram__cell--combined': !!value,
                            })}
                        >
                            {name}
                        </div>
                    )}
                    {!value && value >= descendantValues && (
                        <div
                            style={{ width: `${(value / descendantValues) * 100}%` }}
                            className="forest-diagram__cell--empty"
                        >
                            {name}
                        </div>
                    )}
                </div>

                {!!descendantValues && (
                    <ForestDiagram
                        forest={tree.children}
                        total={descendantValues}
                        getId={getId}
                        getName={getName}
                        getValue={getValue}
                        maxLevel={maxLevel}
                    />
                )}
            </div>
            {ref.current && (
                <UncontrolledTooltip target={ref.current} placement="top">
                    {name}
                </UncontrolledTooltip>
            )}
        </Fragment>
    );
};

function getDescendantValues<T>(forest: Forest<T>, getValue: (item: T) => number | undefined): number {
    if (forest.length === 0) return 0;

    const forestValues = forest.map((tree) => {
        const descendantValue = getDescendantValues(tree.children, getValue);
        return Math.max(getValue(tree) ?? 0, descendantValue);
    });

    return sum(forestValues);
}
