package foundry.veil.api.client.necromancer.render;

import com.mojang.blaze3d.systems.RenderSystem;
import foundry.veil.api.client.necromancer.Skeleton;
import foundry.veil.api.client.render.MatrixStack;
import foundry.veil.api.client.render.VeilRenderSystem;
import foundry.veil.api.client.render.shader.block.DynamicShaderBlock;
import foundry.veil.api.client.render.vertex.VertexArray;
import foundry.veil.api.client.render.vertex.VertexArrayBuilder;
import it.unimi.dsi.fastutil.floats.FloatList;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntMaps;
import org.jetbrains.annotations.ApiStatus;
import org.joml.*;
import org.lwjgl.system.MemoryUtil;
import org.lwjgl.system.NativeResource;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import net.minecraft.class_1921;
import net.minecraft.class_284;
import net.minecraft.class_293;
import net.minecraft.class_310;
import net.minecraft.class_3532;
import net.minecraft.class_4587;
import net.minecraft.class_5944;
import net.minecraft.class_9799;

import static org.lwjgl.opengl.ARBDirectStateAccess.glNamedBufferSubData;
import static org.lwjgl.opengl.GL15C.glBindBuffer;
import static org.lwjgl.opengl.GL15C.glBufferSubData;
import static org.lwjgl.opengl.GL30C.glUniform1ui;
import static org.lwjgl.opengl.GL31C.GL_UNIFORM_BUFFER;

public class Skin implements NativeResource {

    private final VertexArray vertexArray;
    private final Object2IntMap<String> boneIds;
    private final Vector4f color;

    private int instances;

    private final Matrix3f normalMatrix;
    private Matrix4x3f[] matrixStack;
    private Quaternionf[] orientationStack;

    public Skin(VertexArray vertexArray, Object2IntMap<String> boneIds) {
        this.vertexArray = vertexArray;
        this.boneIds = boneIds;
        this.color = new Vector4f();

        this.normalMatrix = new Matrix3f();
        this.matrixStack = null;
        this.orientationStack = null;
    }

    @ApiStatus.Internal
    public void render(class_1921 renderType, List<Matrix4x3f> transforms, List<Skeleton> skeletons, int instancedBuffer, class_9799 boneBuilder, int boneBuffer, DynamicShaderBlock<?> boneBlock, FloatList partialTicks) {
        if (skeletons.isEmpty()) {
            return;
        }

        if (this.instances != skeletons.size()) {
            VertexArrayBuilder format = this.vertexArray.editFormat();
            // The instanced buffer has to be redefined each time it changes size, so re-attach it
            format.defineVertexBuffer(1, instancedBuffer, 0, 8, 1);
            if (this.instances == 0) {
                format.setVertexIAttribute(4, 1, 1, VertexArrayBuilder.DataType.UNSIGNED_BYTE, 0); // Overlay Coordinates
                format.setVertexIAttribute(5, 1, 1, VertexArrayBuilder.DataType.UNSIGNED_BYTE, 1); // Lightmap Coordinates
                format.setVertexAttribute(6, 1, 4, VertexArrayBuilder.DataType.UNSIGNED_BYTE, true, 2); // Color
            }
            this.instances = skeletons.size();
        }

        Skeleton first = skeletons.getFirst();
        int maxDepth = first.getMaxDepth();
        if (this.matrixStack == null || this.matrixStack.length < maxDepth) {
            this.matrixStack = new Matrix4x3f[maxDepth];
            this.orientationStack = new Quaternionf[maxDepth];
            for (int i = 0; i < maxDepth; i++) {
                this.matrixStack[i] = new Matrix4x3f();
                this.orientationStack[i] = new Quaternionf();
            }
        }

        // Store bone data in buffer
        int skeletonDataSize = Skeleton.UNIFORM_STRIDE * this.getSkeletonDataSize();
        int size = skeletonDataSize * skeletons.size();
        ByteBuffer buffer = MemoryUtil.memByteBuffer(boneBuilder.method_60808(size), size);

        for (int i = 0; i < skeletons.size(); i++) {
            Skeleton skeleton = skeletons.get(i);
            for (int j = 0; j < maxDepth; j++) {
                this.matrixStack[j].identity();
            }
            buffer.position(i * skeletonDataSize);
            skeleton.storeInstancedData(
                    buffer,
                    skeleton.roots,
                    this.boneIds,
                    0,
                    this.color,
                    this.normalMatrix,
                    transforms.get(i),
                    this.matrixStack,
                    this.orientationStack,
                    partialTicks.getFloat(i)
            );
        }

        buffer.rewind();
        class_9799.class_9800 result = boneBuilder.method_60807();
        if (result != null) {
            result.close();
        }

        // Upload data
        if (VeilRenderSystem.directStateAccessSupported()) {
            glNamedBufferSubData(boneBuffer, 0, buffer);
        } else {
            glBindBuffer(GL_UNIFORM_BUFFER, boneBuffer);
            glBufferSubData(GL_UNIFORM_BUFFER, 0, buffer);
        }

        // Draw
        this.vertexArray.bind();
        renderType.method_23516();
        VeilRenderSystem.bind("NecromancerBones", boneBlock);
        class_5944 shader = RenderSystem.getShader();
        if (shader != null) {
            shader.method_60897(class_293.class_5596.field_27379, RenderSystem.getModelViewMatrix(), RenderSystem.getProjectionMatrix(), class_310.method_1551().method_22683());
            shader.method_34586();

            class_284 uniform = shader.method_34582("NecromancerBoneCount");
            if (uniform != null) {
                glUniform1ui(uniform.method_35660(), this.boneIds.size());
            }
        }

        // TODO query uniform block size
        this.vertexArray.drawInstanced(skeletons.size());

        if (shader != null) {
            shader.method_34585();
        }

        VeilRenderSystem.unbind(boneBlock);
        renderType.method_23518();
    }

    public VertexArray getVertexArray() {
        return this.vertexArray;
    }

    public int getSkeletonDataSize() {
        return this.boneIds.size();
    }

    public static VertexArray createVertexArray() {
        RenderSystem.assertOnRenderThreadOrInit();
        VertexArray vertexArray = VertexArray.create();

        int vbo = vertexArray.getOrCreateBuffer(VertexArray.VERTEX_BUFFER);
        VertexArrayBuilder format = vertexArray.editFormat();
        format.defineVertexBuffer(0, vbo, 0, 24, 0);

        format.setVertexAttribute(0, 0, 3, VertexArrayBuilder.DataType.FLOAT, false, 0); // Position
        format.setVertexAttribute(1, 0, 2, VertexArrayBuilder.DataType.FLOAT, false, 12); // UV
        format.setVertexAttribute(2, 0, 3, VertexArrayBuilder.DataType.BYTE, true, 20); // Normal
        format.setVertexIAttribute(3, 0, 1, VertexArrayBuilder.DataType.UNSIGNED_BYTE, 23); // Bone Index

        VertexArray.unbind();
        return vertexArray;
    }

    @Override
    public void free() {
        this.vertexArray.free();
    }

    public static Builder builder(int textureWidth, int textureHeight) {
        return new Builder(textureWidth, textureHeight);
    }

    public static class Builder {
        private static final Vector3f POS = new Vector3f();
        private static final Vector3f NORMAL = new Vector3f();

        private final VertexArray vertexArray;
        private final class_9799 vertices;
        private final IntList indices;
        private final List<String> boneNames;
        private final float textureWidth;
        private final float textureHeight;
        private final Matrix4f position;
        private final Matrix3f normal;

        private int nextIndex;

        private Builder(float textureWidth, float textureHeight) {
            this.vertexArray = createVertexArray();
            this.vertices = new class_9799(Skeleton.MAX_BONES * 24 * 24);
            this.indices = new IntArrayList();
            this.boneNames = new ArrayList<>();
            this.textureWidth = textureWidth;
            this.textureHeight = textureHeight;
            this.position = new Matrix4f();
            this.normal = new Matrix3f();

            this.nextIndex = 0;
        }

        private static byte normalIntValue(float value) {
            return (byte) ((int) (class_3532.method_15363(value, -1.0F, 1.0F) * 127.0F) & 0xFF);
        }

        public Builder startBone(String boneId) {
            if (this.boneNames.contains(boneId)) {
                throw new IllegalStateException("Bone '" + boneId + "' has already defined mesh data");
            }
            if (this.boneNames.size() >= Skeleton.MAX_BONES) {
                throw new IllegalStateException("Too many bones defined. Max is " + Skeleton.MAX_BONES);
            }

            this.boneNames.add(boneId);
            return this;
        }

        public Builder setTransform(MatrixStack stack) {
            return this.setTransform(stack.position());
        }

        public Builder setTransform(class_4587 stack) {
            return this.setTransform(stack.method_23760());
        }

        public Builder setTransform(class_4587.class_4665 pose) {
            return this.setTransform(pose.method_23761());
        }

        public Builder setTransform(Matrix4fc position) {
            this.position.set(position);
            this.position.normal(this.normal);
            return this;
        }

        public Builder addVertex(float x, float y, float z, float u, float v, float normalX, float normalY, float normalZ) {
            if (this.boneNames.isEmpty()) {
                throw new IllegalStateException("No bone specified. Call #startBone(String) to start building a mesh.");
            }

            this.position.transformPosition(x, y, z, POS);
            this.normal.transform(normalX, normalY, normalZ, NORMAL);
            long pointer = this.vertices.method_60808(24);
            MemoryUtil.memPutFloat(pointer, POS.x);
            MemoryUtil.memPutFloat(pointer + 4, POS.y);
            MemoryUtil.memPutFloat(pointer + 8, POS.z);
            MemoryUtil.memPutFloat(pointer + 12, u);
            MemoryUtil.memPutFloat(pointer + 16, v);
            MemoryUtil.memPutByte(pointer + 20, normalIntValue(NORMAL.x));
            MemoryUtil.memPutByte(pointer + 21, normalIntValue(NORMAL.y));
            MemoryUtil.memPutByte(pointer + 22, normalIntValue(NORMAL.z));
            MemoryUtil.memPutByte(pointer + 23, (byte) (this.boneNames.size() - 1));
            return this;
        }

        public Builder addIndex(int index) {
            this.indices.add(index);
            if (index > this.nextIndex) {
                this.nextIndex = index + 1;
            }
            return this;
        }

        public Builder addQuadIndices(int index) {
            this.addIndex(index);
            this.addIndex(index + 1);
            this.addIndex(index + 2);
            this.addIndex(index + 2);
            this.addIndex(index + 3);
            this.addIndex(index);
            return this;
        }

        public Skin.Builder addCube(
                float xSize, float ySize, float zSize,
                float xOffset, float yOffset, float zOffset,
                float xInflate, float yInflate, float zInflate,
                float uOffset, float vOffset, boolean mirrored) {
            float minX = xOffset - xInflate, minY = yOffset - yInflate, minZ = zOffset - zInflate;
            float maxX = xOffset + xSize + xInflate, maxY = yOffset + ySize + yInflate, maxZ = zOffset + zSize + zInflate;

            float u0 = uOffset, u1 = u0 + zSize, u2 = u1 + xSize, u3 = u2 + zSize, u4 = u3 + xSize;
            u0 /= this.textureWidth;
            u1 /= this.textureWidth;
            u2 /= this.textureWidth;
            u3 /= this.textureWidth;
            u4 /= this.textureWidth;
            float topBottomU0 = uOffset + zSize, topBottomU1 = topBottomU0 + xSize, topBottomU2 = topBottomU1 + xSize;
            topBottomU0 /= this.textureWidth;
            topBottomU1 /= this.textureWidth;
            topBottomU2 /= this.textureWidth;

            float v0 = vOffset, v1 = v0 + zSize, v2 = v1 + ySize;
            v0 /= this.textureHeight;
            v1 /= this.textureHeight;
            v2 /= this.textureHeight;

            // A little gross, but should work in every case.
            if (!mirrored) {
                // Up
                this.addVertex(minX, maxY, minZ, topBottomU1, v1, 0.0F, 1.0F, 0.0F);
                this.addVertex(minX, maxY, maxZ, topBottomU1, v0, 0.0F, 1.0F, 0.0F);
                this.addVertex(maxX, maxY, maxZ, topBottomU0, v0, 0.0F, 1.0F, 0.0F);
                this.addVertex(maxX, maxY, minZ, topBottomU0, v1, 0.0F, 1.0F, 0.0F);

                // Down
                this.addVertex(maxX, minY, maxZ, topBottomU1, v0, 0.0F, -1.0F, 0.0F);
                this.addVertex(minX, minY, maxZ, topBottomU2, v0, 0.0F, -1.0F, 0.0F);
                this.addVertex(minX, minY, minZ, topBottomU2, v1, 0.0F, -1.0F, 0.0F);
                this.addVertex(maxX, minY, minZ, topBottomU1, v1, 0.0F, -1.0F, 0.0F);

                // East
                this.addVertex(maxX, minY, maxZ, u0, v2, 1.0F, 0.0F, 0.0F);
                this.addVertex(maxX, minY, minZ, u1, v2, 1.0F, 0.0F, 0.0F);
                this.addVertex(maxX, maxY, minZ, u1, v1, 1.0F, 0.0F, 0.0F);
                this.addVertex(maxX, maxY, maxZ, u0, v1, 1.0F, 0.0F, 0.0F);

                // West
                this.addVertex(minX, minY, minZ, u2, v2, -1.0F, 0.0F, 0.0F);
                this.addVertex(minX, minY, maxZ, u3, v2, -1.0F, 0.0F, 0.0F);
                this.addVertex(minX, maxY, maxZ, u3, v1, -1.0F, 0.0F, 0.0F);
                this.addVertex(minX, maxY, minZ, u2, v1, -1.0F, 0.0F, 0.0F);

                // North
                this.addVertex(maxX, minY, minZ, u1, v2, 0.0F, 0.0F, -1.0F);
                this.addVertex(minX, minY, minZ, u2, v2, 0.0F, 0.0F, -1.0F);
                this.addVertex(minX, maxY, minZ, u2, v1, 0.0F, 0.0F, -1.0F);
                this.addVertex(maxX, maxY, minZ, u1, v1, 0.0F, 0.0F, -1.0F);

                // South
                this.addVertex(minX, minY, maxZ, u3, v2, 0.0F, 0.0F, 1.0F);
                this.addVertex(maxX, minY, maxZ, u4, v2, 0.0F, 0.0F, 1.0F);
                this.addVertex(maxX, maxY, maxZ, u4, v1, 0.0F, 0.0F, 1.0F);
                this.addVertex(minX, maxY, maxZ, u3, v1, 0.0F, 0.0F, 1.0F);
            } else {
                // Up, mirrored
                this.addVertex(minX, maxY, minZ, topBottomU0, v1, 0.0F, 1.0F, 0.0F);
                this.addVertex(minX, maxY, maxZ, topBottomU0, v0, 0.0F, 1.0F, 0.0F);
                this.addVertex(maxX, maxY, maxZ, topBottomU1, v0, 0.0F, 1.0F, 0.0F);
                this.addVertex(maxX, maxY, minZ, topBottomU1, v1, 0.0F, 1.0F, 0.0F);

                // Down, mirrored
                this.addVertex(maxX, minY, maxZ, topBottomU2, v0, 0.0F, -1.0F, 0.0F);
                this.addVertex(minX, minY, maxZ, topBottomU1, v0, 0.0F, -1.0F, 0.0F);
                this.addVertex(minX, minY, minZ, topBottomU1, v1, 0.0F, -1.0F, 0.0F);
                this.addVertex(maxX, minY, minZ, topBottomU2, v1, 0.0F, -1.0F, 0.0F);

                // East, mirrored
                this.addVertex(maxX, minY, maxZ, u3, v2, 1.0F, 0.0F, 0.0F);
                this.addVertex(maxX, minY, minZ, u2, v2, 1.0F, 0.0F, 0.0F);
                this.addVertex(maxX, maxY, minZ, u2, v1, 1.0F, 0.0F, 0.0F);
                this.addVertex(maxX, maxY, maxZ, u3, v1, 1.0F, 0.0F, 0.0F);

                // West, mirrored
                this.addVertex(minX, minY, minZ, u1, v2, -1.0F, 0.0F, 0.0F);
                this.addVertex(minX, minY, maxZ, u0, v2, -1.0F, 0.0F, 0.0F);
                this.addVertex(minX, maxY, maxZ, u0, v1, -1.0F, 0.0F, 0.0F);
                this.addVertex(minX, maxY, minZ, u1, v1, -1.0F, 0.0F, 0.0F);

                // North, mirrored
                this.addVertex(maxX, minY, minZ, u2, v2, 0.0F, 0.0F, -1.0F);
                this.addVertex(minX, minY, minZ, u1, v2, 0.0F, 0.0F, -1.0F);
                this.addVertex(minX, maxY, minZ, u1, v1, 0.0F, 0.0F, -1.0F);
                this.addVertex(maxX, maxY, minZ, u2, v1, 0.0F, 0.0F, -1.0F);

                // South, mirrored
                this.addVertex(minX, minY, maxZ, u4, v2, 0.0F, 0.0F, 1.0F);
                this.addVertex(maxX, minY, maxZ, u3, v2, 0.0F, 0.0F, 1.0F);
                this.addVertex(maxX, maxY, maxZ, u3, v1, 0.0F, 0.0F, 1.0F);
                this.addVertex(minX, maxY, maxZ, u4, v1, 0.0F, 0.0F, 1.0F);
            }

            for (int i = 0; i < 6; i++) {
                this.addQuadIndices(this.nextIndex());
            }

            return this;
        }

        public Builder addTri(
                float x1, float y1, float z1, float u1, float v1,
                float x2, float y2, float z2, float u2, float v2,
                float x3, float y3, float z3, float u3, float v3,
                float normalX, float normalY, float normalZ) {
            this.vertices.method_60808(72);
            this.addVertex(x1, y1, z1, u1, v1, normalX, normalY, normalZ);
            this.addVertex(x2, y2, z2, u2, v2, normalX, normalY, normalZ);
            this.addVertex(x3, y3, z3, u3, v3, normalX, normalY, normalZ);
            this.addIndex(this.nextIndex());
            this.addIndex(this.nextIndex());
            this.addIndex(this.nextIndex());
            return this;
        }

        public Builder addFace(
                float x1, float y1, float z1, float u1, float v1,
                float x2, float y2, float z2, float u2, float v2,
                float x3, float y3, float z3, float u3, float v3,
                float x4, float y4, float z4, float u4, float v4,
                float normalX, float normalY, float normalZ) {
            this.vertices.method_60808(96);
            this.addVertex(x1, y1, z1, u1, v1, normalX, normalY, normalZ);
            this.addVertex(x2, y2, z2, u2, v2, normalX, normalY, normalZ);
            this.addVertex(x3, y3, z3, u3, v3, normalX, normalY, normalZ);
            this.addVertex(x4, y4, z4, u4, v4, normalX, normalY, normalZ);
            this.addQuadIndices(this.nextIndex());
            return this;
        }

        public int nextIndex() {
            return this.nextIndex;
        }

        private void storeIndices(VertexArray.IndexType indexType, ByteBuffer buffer) {
            for (int i = 0; i < this.indices.size(); i++) {
                int index = this.indices.getInt(i);
                switch (indexType) {
                    case BYTE -> buffer.put(i, (byte) index);
                    case SHORT -> buffer.putShort(i * 2, (short) index);
                    case INT -> buffer.putInt(i * 4, index);
                }
            }
        }

        public Skin build() {
            ByteBuffer indices = null;
            try (this.vertices; class_9799.class_9800 result = this.vertices.method_60807()) {
                if (result == null) {
                    throw new IllegalStateException("No mesh data provides to skin");
                }
                int vertexBuffer = this.vertexArray.getOrCreateBuffer(VertexArray.VERTEX_BUFFER);
                VertexArray.upload(vertexBuffer, result.method_60817(), VertexArray.DrawUsage.STATIC);

                // Allocate and store index buffer
                VertexArray.IndexType indexType = VertexArray.IndexType.least(this.nextIndex - 1);
                indices = MemoryUtil.memAlloc(this.indices.size() << indexType.ordinal());
                this.storeIndices(indexType, indices);
                this.vertexArray.uploadIndexBuffer(indices, indexType);

                Object2IntMap<String> boneIds = new Object2IntArrayMap<>(this.boneNames.size());
                for (int i = 0; i < this.boneNames.size(); i++) {
                    boneIds.put(this.boneNames.get(i), i);
                }

                return new Skin(this.vertexArray, Object2IntMaps.unmodifiable(boneIds));
            } finally {
                MemoryUtil.memFree(indices);
            }
        }
    }
}
