package com.zurrtum.create.client.content.fluids;

import com.zurrtum.create.catnip.data.Iterate;
import com.zurrtum.create.client.flywheel.api.material.CardinalLightingMode;
import com.zurrtum.create.client.flywheel.api.material.Transparency;
import com.zurrtum.create.client.flywheel.api.model.Model;
import com.zurrtum.create.client.flywheel.api.vertex.MutableVertexList;
import com.zurrtum.create.client.flywheel.lib.material.SimpleMaterial;
import com.zurrtum.create.client.flywheel.lib.model.QuadMesh;
import com.zurrtum.create.client.flywheel.lib.model.SingleMeshModel;
import com.zurrtum.create.client.flywheel.lib.util.RendererReloadCache;
import net.minecraft.class_1058;
import net.minecraft.class_2350;
import net.minecraft.class_3532;
import net.minecraft.class_4608;
import org.joml.Vector4f;
import org.joml.Vector4fc;

public class FluidMesh {
    private static final RendererReloadCache<class_1058, Model> STREAM = new RendererReloadCache<>(sprite -> new SingleMeshModel(
        new FluidStreamMesh(sprite), material(sprite)));

    private static final RendererReloadCache<SurfaceKey, Model> SURFACE = new RendererReloadCache<>(sprite -> new SingleMeshModel(
        new FluidSurfaceMesh(sprite.texture(),
        sprite.width()
    ), material(sprite.texture())
    ));
    public static final float PIPE_RADIUS = 3f / 16f;

    // TODO: width parameter here too
    public static Model stream(class_1058 sprite) {
        return STREAM.get(sprite);
    }

    public static Model surface(class_1058 sprite, float width) {
        return SURFACE.get(new SurfaceKey(sprite, width));
    }

    private static SimpleMaterial material(class_1058 sprite) {
        return SimpleMaterial.builder().cardinalLightingMode(CardinalLightingMode.OFF).texture(sprite.method_45852())
            .transparency(Transparency.ORDER_INDEPENDENT).build();
    }

    private record SurfaceKey(class_1058 texture, float width) {
    }

    public record FluidSurfaceMesh(class_1058 texture, float width) implements QuadMesh {
        @Override
        public int vertexCount() {
            int quadWidth = class_3532.method_15386(width) - class_3532.method_15375(-width);
            return 4 * quadWidth * quadWidth;
        }

        @Override
        public void write(MutableVertexList vertexList) {
            for (int i = 0; i < vertexCount(); i++) {
                vertexList.r(i, 1);
                vertexList.g(i, 1);
                vertexList.b(i, 1);
                vertexList.a(i, 1);
                vertexList.light(i, 0);
                vertexList.overlay(i, class_4608.field_21444);

                vertexList.normalX(i, 0);
                vertexList.normalY(i, 1);
                vertexList.normalZ(i, 0);

                vertexList.y(i, 0);
            }

            float textureScale = 1 / 16f;

            float left = -width;
            float right = width;
            float down = -width;
            float up = width;

            int vertex = 0;

            float x2;
            float y2;
            for (float x1 = left; x1 < right; x1 = x2) {
                float x1floor = class_3532.method_15375(x1);
                x2 = Math.min(x1floor + 1, right);
                float u1 = texture.method_4580((x1 - x1floor) * 16 * textureScale);
                float u2 = texture.method_4580((x2 - x1floor) * 16 * textureScale);
                for (float y1 = down; y1 < up; y1 = y2) {
                    float y1floor = class_3532.method_15375(y1);
                    y2 = Math.min(y1floor + 1, up);
                    float v1 = texture.method_4570((y1 - y1floor) * 16 * textureScale);
                    float v2 = texture.method_4570((y2 - y1floor) * 16 * textureScale);

                    vertexList.x(vertex, x1);
                    vertexList.z(vertex, y1);
                    vertexList.u(vertex, u1);
                    vertexList.v(vertex, v1);

                    vertexList.x(vertex + 1, x1);
                    vertexList.z(vertex + 1, y2);
                    vertexList.u(vertex + 1, u1);
                    vertexList.v(vertex + 1, v2);

                    vertexList.x(vertex + 2, x2);
                    vertexList.z(vertex + 2, y2);
                    vertexList.u(vertex + 2, u2);
                    vertexList.v(vertex + 2, v2);

                    vertexList.x(vertex + 3, x2);
                    vertexList.z(vertex + 3, y1);
                    vertexList.u(vertex + 3, u2);
                    vertexList.v(vertex + 3, v1);
                    vertex += 4;
                }
            }
        }

        @Override
        public Vector4fc boundingSphere() {
            return new Vector4f(0, 0, 0, width / class_3532.field_15724);
        }
    }

    public record FluidStreamMesh(class_1058 texture) implements QuadMesh {
        @Override
        public int vertexCount() {
            return 4 * 2 * 4;
        }

        @Override
        public void write(MutableVertexList vertexList) {
            for (int i = 0; i < vertexCount(); i++) {
                vertexList.r(i, 1);
                vertexList.g(i, 1);
                vertexList.b(i, 1);
                vertexList.a(i, 1);
                vertexList.light(i, 0);
                vertexList.overlay(i, class_4608.field_21444);

                vertexList.v(i, 0);
            }

            float textureScale = 1 / 32f;

            float radius = PIPE_RADIUS;
            float left = -radius;
            float right = radius;

            int vertex = 0;

            for (var horizontalDirection : Iterate.horizontalDirections) {
                float x2;
                for (float x1 = left; x1 < right; x1 = x2) {
                    float x1floor = class_3532.method_15375(x1);
                    x2 = Math.min(x1floor + 1, right);
                    float u1 = texture.method_4580((x1 - x1floor) * 16 * textureScale);
                    float u2 = texture.method_4580((x2 - x1floor) * 16 * textureScale);

                    putQuad(vertexList, vertex, horizontalDirection, radius, x1, x2, u1, u2);
                    vertex += 4;
                }
            }
        }

        private static void putQuad(MutableVertexList vertexList, int i, class_2350 horizontal, float radius, float p0, float p1, float u0, float u1) {
            float xStart;
            float xEnd;
            float zStart;
            float zEnd;

            switch (horizontal) {
                case field_11043:
                    xStart = p1;
                    xEnd = p0;
                    zStart = zEnd = -radius;
                    break;
                case field_11035:
                    xStart = p0;
                    xEnd = p1;
                    zStart = zEnd = radius;
                    break;
                case field_11039:
                    zStart = p0;
                    zEnd = p1;
                    xStart = xEnd = -radius;
                    break;
                case field_11034:
                    zStart = p1;
                    zEnd = p0;
                    xStart = xEnd = radius;
                    break;
                default:
                    throw new IllegalStateException("Unexpected value: " + horizontal);
            }

            vertexList.x(i, xStart);
            vertexList.y(i, 1);
            vertexList.z(i, zStart);
            vertexList.u(i, u0);

            vertexList.x(i + 1, xStart);
            vertexList.y(i + 1, 0);
            vertexList.z(i + 1, zStart);
            vertexList.u(i + 1, u0);

            vertexList.x(i + 2, xEnd);
            vertexList.y(i + 2, 0);
            vertexList.z(i + 2, zEnd);
            vertexList.u(i + 2, u1);

            vertexList.x(i + 3, xEnd);
            vertexList.y(i + 3, 1);
            vertexList.z(i + 3, zEnd);
            vertexList.u(i + 3, u1);

            for (int j = 0; j < 4; j++) {
                vertexList.normalX(i + j, horizontal.method_10148());
                vertexList.normalY(i + j, horizontal.method_10164());
                vertexList.normalZ(i + j, horizontal.method_10165());
            }
        }

        @Override
        public Vector4fc boundingSphere() {
            return new Vector4f(0, 0.5f, 0, 1);
        }
    }
}
