package com.zurrtum.create.client.catnip.outliner;

import com.zurrtum.create.catnip.data.Iterate;
import com.zurrtum.create.client.catnip.render.BindableTexture;
import com.zurrtum.create.client.catnip.render.PonderRenderTypes;
import com.zurrtum.create.client.catnip.render.SuperRenderTypeBuffer;
import org.joml.Vector3f;
import org.joml.Vector4f;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import net.minecraft.class_1921;
import net.minecraft.class_2338;
import net.minecraft.class_2350;
import net.minecraft.class_2350.class_2351;
import net.minecraft.class_2350.class_2352;
import net.minecraft.class_243;
import net.minecraft.class_310;
import net.minecraft.class_4587;
import net.minecraft.class_4588;

public class BlockClusterOutline extends Outline {

    private final Cluster cluster;

    protected final Vector3f pos0Temp = new Vector3f();
    protected final Vector3f pos1Temp = new Vector3f();
    protected final Vector3f pos2Temp = new Vector3f();
    protected final Vector3f pos3Temp = new Vector3f();
    protected final Vector3f normalTemp = new Vector3f();
    protected final Vector3f originTemp = new Vector3f();

    public BlockClusterOutline(Iterable<class_2338> positions) {
        cluster = new Cluster();
        positions.forEach(cluster::include);
    }

    @Override
    public void render(class_310 mc, class_4587 ms, SuperRenderTypeBuffer buffer, class_243 camera, float pt) {
        params.loadColor(colorTemp);
        Vector4f color = colorTemp;
        int lightmap = params.lightmap;
        boolean disableLineNormals = params.disableLineNormals;

        renderFaces(ms, buffer, camera, pt, color, lightmap);
        renderEdges(ms, buffer, camera, pt, color, lightmap, disableLineNormals);
    }

    protected void renderFaces(class_4587 ms, SuperRenderTypeBuffer buffer, class_243 camera, float pt, Vector4f color, int lightmap) {
        BindableTexture faceTexture = params.faceTexture;
        if (faceTexture == null)
            return;
        if (cluster.isEmpty())
            return;

        ms.method_22903();
        ms.method_22904(cluster.anchor.method_10263() - camera.field_1352, cluster.anchor.method_10264() - camera.field_1351, cluster.anchor.method_10260() - camera.field_1350);

        class_4587.class_4665 pose = ms.method_23760();
        class_1921 renderType = PonderRenderTypes.outlineTranslucent(faceTexture.getLocation(), true);
        class_4588 consumer = buffer.getLateBuffer(renderType);

        cluster.visibleFaces.forEach((face, axisDirection) -> {
            class_2350 direction = class_2350.method_10156(axisDirection, face.axis);
            class_2338 pos = face.pos;
            if (axisDirection == class_2352.field_11056)
                pos = pos.method_10093(direction.method_10153());
            bufferBlockFace(pose, consumer, pos, direction, color, lightmap);
        });

        ms.method_22909();
    }

    protected void renderEdges(
        class_4587 ms,
        SuperRenderTypeBuffer buffer,
        class_243 camera,
        float pt,
        Vector4f color,
        int lightmap,
        boolean disableNormals
    ) {
        float lineWidth = params.getLineWidth();
        if (lineWidth == 0)
            return;
        if (cluster.isEmpty())
            return;

        ms.method_22903();
        ms.method_22904(cluster.anchor.method_10263() - camera.field_1352, cluster.anchor.method_10264() - camera.field_1351, cluster.anchor.method_10260() - camera.field_1350);

        class_4587.class_4665 pose = ms.method_23760();
        class_4588 consumer = buffer.getBuffer(PonderRenderTypes.outlineSolid());

        cluster.visibleEdges.forEach(edge -> {
            class_2338 pos = edge.pos;
            Vector3f origin = originTemp;
            origin.set(pos.method_10263(), pos.method_10264(), pos.method_10260());
            class_2350 direction = class_2350.method_10156(class_2352.field_11056, edge.axis);
            bufferCuboidLine(pose, consumer, origin, direction, 1, lineWidth, color, lightmap, disableNormals);
        });

        ms.method_22909();
    }

    public static void loadFaceData(class_2350 face, Vector3f pos0, Vector3f pos1, Vector3f pos2, Vector3f pos3, Vector3f normal) {
        switch (face) {
            case field_11033 -> {
                // 0 1 2 3
                pos0.set(0, 0, 1);
                pos1.set(0, 0, 0);
                pos2.set(1, 0, 0);
                pos3.set(1, 0, 1);
                normal.set(0, -1, 0);
            }
            case field_11036 -> {
                // 4 5 6 7
                pos0.set(0, 1, 0);
                pos1.set(0, 1, 1);
                pos2.set(1, 1, 1);
                pos3.set(1, 1, 0);
                normal.set(0, 1, 0);
            }
            case field_11043 -> {
                // 7 2 1 4
                pos0.set(1, 1, 0);
                pos1.set(1, 0, 0);
                pos2.set(0, 0, 0);
                pos3.set(0, 1, 0);
                normal.set(0, 0, -1);
            }
            case field_11035 -> {
                // 5 0 3 6
                pos0.set(0, 1, 1);
                pos1.set(0, 0, 1);
                pos2.set(1, 0, 1);
                pos3.set(1, 1, 1);
                normal.set(0, 0, 1);
            }
            case field_11039 -> {
                // 4 1 0 5
                pos0.set(0, 1, 0);
                pos1.set(0, 0, 0);
                pos2.set(0, 0, 1);
                pos3.set(0, 1, 1);
                normal.set(-1, 0, 0);
            }
            case field_11034 -> {
                // 6 3 2 7
                pos0.set(1, 1, 1);
                pos1.set(1, 0, 1);
                pos2.set(1, 0, 0);
                pos3.set(1, 1, 0);
                normal.set(1, 0, 0);
            }
        }
    }

    public static void addPos(float x, float y, float z, Vector3f pos0, Vector3f pos1, Vector3f pos2, Vector3f pos3) {
        pos0.add(x, y, z);
        pos1.add(x, y, z);
        pos2.add(x, y, z);
        pos3.add(x, y, z);
    }

    protected void bufferBlockFace(class_4587.class_4665 pose, class_4588 consumer, class_2338 pos, class_2350 face, Vector4f color, int lightmap) {
        Vector3f pos0 = pos0Temp;
        Vector3f pos1 = pos1Temp;
        Vector3f pos2 = pos2Temp;
        Vector3f pos3 = pos3Temp;
        Vector3f normal = normalTemp;

        loadFaceData(face, pos0, pos1, pos2, pos3, normal);
        addPos(
            pos.method_10263() + face.method_10148() * 1 / 128f,
            pos.method_10264() + face.method_10164() * 1 / 128f,
            pos.method_10260() + face.method_10165() * 1 / 128f,
            pos0,
            pos1,
            pos2,
            pos3
        );

        bufferQuad(pose, consumer, pos0, pos1, pos2, pos3, color, lightmap, normal);
    }

    private static class Cluster {

        private class_2338 anchor;
        private Map<MergeEntry, class_2352> visibleFaces;
        private Set<MergeEntry> visibleEdges;

        public Cluster() {
            visibleEdges = new HashSet<>();
            visibleFaces = new HashMap<>();
        }

        public boolean isEmpty() {
            return anchor == null;
        }

        public void include(class_2338 pos) {
            if (anchor == null)
                anchor = pos;

            pos = pos.method_10059(anchor);

            // 6 FACES
            for (class_2351 axis : Iterate.axes) {
                class_2350 direction = class_2350.method_10156(class_2352.field_11056, axis);
                for (int offset : Iterate.zeroAndOne) {
                    MergeEntry entry = new MergeEntry(axis, pos.method_10079(direction, offset));
                    if (visibleFaces.remove(entry) == null)
                        visibleFaces.put(entry, offset == 0 ? class_2352.field_11060 : class_2352.field_11056);
                }
            }

            // 12 EDGES
            for (class_2351 axis : Iterate.axes) {
                for (class_2351 axis2 : Iterate.axes) {
                    if (axis == axis2)
                        continue;
                    for (class_2351 axis3 : Iterate.axes) {
                        if (axis == axis3)
                            continue;
                        if (axis2 == axis3)
                            continue;

                        class_2350 direction = class_2350.method_10156(class_2352.field_11056, axis2);
                        class_2350 direction2 = class_2350.method_10156(class_2352.field_11056, axis3);

                        for (int offset : Iterate.zeroAndOne) {
                            class_2338 entryPos = pos.method_10079(direction, offset);
                            for (int offset2 : Iterate.zeroAndOne) {
                                entryPos = entryPos.method_10079(direction2, offset2);
                                MergeEntry entry = new MergeEntry(axis, entryPos);
                                if (!visibleEdges.remove(entry))
                                    visibleEdges.add(entry);
                            }
                        }
                    }

                    break;
                }
            }

        }

    }

    private static class MergeEntry {

        private class_2351 axis;
        private class_2338 pos;

        public MergeEntry(class_2351 axis, class_2338 pos) {
            this.axis = axis;
            this.pos = pos;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o)
                return true;
            if (!(o instanceof MergeEntry))
                return false;

            MergeEntry other = (MergeEntry) o;
            return this.axis == other.axis && this.pos.equals(other.pos);
        }

        @Override
        public int hashCode() {
            return this.pos.hashCode() * 31 + axis.ordinal();
        }
    }

}
