import Node, { RangeNode, RangePointNode } from "./Node.js";


const allEqual = pointArray => {
    if (pointArray.length === 1) {
        return false;
    }

    let startingPoint = pointArray[0];
        
    for (let index = 1; index < pointArray.length; index++) {
        let currentPoint = pointArray[index];
        if (currentPoint.x !== startingPoint.x || 
                currentPoint.y !== startingPoint.y) {
            return false;
        }
    }

    return true;
}

const comparator = (a, b) => {
    if (a.x < b.x) {
        return -1;
    } else if (a.x > b.x) {
        return 1;
    }

    return 0;
}

const pushValue = (leftArray, rightArray,point, median, yAxis) => {
    let value = point.x;
    if (yAxis) {
        value = point.y;
    }

    if (value > median) {
        rightArray.push(point);
    } else if (value <= median) {
        leftArray.push(point);
    }

}

const findMedian = (array, yAxis) => {
    let sum = 0;
    let length = array.length;
    for (let index = 0; index < length; index++) {
        let point = array[index];
        let value = point.x;

        if (yAxis) {
            value = point.y;
        }
        sum += value; 
    }

    let median = sum / length;

    return median;
}

class kdTree {
    static startDepth = false;
    #root = null;
    constructor() {
        this.#root = null;
    }

    _rangeQueryTraversalCount(currentNode, boundary, resultArray) {
        let { x1 : minX, x2 : maxX, y1 : minY, y2 : maxY } = boundary;
        let value = currentNode.getValue();
        
        if (currentNode.isLeaf()) {
            if (value.x >= minX && 
                    value.x <= maxX && 
                    value.y >= minY && 
                    value.y <= maxY) {
                resultArray.count++;
            }
            return ;
        }

        let yAxis = currentNode.isYAxis();
        let max = maxX;
        let min = minX;

        let left = currentNode.getLeft();
        let right = currentNode.getRight();

        if (yAxis) {
            max = maxY;
            min = minY;
        }

        if (value > min && left) {
            this._rangeQueryTraversalCount(left, boundary, resultArray);
        }

        if (currentNode instanceof RangePointNode && 
                    value >= min && value <= max) {
            currentNode.getPoints().forEach(point => {
                resultArray.push(point);
            })
        }

        if (value < max && right) {
            this._rangeQueryTraversalCount(right, boundary, resultArray);
        }

    }

    _inorderDeletion(currentNode) {
        let left = currentNode.getLeft();
        let right = currentNode.getRight();
        if (left) {
            this._inorderDeletion(left);
            currentNode.setLeft(null);
        }

        currentNode.releaseValue();

        if (right) {
            this._inorderDeletion(right);
            currentNode.setRight(null);
        }
    }

    _addPoint(pointsArray, yAxis) {
        let leftArray = [];
        let rightArray = [];
        let length = pointsArray.length;
        let median = findMedian(pointsArray, yAxis);

        // Base cases when length is 0 or length is 1
        // or all the arrays are same
        if (length === 0) {
            return null;
        }

        if (length === 1) {
            let point = pointsArray[0];
            // return point node
            return new Node(point, true);
        }
        // This check is O(n) same as for loop below so it is okay
        if (allEqual(pointsArray)) {
            return new RangePointNode(median, yAxis, pointsArray);
        }

        for (let index = 0; index < pointsArray.length; index++) {
            let point = pointsArray[index];
            pushValue(leftArray, rightArray, point, median, yAxis);
        }

        let left = this._addPoint(leftArray, !yAxis);
        let right = this._addPoint(rightArray, !yAxis);
        let rangeNode = new RangeNode(median, yAxis)

        rangeNode.setLeft(left);
        rangeNode.setRight(right);

        return rangeNode;
    }

    // Clears the tree
    clear() {
        let root = this.#root;
        if (root) {
            this._inorderDeletion(root);
            this.#root = null;
        }
    }

    addPoints(pointsArray) {
        let sortedArray = pointsArray.sort(comparator);
        let length = pointsArray.length;
        if (length) {
            // Sort if by x-axis
            this.#root =
                this._addPoint(sortedArray, kdTree.startDepth);
        }
        return sortedArray;
    }

    rangeQuery(boundary) {
        let root = this.#root;
        let queryResult = {
            count : 0
        };

        if (root) {
            this._rangeQueryTraversalCount(root, boundary, queryResult);
        }

        return queryResult;
    }
    
}

export default kdTree;