package dev.doublekekse.map_utils.curve;

import java.util.ArrayList;
import java.util.List;
import net.minecraft.class_241;
import net.minecraft.class_243;
import net.minecraft.class_2487;
import net.minecraft.class_2499;
import net.minecraft.class_2540;
import net.minecraft.class_3532;
import net.minecraft.class_9135;
import net.minecraft.class_9139;

public record SplinePath(List<SplineControlPoint> controlPoints) {
    public static final class_9139<class_2540, SplinePath> STREAM_CODEC = class_9139.method_56434(
        class_9135.method_56376(ArrayList::new, SplineControlPoint.STREAM_CODEC), SplinePath::controlPoints,
        SplinePath::new
    );

    public SplinePath {
        if (controlPoints == null || controlPoints.isEmpty()) {
            throw new IllegalArgumentException("At least one control point is required.");
        }
    }

    public class_243 getPosition(double progress) {
        if (controlPoints.size() == 1) {
            return controlPoints.getFirst().position();
        }

        int n = controlPoints.size() - 1;
        int startIndex = (int) Math.floor(progress * n);
        startIndex = Math.max(0, Math.min(startIndex, n - 1));

        double u = progress * n - startIndex;

        class_243 p0 = (startIndex > 0) ? controlPoints.get(startIndex - 1).position() : controlPoints.getFirst().position();
        class_243 p1 = controlPoints.get(startIndex).position();
        class_243 p2 = controlPoints.get(startIndex + 1).position();
        class_243 p3 = (startIndex + 2 < controlPoints.size()) ? controlPoints.get(startIndex + 2).position() : controlPoints.get(n).position();

        return interpolateCatmullRom(p0, p1, p2, p3, u);
    }

    public class_241 getRotation(double progress) {
        if (controlPoints.size() == 1) {
            return controlPoints.getFirst().rotation();
        }

        int numRotations = controlPoints.size();
        if (numRotations < 2) {
            throw new IllegalArgumentException("At least two rotations are needed for interpolation.");
        }

        int index0 = (int) Math.floor(progress * (numRotations - 1));
        int index1 = Math.min(index0 + 1, numRotations - 1);
        double t = (progress * (numRotations - 1)) - index0;

        return lerpRotation(controlPoints.get(index0).rotation(), controlPoints.get(index1).rotation(), (float) t);
    }

    public static class_241 lerpRotation(class_241 start, class_241 end, float t) {
        return new class_241(class_3532.method_17821(t, start.field_1343, end.field_1343), class_3532.method_17821(t, start.field_1342, end.field_1342));
    }

    private static class_243 interpolateCatmullRom(class_243 p0, class_243 p1, class_243 p2, class_243 p3, double u) {
        double u2 = u * u;
        double u3 = u2 * u;

        double[] coefficients = {
            -0.5 * u3 + u2 - 0.5 * u,
            1.5 * u3 - 2.5 * u2 + 1,
            -1.5 * u3 + 2.0 * u2 + 0.5 * u,
            0.5 * u3 - 0.5 * u2
        };

        return new class_243(
            coefficients[0] * p0.field_1352 + coefficients[1] * p1.field_1352 + coefficients[2] * p2.field_1352 + coefficients[3] * p3.field_1352,
            coefficients[0] * p0.field_1351 + coefficients[1] * p1.field_1351 + coefficients[2] * p2.field_1351 + coefficients[3] * p3.field_1351,
            coefficients[0] * p0.field_1350 + coefficients[1] * p1.field_1350 + coefficients[2] * p2.field_1350 + coefficients[3] * p3.field_1350
        );
    }

    public int size() {
        return controlPoints.size();
    }

    public class_2499 write() {
        var tag = new class_2499();
        for (var controlPoint : controlPoints) {
            tag.add(controlPoint.write());
        }

        return tag;
    }

    public static SplinePath read(class_2499 tag) {
        var controlPoints = new ArrayList<SplineControlPoint>();

        for (var controlPointTag : tag) {
            controlPoints.add(SplineControlPoint.read((class_2487) controlPointTag));
        }

        return new SplinePath(controlPoints);
    }
}
