package com.zurrtum.create.content.trains.track;

import com.mojang.serialization.Codec;
import com.mojang.serialization.codecs.RecordCodecBuilder;
import com.zurrtum.create.AllBlocks;
import com.zurrtum.create.AllItems;
import com.zurrtum.create.AllTrackMaterials;
import com.zurrtum.create.catnip.data.Couple;
import com.zurrtum.create.catnip.data.Iterate;
import com.zurrtum.create.catnip.data.Pair;
import com.zurrtum.create.catnip.math.VecHelper;
import com.zurrtum.create.foundation.codec.CreateCodecs;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import net.minecraft.class_11368;
import net.minecraft.class_11372;
import net.minecraft.class_1542;
import net.minecraft.class_1657;
import net.minecraft.class_1661;
import net.minecraft.class_1799;
import net.minecraft.class_1928;
import net.minecraft.class_1937;
import net.minecraft.class_2338;
import net.minecraft.class_2350.class_2351;
import net.minecraft.class_238;
import net.minecraft.class_2388;
import net.minecraft.class_2398;
import net.minecraft.class_243;
import net.minecraft.class_2540;
import net.minecraft.class_3218;
import net.minecraft.class_3532;

public class BezierConnection implements Iterable<BezierConnection.Segment> {
    public static final Codec<BezierConnection> CODEC = RecordCodecBuilder.create(instance -> instance.group(
        CreateCodecs.COUPLE_BLOCK_POS_CODEC.fieldOf("Positions").forGetter(i -> i.bePositions),
        CreateCodecs.COUPLE_VEC3D_CODEC.fieldOf("Starts").forGetter(i -> i.starts),
        CreateCodecs.COUPLE_VEC3D_CODEC.fieldOf("Axes").forGetter(i -> i.axes),
        CreateCodecs.COUPLE_VEC3D_CODEC.fieldOf("Normals").forGetter(i -> i.normals),
        Codec.BOOL.fieldOf("Primary").forGetter(i -> i.primary),
        Codec.BOOL.fieldOf("Girder").forGetter(i -> i.hasGirder),
        TrackMaterial.CODEC.fieldOf("Material").forGetter(i -> i.trackMaterial),
        CreateCodecs.COUPLE_INT_CODEC.optionalFieldOf("Smoothing").forGetter(i -> Optional.ofNullable(i.smoothing))
    ).apply(instance, BezierConnection::new));

    public final Couple<class_2338> bePositions;
    public final Couple<class_243> starts;
    public final Couple<class_243> axes;
    public final Couple<class_243> normals;
    @Nullable
    public Couple<Integer> smoothing;
    public final boolean primary;
    public final boolean hasGirder;
    protected TrackMaterial trackMaterial;

    // runtime
    private final AtomicReference<@Nullable Runtime> lazyRuntime = new AtomicReference<>(null);

    public BezierConnection(
        Couple<class_2338> positions,
        Couple<class_243> starts,
        Couple<class_243> axes,
        Couple<class_243> normals,
        boolean primary,
        boolean girder,
        TrackMaterial material
    ) {
        bePositions = positions;
        this.starts = starts;
        this.axes = axes;
        this.normals = normals;
        this.primary = primary;
        this.hasGirder = girder;
        this.trackMaterial = material;
    }

    public BezierConnection secondary() {
        BezierConnection bezierConnection = new BezierConnection(
            bePositions.swap(),
            starts.swap(),
            axes.swap(),
            normals.swap(),
            !primary,
            hasGirder,
            trackMaterial
        );
        if (smoothing != null)
            bezierConnection.smoothing = smoothing.swap();
        return bezierConnection;
    }

    public BezierConnection clone() {
        var out = new BezierConnection(bePositions.copy(), starts.copy(), axes.copy(), normals.copy(), primary, hasGirder, trackMaterial);
        if (smoothing != null) {
            out.smoothing = smoothing.copy();
        }
        return out;
    }

    private static boolean coupleEquals(Couple<?> a, Couple<?> b) {
        return (a.getFirst().equals(b.getFirst()) && a.getSecond()
            .equals(b.getSecond())) || (a.getFirst() instanceof class_243 aFirst && a.getSecond() instanceof class_243 aSecond && b.getFirst() instanceof class_243 bFirst && b.getSecond() instanceof class_243 bSecond && aFirst.method_24802(bFirst,
            1e-6
        ) && aSecond.method_24802(bSecond, 1e-6));
    }

    public boolean equalsSansMaterial(BezierConnection other) {
        return equalsSansMaterialInner(other) || equalsSansMaterialInner(other.secondary());
    }

    private boolean equalsSansMaterialInner(BezierConnection other) {
        return this == other || (other != null && coupleEquals(this.bePositions, other.bePositions) && coupleEquals(
            this.starts,
            other.starts
        ) && coupleEquals(this.axes, other.axes) && coupleEquals(
            this.normals,
            other.normals
        ) && this.hasGirder == other.hasGirder);
    }

    @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
    public BezierConnection(
        Couple<class_2338> bePositions,
        Couple<class_243> starts,
        Couple<class_243> axes,
        Couple<class_243> normals,
        boolean primary,
        boolean hasGirder,
        TrackMaterial trackMaterial,
        Optional<Couple<Integer>> smoothing
    ) {
        this(bePositions, starts, axes, normals, primary, hasGirder, trackMaterial);
        this.smoothing = smoothing.orElse(null);
    }

    public BezierConnection(class_11368 view, class_2338 localTo) {
        this(
            view.method_71426("Positions", CreateCodecs.COUPLE_BLOCK_POS_CODEC).orElseThrow().map(b -> b.method_10081(localTo)),
            view.method_71426("Starts", CreateCodecs.COUPLE_VEC3D_CODEC).orElseThrow().map(v -> v.method_1019(class_243.method_24954(localTo))),
            view.method_71426("Axes", CreateCodecs.COUPLE_VEC3D_CODEC).orElseThrow(),
            view.method_71426("Normals", CreateCodecs.COUPLE_VEC3D_CODEC).orElseThrow(),
            view.method_71433("Primary", false),
            view.method_71433("Girder", false),
            view.method_71426("Material", TrackMaterial.CODEC).orElse(AllTrackMaterials.ANDESITE)
        );

        view.method_71426("Smoothing", CreateCodecs.COUPLE_INT_CODEC).ifPresent(couple -> smoothing = couple);
    }

    public void write(class_11372 view, class_2338 localTo) {
        Couple<class_2338> tePositions = this.bePositions.map(b -> b.method_10059(localTo));
        Couple<class_243> starts = this.starts.map(v -> v.method_1020(class_243.method_24954(localTo)));

        view.method_71472("Girder", hasGirder);
        view.method_71472("Primary", primary);
        view.method_71468("Positions", CreateCodecs.COUPLE_BLOCK_POS_CODEC, tePositions);
        view.method_71468("Starts", CreateCodecs.COUPLE_VEC3D_CODEC, starts);
        view.method_71468("Axes", CreateCodecs.COUPLE_VEC3D_CODEC, axes);
        view.method_71468("Normals", CreateCodecs.COUPLE_VEC3D_CODEC, normals);
        view.method_71468("Material", TrackMaterial.CODEC, getMaterial());

        if (smoothing != null)
            view.method_71468("Smoothing", CreateCodecs.COUPLE_INT_CODEC, smoothing);
    }

    public BezierConnection(class_2540 buffer) {
        this(
            Couple.create(buffer::method_10811),
            Couple.create(() -> VecHelper.read(buffer)),
            Couple.create(() -> VecHelper.read(buffer)),
            Couple.create(() -> VecHelper.read(buffer)),
            buffer.readBoolean(),
            buffer.readBoolean(),
            TrackMaterial.fromId(buffer.method_10810())
        );
        if (buffer.readBoolean())
            smoothing = Couple.create(buffer::method_10816);
    }

    public void write(class_2540 buffer) {
        bePositions.forEach(buffer::method_10807);
        starts.forEach(v -> VecHelper.write(v, buffer));
        axes.forEach(v -> VecHelper.write(v, buffer));
        normals.forEach(v -> VecHelper.write(v, buffer));
        buffer.method_52964(primary);
        buffer.method_52964(hasGirder);
        buffer.method_10812(getMaterial().getId());
        buffer.method_52964(smoothing != null);
        if (smoothing != null)
            smoothing.forEach(buffer::method_10804);
    }

    public class_2338 getKey() {
        return bePositions.getSecond();
    }

    public boolean isPrimary() {
        return primary;
    }

    public int yOffsetAt(class_243 end) {
        if (smoothing == null)
            return 0;
        if (TrackBlockEntityTilt.compareHandles(starts.getFirst(), end))
            return smoothing.getFirst();
        if (TrackBlockEntityTilt.compareHandles(starts.getSecond(), end))
            return smoothing.getSecond();
        return 0;
    }

    // Runtime information

    public double getLength() {
        return resolve().length;
    }

    public float[] getStepLUT() {
        return resolve().stepLUT;
    }

    public int getSegmentCount() {
        return resolve().segments;
    }

    public class_243 getPosition(double t) {
        var runtime = resolve();
        return VecHelper.bezier(starts.getFirst(), starts.getSecond(), runtime.finish1, runtime.finish2, (float) t);
    }

    public double getRadius() {
        return resolve().radius;
    }

    public double getHandleLength() {
        return resolve().handleLength;
    }

    public float getSegmentT(int index) {
        return resolve().getSegmentT(index);
    }

    public double incrementT(double currentT, double distance) {
        var runtime = resolve();
        double dx = VecHelper.bezierDerivative(starts.getFirst(), starts.getSecond(), runtime.finish1, runtime.finish2, (float) currentT)
            .method_1033() / getLength();
        return currentT + distance / dx;
    }

    public class_238 getBounds() {
        return resolve().bounds;
    }

    public class_243 getNormal(double t) {
        var runtime = resolve();
        class_243 end1 = starts.getFirst();
        class_243 end2 = starts.getSecond();
        class_243 fn1 = normals.getFirst();
        class_243 fn2 = normals.getSecond();

        class_243 derivative = VecHelper.bezierDerivative(end1, end2, runtime.finish1, runtime.finish2, (float) t).method_1029();
        class_243 faceNormal = fn1.equals(fn2) ? fn1 : VecHelper.slerp((float) t, fn1, fn2);
        class_243 normal = faceNormal.method_1036(derivative).method_1029();
        return derivative.method_1036(normal);
    }

    @NotNull
    private Runtime resolve() {
        var out = lazyRuntime.get();

        if (out == null) {
            // Since this can be accessed from multiple threads, we consolidate the intermediary
            // computation into a class and only publish complete results.
            out = new Runtime(starts, axes);
            // Doesn't matter if this one becomes the canonical value because all results are the same.
            lazyRuntime.set(out);
        }

        return out;
    }

    @Override
    public Iterator<Segment> iterator() {
        var offset = class_243.method_24954(bePositions.getFirst()).method_1021(-1).method_1031(0, 3 / 16f, 0);
        return new Bezierator(this, offset);
    }

    public void addItemsToPlayer(class_1657 player) {
        class_1661 inv = player.method_31548();
        int tracks = getTrackItemCost();
        while (tracks > 0) {
            inv.method_7398(new class_1799(getMaterial().getBlock(), Math.min(64, tracks)));
            tracks -= 64;
        }
        int girders = getGirderItemCost();
        while (girders > 0) {
            inv.method_7398(new class_1799(AllItems.METAL_GIRDER, Math.min(64, girders)));
            girders -= 64;
        }
    }

    public int getGirderItemCost() {
        return hasGirder ? getTrackItemCost() * 2 : 0;
    }

    public int getTrackItemCost() {
        return (getSegmentCount() + 1) / 2;
    }

    public void spawnItems(class_1937 level) {
        if (!(level instanceof class_3218 serverWorld) || !serverWorld.method_64395().method_8355(class_1928.field_19392))
            return;
        class_243 origin = class_243.method_24954(bePositions.getFirst());
        for (Segment segment : this) {
            if (segment.index % 2 != 0 || segment.index == getSegmentCount())
                continue;
            class_243 v = VecHelper.offsetRandomly(segment.position, level.field_9229, .125f).method_1019(origin);
            class_1542 entity = new class_1542(level, v.field_1352, v.field_1351, v.field_1350, new class_1799(getMaterial()));
            entity.method_6988();
            level.method_8649(entity);
            if (!hasGirder)
                continue;
            for (int i = 0; i < 2; i++) {
                entity = new class_1542(level, v.field_1352, v.field_1351, v.field_1350, AllItems.METAL_GIRDER.method_7854());
                entity.method_6988();
                level.method_8649(entity);
            }
        }
    }

    public void spawnDestroyParticles(class_1937 level) {
        if (!(level instanceof class_3218 slevel))
            return;
        class_2388 data = new class_2388(class_2398.field_11217, getMaterial().getBlock().method_9564());
        class_2388 girderData = new class_2388(class_2398.field_11217, AllBlocks.METAL_GIRDER.method_9564());
        class_243 origin = class_243.method_24954(bePositions.getFirst());
        for (Segment segment : this) {
            for (int offset : Iterate.positiveAndNegative) {
                class_243 v = segment.position.method_1019(segment.normal.method_1021(14 / 16f * offset)).method_1019(origin);
                slevel.method_65096(data, v.field_1352, v.field_1351, v.field_1350, 1, 0, 0, 0, 0);
                if (!hasGirder)
                    continue;
                slevel.method_65096(girderData, v.field_1352, v.field_1351 - .5f, v.field_1350, 1, 0, 0, 0, 0);
            }
        }
    }

    public TrackMaterial getMaterial() {
        return trackMaterial;
    }

    public void setMaterial(TrackMaterial material) {
        trackMaterial = material;
    }

    private static class Runtime {
        private final class_243 finish1;
        private final class_243 finish2;
        private final double length;
        private final float[] stepLUT;
        private final int segments;

        private double radius;
        private double handleLength;

        private final class_238 bounds;

        private Runtime(Couple<class_243> starts, Couple<class_243> axes) {
            class_243 end1 = starts.getFirst();
            class_243 end2 = starts.getSecond();
            class_243 axis1 = axes.getFirst().method_1029();
            class_243 axis2 = axes.getSecond().method_1029();

            determineHandles(end1, end2, axis1, axis2);

            finish1 = axis1.method_1021(handleLength).method_1019(end1);
            finish2 = axis2.method_1021(handleLength).method_1019(end2);

            int scanCount = 16;

            this.length = computeLength(finish1, finish2, end1, end2, scanCount);

            segments = (int) (length * 2);
            stepLUT = new float[segments + 1];
            stepLUT[0] = 1;
            float combinedDistance = 0;

            class_238 bounds = new class_238(end1, end2);

            // determine step lut
            {
                class_243 previous = end1;
                for (int i = 0; i <= segments; i++) {
                    float t = i / (float) segments;
                    class_243 result = VecHelper.bezier(end1, end2, finish1, finish2, t);
                    bounds = bounds.method_991(new class_238(result, result));
                    if (i > 0) {
                        combinedDistance += result.method_1022(previous) / length;
                        stepLUT[i] = (float) (t / combinedDistance);
                    }
                    previous = result;
                }
            }

            this.bounds = bounds.method_1014(1.375f);
        }

        private static double computeLength(class_243 finish1, class_243 finish2, class_243 end1, class_243 end2, int scanCount) {
            double length = 0;

            class_243 previous = end1;
            for (int i = 0; i <= scanCount; i++) {
                float t = i / (float) scanCount;
                class_243 result = VecHelper.bezier(end1, end2, finish1, finish2, t);
                if (previous != null)
                    length += result.method_1022(previous);
                previous = result;
            }
            return length;
        }

        public float getSegmentT(int index) {
            return index == segments ? 1 : index * stepLUT[index] / segments;
        }

        private void determineHandles(class_243 end1, class_243 end2, class_243 axis1, class_243 axis2) {
            class_243 cross1 = axis1.method_1036(new class_243(0, 1, 0));
            class_243 cross2 = axis2.method_1036(new class_243(0, 1, 0));

            radius = 0;
            double a1 = class_3532.method_15349(-axis2.field_1350, -axis2.field_1352);
            double a2 = class_3532.method_15349(axis1.field_1350, axis1.field_1352);
            double angle = a1 - a2;

            float circle = 2 * class_3532.field_29844;
            angle = (angle + circle) % circle;
            if (Math.abs(circle - angle) < Math.abs(angle))
                angle = circle - angle;

            if (class_3532.method_20390(angle, 0)) {
                double[] intersect = VecHelper.intersect(end1, end2, axis1, cross2, class_2351.field_11052);
                if (intersect != null) {
                    double t = Math.abs(intersect[0]);
                    double u = Math.abs(intersect[1]);
                    double min = Math.min(t, u);
                    double max = Math.max(t, u);

                    if (min > 1.2 && max / min > 1 && max / min < 3) {
                        handleLength = (max - min);
                        return;
                    }
                }

                handleLength = end2.method_1022(end1) / 3;
                return;
            }

            double n = circle / angle;
            double factor = 4 / 3d * Math.tan(Math.PI / (2 * n));
            double[] intersect = VecHelper.intersect(end1, end2, cross1, cross2, class_2351.field_11052);

            if (intersect == null) {
                handleLength = end2.method_1022(end1) / 3;
                return;
            }

            radius = Math.abs(intersect[1]);
            handleLength = radius * factor;
            if (class_3532.method_20390(handleLength, 0))
                handleLength = 1;
        }
    }

    public static class Segment {

        public int index;
        public class_243 position;
        public class_243 derivative;
        public class_243 faceNormal;
        public class_243 normal;

    }

    private static class Bezierator implements Iterator<Segment> {
        private final Segment segment;
        private final class_243 end1;
        private final class_243 end2;
        private final class_243 finish1;
        private final class_243 finish2;
        private final class_243 faceNormal1;
        private final class_243 faceNormal2;
        private final Runtime runtime;

        private Bezierator(BezierConnection bc, class_243 offset) {
            runtime = bc.resolve();

            end1 = bc.starts.getFirst().method_1019(offset);
            end2 = bc.starts.getSecond().method_1019(offset);

            finish1 = bc.axes.getFirst().method_1021(runtime.handleLength).method_1019(end1);
            finish2 = bc.axes.getSecond().method_1021(runtime.handleLength).method_1019(end2);

            faceNormal1 = bc.normals.getFirst();
            faceNormal2 = bc.normals.getSecond();
            segment = new Segment();
            segment.index = -1; // will get incremented to 0 in #next()
        }

        @Override
        public boolean hasNext() {
            return segment.index + 1 <= runtime.segments;
        }

        @Override
        public Segment next() {
            segment.index++;
            float t = runtime.getSegmentT(segment.index);
            segment.position = VecHelper.bezier(end1, end2, finish1, finish2, t);
            segment.derivative = VecHelper.bezierDerivative(end1, end2, finish1, finish2, t).method_1029();
            segment.faceNormal = faceNormal1.equals(faceNormal2) ? faceNormal1 : VecHelper.slerp(t, faceNormal1, faceNormal2);
            segment.normal = segment.faceNormal.method_1036(segment.derivative).method_1029();
            return segment;
        }
    }

    public Object bakedSegments;
    public Object bakedGirders;

    @SuppressWarnings("unchecked")
    public <T> T getBakedSegments(Function<BezierConnection, T> factory) {
        if (bakedSegments != null) {
            return (T) bakedSegments;
        }
        T segments = factory.apply(this);
        bakedSegments = segments;
        return segments;
    }

    @SuppressWarnings("unchecked")
    public <T> T getBakedGirders(Function<BezierConnection, T> factory) {
        if (bakedGirders != null) {
            return (T) bakedGirders;
        }
        T girders = factory.apply(this);
        bakedGirders = girders;
        return girders;
    }

    public Map<Pair<Integer, Integer>, Double> rasterise() {
        Map<Pair<Integer, Integer>, Double> yLevels = new HashMap<>();
        class_2338 tePosition = bePositions.getFirst();
        class_243 end1 = starts.getFirst().method_1020(class_243.method_24954(tePosition)).method_1031(0, 3 / 16f, 0);
        class_243 end2 = starts.getSecond().method_1020(class_243.method_24954(tePosition)).method_1031(0, 3 / 16f, 0);
        class_243 axis1 = axes.getFirst();
        class_243 axis2 = axes.getSecond();

        double handleLength = getHandleLength();
        class_243 finish1 = axis1.method_1021(handleLength).method_1019(end1);
        class_243 finish2 = axis2.method_1021(handleLength).method_1019(end2);

        class_243 faceNormal1 = normals.getFirst();
        class_243 faceNormal2 = normals.getSecond();

        int segCount = getSegmentCount();
        float[] lut = getStepLUT();
        class_243[] samples = new class_243[segCount];

        for (int i = 0; i < segCount; i++) {
            float t = class_3532.method_15363((i + 0.5f) * lut[i] / segCount, 0, 1);
            class_243 result = VecHelper.bezier(end1, end2, finish1, finish2, t);
            class_243 derivative = VecHelper.bezierDerivative(end1, end2, finish1, finish2, t).method_1029();
            class_243 faceNormal = faceNormal1.equals(faceNormal2) ? faceNormal1 : VecHelper.slerp(t, faceNormal1, faceNormal2);
            class_243 normal = faceNormal.method_1036(derivative).method_1029();
            class_243 below = result.method_1019(faceNormal.method_1021(-.25f));
            class_243 rail1 = below.method_1019(normal.method_1021(.05f));
            class_243 rail2 = below.method_1020(normal.method_1021(.05f));
            class_243 railMiddle = rail1.method_1019(rail2).method_1021(.5);
            samples[i] = railMiddle;
        }

        class_243 center = end1.method_1019(end2).method_1021(0.5);

        Pair<Integer, Integer> prev = null;
        Pair<Integer, Integer> prev2 = null;
        Pair<Integer, Integer> prev3 = null;

        for (int i = 0; i < segCount; i++) {
            class_243 railMiddle = samples[i];
            class_2338 pos = class_2338.method_49638(railMiddle);
            Pair<Integer, Integer> key = Pair.of(pos.method_10263(), pos.method_10260());
            boolean alreadyPresent = yLevels.containsKey(key);
            if (alreadyPresent && yLevels.get(key) <= railMiddle.field_1351)
                continue;
            yLevels.put(key, railMiddle.field_1351);
            if (alreadyPresent)
                continue;

            if (prev3 != null) { // Remove obsolete pixels
                boolean doubledViaPrev = isLineDoubled(prev2, prev, key);
                boolean doubledViaPrev2 = isLineDoubled(prev3, prev2, prev);
                boolean prevCloser = diff(prev, center) > diff(prev2, center);

                if (doubledViaPrev2 && (!doubledViaPrev || !prevCloser)) {
                    yLevels.remove(prev2);
                    prev2 = prev;
                    prev = key;
                    continue;

                } else if (doubledViaPrev && doubledViaPrev2 && prevCloser) {
                    yLevels.remove(prev);
                    prev = key;
                    continue;
                }
            }

            prev3 = prev2;
            prev2 = prev;
            prev = key;
        }

        return yLevels;
    }

    private double diff(Pair<Integer, Integer> pFrom, class_243 to) {
        return to.method_1028(pFrom.getFirst() + 0.5, to.field_1351, pFrom.getSecond() + 0.5);
    }

    private boolean isLineDoubled(Pair<Integer, Integer> pFrom, Pair<Integer, Integer> pVia, Pair<Integer, Integer> pTo) {
        int diff1x = pVia.getFirst() - pFrom.getFirst();
        int diff1z = pVia.getSecond() - pFrom.getSecond();
        int diff2x = pTo.getFirst() - pVia.getFirst();
        int diff2z = pTo.getSecond() - pVia.getSecond();
        return Math.abs(diff1x) + Math.abs(diff1z) == 1 && Math.abs(diff2x) + Math.abs(diff2z) == 1 && diff1x != diff2x && diff1z != diff2z;
    }

}
