package net.mehvahdjukaar.moonlight.api.client.model.fabric;

import net.mehvahdjukaar.moonlight.api.client.model.BakedQuadsTransformer;
import net.minecraft.class_1058;
import net.minecraft.class_2350;
import net.minecraft.class_290;
import net.minecraft.class_777;
import org.joml.Matrix3f;
import org.joml.Matrix4f;
import org.joml.Vector3f;
import org.joml.Vector4f;
import ;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.IntUnaryOperator;
import java.util.function.UnaryOperator;


public class BakedQuadsTransformerImpl implements BakedQuadsTransformer {

    private Consumer<class_777> inner = i -> {
    };

    private Boolean shade = null;
    private Integer emissivity = null;
    private Integer tintIndex = null;
    private UnaryOperator<class_2350> directionRemap = UnaryOperator.identity();
    private class_1058 sprite = null;

    public static BakedQuadsTransformer create() {
        return new BakedQuadsTransformerImpl();
    }

    @Override
    public BakedQuadsTransformer applyingColor(IntUnaryOperator indexToABGR) {
        inner = inner.andThen(applyingColorInplace(indexToABGR));
        return this;
    }

    @Override
    public BakedQuadsTransformer applyingLightMap(int packedLight) {
        inner = inner.andThen(applyingLightmapInplace(packedLight));
        return this;
    }

    @Override
    public BakedQuadsTransformer applyingTransform(Matrix4f transform) {
        inner = inner.andThen(applyingTransformInplace(transform));
        directionRemap = d -> class_2350.method_23225(new Matrix4f(new Matrix3f(transform)), d);
        return this;
    }

    @Override
    public BakedQuadsTransformer applyingAmbientOcclusion(boolean ambientOcclusion) {
        return this;
    }

    @Override
    public BakedQuadsTransformer applyingShade(boolean shade) {
        this.shade = shade;
        return this;
    }

    @Override
    public BakedQuadsTransformer applyingTintIndex(int tintIndex) {
        this.tintIndex = tintIndex;
        return this;
    }

    @Override
    public BakedQuadsTransformer applyingEmissivity(int emissivity) {
        this.emissivity = emissivity;
        return this;
    }

    @Override
    public BakedQuadsTransformer applyingSprite(class_1058 sprite) {
        inner = inner.andThen(applyingSpriteInplace(sprite));
        this.sprite = sprite;
        return this;
    }

    private class_1058 lastSpriteHack = null;

    @Override
    public class_777 transform(class_777 quad) {
        int[] v = Arrays.copyOf(quad.method_3357(), quad.method_3357().length);

        int tint = this.tintIndex == null ? quad.method_3359() : this.tintIndex;
        boolean shade = this.shade == null ? quad.method_24874() : this.shade;
        class_1058 sprite = this.sprite == null ? quad.method_35788() : this.sprite;
        lastSpriteHack = quad.method_35788();
        class_777 newQuad = new class_777(v, tint, directionRemap.apply(quad.method_3358()), sprite, shade);
        inner.accept(newQuad);
        lastSpriteHack = null;
        if (emissivity != null) {
            AtomicReference<class_777> emissiveQuad = new AtomicReference<>();
            try (BakedQuadBuilderImpl builder = (BakedQuadBuilderImpl) BakedQuadBuilderImpl
                    .create(sprite, null, emissiveQuad::set)) {
                builder.fromVanilla(newQuad);
                builder.lightEmission(emissivity);
            } catch (Exception ignored) {
            }
            newQuad = emissiveQuad.get();
        }
        return newQuad;
    }


    private Consumer<class_777> applyingSpriteInplace(class_1058 sprite) {
        return q -> {
            class_1058 oldSprite = lastSpriteHack;
            int stride = getStride();
            int[] v = q.method_3357();
            float segmentWScale = sprite.method_45851().method_45807() / (float) oldSprite.method_45851().method_45807();
            float segmentHScale = sprite.method_45851().method_45815() / (float) oldSprite.method_45851().method_45815();

            for (int i = 0; i < 4; i++) {
                int offset = i * stride + UV0;
                float originalU = Float.intBitsToFloat(v[offset]);
                float originalV = Float.intBitsToFloat(v[offset + 1]);

                float u1 = (originalU - oldSprite.method_4594()) * segmentWScale;
                v[offset] = Float.floatToRawIntBits(u1 + sprite.method_4594());

                float v1 = (originalV - oldSprite.method_4593()) * segmentHScale;
                v[offset + 1] = Float.floatToRawIntBits(v1 + sprite.method_4593());
            }
        };
    }

    private static Consumer<class_777> applyingColorInplace(IntUnaryOperator indexToABGR) {
        return quad -> {
            int[] v = quad.method_3357();
            int stride = getStride();
            for (int i = 0; i < 4; i++) {
                int i1 = indexToABGR.applyAsInt(i);
                v[i * stride + COLOR] = i1;
            }
        };
    }

    private static Consumer<class_777> applyingLightmapInplace(int packedLight) {
        return quad -> {
            var vertices = quad.method_3357();
            for (int i = 0; i < 4; i++)
                vertices[i * getStride() + UV2] = packedLight;
        };
    }

    private static Consumer<class_777> applyingTransformInplace(Matrix4f transform) {
        return quad -> {
            var v = quad.method_3357();
            int stride = getStride();
            for (int i = 0; i < 4; i++) {
                int offset = i * stride + POSITION;
                float originalX = Float.intBitsToFloat(v[offset]) - 0.5f;
                float originalY = Float.intBitsToFloat(v[offset + 1]) - 0.5f;
                float originalZ = Float.intBitsToFloat(v[offset + 2]) - 0.5f;

                Vector4f vec = new Vector4f(originalX, originalY, originalZ, 1);
                vec.mul(transform);
                // Divide by homogeneous coordinate to obtain transformed 3D point
                vec.div(vec.w);

                v[offset] = Float.floatToRawIntBits(vec.x() + 0.5f);
                v[offset + 1] = Float.floatToRawIntBits(vec.y() + 0.5f);
                v[offset + 2] = Float.floatToRawIntBits(vec.z() + 0.5f);
            }
            var normalTransform = new Matrix3f(transform).invert().transpose();

            for (int i = 0; i < 4; i++) {
                int offset = i * stride + NORMAL;
                int normalIn = v[offset];
                if ((normalIn & 0x00FFFFFF) != 0) {
                    float normalX = ((byte) (normalIn & 0xFF)) / 127.0f;
                    float normalY = ((byte) ((normalIn >> 8) & 0xFF)) / 127.0f;
                    float normalZ = ((byte) ((normalIn >> 16) & 0xFF)) / 127.0f;

                    Vector3f vec = new Vector3f(normalX, normalY, normalZ);
                    vec.mul(normalTransform);
                    vec.normalize();
                    v[offset] = (((byte) (vec.x() * 127.0f)) & 0xFF) |
                            ((((byte) (vec.y() * 127.0f)) & 0xFF) << 8) |
                            ((((byte) (vec.z() * 127.0f)) & 0xFF) << 16) |
                            (normalIn & 0xFF000000);
                }
            }
        };
    }

    private static int getStride() {
        return class_290.field_1590.method_1362() / 4;
    }

    private static final int POSITION = 0;
    private static final int COLOR = 3;
    private static final int UV0 = 4;
    private static final int UV2 = 6;
    private static final int NORMAL = 7;

}
