/*
 * Decompiled with CFR 0.152.
 */
package dev.corgitaco.dataanchor.storage._3D;

import dev.corgitaco.dataanchor.DataAnchor;
import dev.corgitaco.dataanchor.storage.NearestPoint;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.TreeSet;
import net.minecraft.class_2382;
import net.minecraft.class_3532;

public class OctreeNearestPointData<T>
implements NearestPoint<T> {
    private final NearestPoint<T>[] leafs;
    private final byte bitShiftScale;
    private final byte highestShiftScale;

    public OctreeNearestPointData(int highestShiftScale) {
        this(0, (byte)highestShiftScale, 2);
    }

    public static <T> OctreeNearestPointData<T> fromSize(int xyzSize) {
        return new OctreeNearestPointData<T>((byte)(32 - Integer.numberOfLeadingZeros(xyzSize)));
    }

    public OctreeNearestPointData() {
        this(0, 31, 2);
    }

    public OctreeNearestPointData(byte bitShiftScale, byte highestShiftScale, int rowSize) {
        this.bitShiftScale = bitShiftScale;
        this.highestShiftScale = highestShiftScale;
        if (bitShiftScale < 0 || bitShiftScale > 31) {
            throw new IllegalArgumentException("bitShiftScale must be between 0 and 31");
        }
        if (rowSize < 2) {
            throw new IllegalArgumentException("rowSize must be greater than 1");
        }
        int smallestEncompassingPowerOfTwo = class_3532.method_15339((int)rowSize);
        if (smallestEncompassingPowerOfTwo != rowSize) {
            DataAnchor.LOGGER.warn("rowSize is not a power of two, rounding up to the nearest power of two...");
            rowSize = smallestEncompassingPowerOfTwo;
        }
        this.leafs = new NearestPoint[rowSize * rowSize * rowSize];
    }

    @Override
    public void setPoint(class_2382 point, T o) {
        int x = point.method_10263();
        int y = point.method_10264();
        int z = point.method_10260();
        int xIndex = this.getXYZIndex(x);
        int yIndex = this.getXYZIndex(y);
        int zIndex = this.getXYZIndex(z);
        this.setPointRecursively(point, o, this.getIndex(xIndex, yIndex, zIndex));
    }

    private void setPointRecursively(class_2382 point, T o, int index) {
        if (this.bitShiftScale == this.highestShiftScale) {
            if (this.leafs[index] == null) {
                this.leafs[index] = new Target<T>(new NearestPoint.PointData<T>(o, point));
            }
            return;
        }
        if (this.leafs[index] == null) {
            this.leafs[index] = new OctreeNearestPointData<T>((byte)(this.bitShiftScale + 1), this.highestShiftScale, this.rowSize());
        }
        this.leafs[index].setPoint(point, o);
    }

    @Override
    public NearestPoint.PointData<T> getNearestPointData(class_2382 point, NearestPoint.DistanceFunction distanceFunction) {
        int x = point.method_10263();
        int y = point.method_10264();
        int z = point.method_10260();
        int xIndex = this.getXYZIndex(x);
        int yIndex = this.getXYZIndex(y);
        int zIndex = this.getXYZIndex(z);
        NearestPoint.PointData<T> nearest = null;
        for (int i = 0; i < this.rowSize(); ++i) {
            int[][] distance;
            for (int[] position : distance = SPIRAL_FAST_3D[i]) {
                NearestPoint.PointData<T> offsetNearest;
                NearestPoint<T> offsetNearestPoint;
                int offsetX = position[0];
                int offsetY = position[1];
                int offsetZ = position[2];
                int offsetXIndex = offsetX + xIndex;
                int offsetYIndex = offsetY + yIndex;
                int offsetZIndex = offsetZ + zIndex;
                int index = this.getIndex(offsetXIndex, offsetYIndex, offsetZIndex);
                if (offsetXIndex < 0 || offsetXIndex >= this.rowSize() || offsetYIndex < 0 || offsetYIndex >= this.rowSize() || offsetZIndex < 0 || offsetZIndex >= this.rowSize() || (offsetNearestPoint = this.leafs[index]) == null || (offsetNearest = offsetNearestPoint.getNearestPointData(point, distanceFunction)) == null) continue;
                if (nearest == null) {
                    nearest = offsetNearest;
                    continue;
                }
                if (!(distanceFunction.apply(nearest.point(), point) > distanceFunction.apply(offsetNearest.point(), point))) continue;
                nearest = offsetNearest;
            }
            if (i <= 0 || nearest == null) continue;
            return nearest;
        }
        return nearest;
    }

    @Override
    public Collection<NearestPoint.PointData<T>> getPointDataWithinRange(class_2382 point, double radius, NearestPoint.DistanceFunction distanceFunction) {
        int x = point.method_10263();
        int y = point.method_10264();
        int z = point.method_10260();
        TreeSet<NearestPoint.PointData<T>> points = new TreeSet<NearestPoint.PointData<T>>(Comparator.comparing(pointData -> distanceFunction.apply(point, pointData.point())));
        int xIndex = this.getXYZIndex(x);
        int yIndex = this.getXYZIndex(y);
        int zIndex = this.getXYZIndex(z);
        for (int i = 0; i < this.rowSize(); ++i) {
            int[][] distance;
            for (int[] position : distance = SPIRAL_FAST_3D[i]) {
                NearestPoint.PointData<T> offsetNearest;
                NearestPoint<T> offsetNearestPoint;
                int offsetX = position[0];
                int offsetY = position[1];
                int offsetZ = position[2];
                int offsetXIndex = offsetX + xIndex;
                int offsetYIndex = offsetY + yIndex;
                int offsetZIndex = offsetZ + zIndex;
                if (offsetXIndex < 0 || offsetXIndex >= this.rowSize() || offsetYIndex < 0 || offsetYIndex >= this.rowSize() || offsetZIndex < 0 || offsetZIndex >= this.rowSize() || (offsetNearestPoint = this.leafs[this.getIndex(offsetXIndex, offsetYIndex, offsetZIndex)]) == null || !(distanceFunction.apply((offsetNearest = offsetNearestPoint.getNearestPointData(point, distanceFunction)).point(), point) <= radius)) continue;
                points.add(offsetNearest);
            }
        }
        return points;
    }

    private int getXYZIndex(int coord) {
        return coord >> this.highestShiftScale - this.bitShiftScale & this.rowSize() - 1;
    }

    public int rowSize() {
        return this.leafs.length >> 2;
    }

    private int getIndex(int x, int y, int z) {
        return x * (this.rowSize() * this.rowSize()) + y * this.rowSize() + z;
    }

    @Override
    public boolean isEmpty() {
        for (NearestPoint<T> leaf : this.leafs) {
            if (leaf == null || leaf.isEmpty()) continue;
            return false;
        }
        return true;
    }

    @Override
    public Collection<NearestPoint.PointData<T>> getAllPointData() {
        ArrayList<NearestPoint.PointData<T>> points = new ArrayList<NearestPoint.PointData<T>>();
        for (NearestPoint<T> leaf : this.leafs) {
            if (leaf == null || leaf.isEmpty()) continue;
            points.addAll(leaf.getAllPointData());
        }
        return Collections.unmodifiableList(points);
    }

    @Override
    public void clear() {
        Arrays.fill(this.leafs, null);
    }

    @Override
    public void removePoint(class_2382 point) {
        int x = point.method_10263();
        int y = point.method_10264();
        int z = point.method_10260();
        int xIndex = this.getXYZIndex(x);
        int yIndex = this.getXYZIndex(y);
        int zIndex = this.getXYZIndex(z);
        this.removePointRecursively(point, this.getIndex(xIndex, yIndex, zIndex));
    }

    private void removePointRecursively(class_2382 point, int index) {
        NearestPoint<T> nearestPoint = this.leafs[index];
        if (this.bitShiftScale == this.highestShiftScale) {
            if (nearestPoint != null) {
                this.leafs[index] = null;
            }
            return;
        }
        if (nearestPoint != null) {
            nearestPoint.removePoint(point);
            if (nearestPoint.isEmpty()) {
                this.leafs[index] = null;
            }
        }
    }

    public record Target<T>(NearestPoint.PointData<T> pointData) implements NearestPoint<T>
    {
        @Override
        public void setPoint(class_2382 point, T o) {
            throw new IllegalArgumentException("Cannot set lowest level point, use constructor.");
        }

        @Override
        public NearestPoint.PointData<T> getNearestPointData(class_2382 point, NearestPoint.DistanceFunction distanceFunction) {
            return this.pointData;
        }

        @Override
        public Collection<NearestPoint.PointData<T>> getPointDataWithinRange(class_2382 point, double radius, NearestPoint.DistanceFunction distanceFunction) {
            return distanceFunction.apply(point, this.pointData.point()) <= radius ? Collections.singleton(this.pointData) : Collections.emptyList();
        }

        @Override
        public boolean isEmpty() {
            return false;
        }

        @Override
        public Collection<NearestPoint.PointData<T>> getAllPointData() {
            return Collections.singleton(this.pointData);
        }

        @Override
        public void clear() {
            throw new IllegalArgumentException("Cannot clear lowest level point");
        }

        @Override
        public void removePoint(class_2382 point) {
            throw new IllegalArgumentException("Cannot remove lowest level point");
        }
    }
}

