package com.zurrtum.create.client.vanillin.elements;

import com.zurrtum.create.client.flywheel.api.material.Material;
import com.zurrtum.create.client.flywheel.api.material.Transparency;
import com.zurrtum.create.client.flywheel.api.material.WriteMask;
import com.zurrtum.create.client.flywheel.api.model.Model;
import com.zurrtum.create.client.flywheel.api.vertex.MutableVertexList;
import com.zurrtum.create.client.flywheel.api.visual.DynamicVisual;
import com.zurrtum.create.client.flywheel.api.visualization.VisualizationContext;
import com.zurrtum.create.client.flywheel.lib.instance.InstanceTypes;
import com.zurrtum.create.client.flywheel.lib.instance.ShadowInstance;
import com.zurrtum.create.client.flywheel.lib.material.SimpleMaterial;
import com.zurrtum.create.client.flywheel.lib.model.QuadMesh;
import com.zurrtum.create.client.flywheel.lib.model.SingleMeshModel;
import com.zurrtum.create.client.flywheel.lib.visual.AbstractVisual;
import com.zurrtum.create.client.flywheel.lib.visual.SimpleDynamicVisual;
import com.zurrtum.create.client.flywheel.lib.visual.util.InstanceRecycler;
import net.minecraft.class_1297;
import net.minecraft.class_2338;
import net.minecraft.class_2350.class_2351;
import net.minecraft.class_2464;
import net.minecraft.class_265;
import net.minecraft.class_2680;
import net.minecraft.class_2791;
import net.minecraft.class_2960;
import net.minecraft.class_310;
import net.minecraft.class_3532;
import net.minecraft.class_4608;
import net.minecraft.class_765;
import org.jetbrains.annotations.Nullable;
import org.joml.Vector4f;
import org.joml.Vector4fc;

/**
 * A component that uses instances to render an entity's shadow.
 *
 * <p>Use {@link #radius(float)} to set the radius of the shadow, in blocks.
 * <br>
 * Use {@link #strength(float)} to set the strength of the shadow.
 * <br>
 * The shadow will be cast on blocks at most {@code min(radius, 2 * strength)} blocks below the entity.</p>
 */
public final class ShadowElement extends AbstractVisual implements SimpleDynamicVisual {
    private static final class_2960 SHADOW_TEXTURE = class_2960.method_60656("textures/misc/shadow.png");
    private static final Material SHADOW_MATERIAL = SimpleMaterial.builder().texture(SHADOW_TEXTURE).mipmap(false)
        .polygonOffset(true) // vanilla shadows use "view offset" but this seems to work fine
        .transparency(Transparency.TRANSLUCENT).writeMask(WriteMask.COLOR).build();
    private static final Model SHADOW_MODEL = new SingleMeshModel(ShadowMesh.INSTANCE, SHADOW_MATERIAL);

    private final class_1297 entity;
    private final class_2338.class_2339 pos = new class_2338.class_2339();

    private final InstanceRecycler<ShadowInstance> instances = new InstanceRecycler<>(this::createInstance);

    // Defaults taken from EntityRenderer.
    private float radius = 0;
    private float strength = 1.0F;

    public ShadowElement(VisualizationContext ctx, class_1297 entity, float partialTick, Config config) {
        super(ctx, entity.method_37908(), partialTick);
        this.entity = entity;
        radius(config.radius);
        strength(config.strength);
    }

    private ShadowInstance createInstance() {
        return visualizationContext.instancerProvider().instancer(InstanceTypes.SHADOW, SHADOW_MODEL).createInstance();
    }

    public float radius() {
        return radius;
    }

    public float strength() {
        return strength;
    }

    /**
     * Set the radius of the shadow, in blocks, clamped to a maximum of 32.
     *
     * <p>Setting this to {@code <= 0} will disable the shadow.</p>
     *
     * @param radius The radius of the shadow, in blocks.
     */
    public ShadowElement radius(float radius) {
        this.radius = Math.min(radius, 32);
        return this;
    }

    /**
     * Set the strength of the shadow.
     *
     * @param strength The strength of the shadow.
     */
    public ShadowElement strength(float strength) {
        this.strength = strength;
        return this;
    }

    /**
     * Update the shadow instances. You'd typically call this in your visual's
     * {@link com.zurrtum.create.client.flywheel.lib.visual.SimpleDynamicVisual#beginFrame(DynamicVisual.Context) beginFrame} method.
     *
     * @param context The frame context.
     */
    @Override
    public void beginFrame(DynamicVisual.Context context) {
        instances.resetCount();

        boolean shadowsEnabled = class_310.method_1551().field_1690.method_42435().method_41753();
        if (shadowsEnabled && radius > 0 && !entity.method_5767()) {
            setupInstances(context);
        }

        instances.discardExtra();
    }

    private void setupInstances(DynamicVisual.Context context) {
        double entityX = class_3532.method_16436(context.partialTick(), entity.field_6038, entity.method_23317());
        double entityY = class_3532.method_16436(context.partialTick(), entity.field_5971, entity.method_23318());
        double entityZ = class_3532.method_16436(context.partialTick(), entity.field_5989, entity.method_23321());
        float castDistance = Math.min(strength * 2, radius);
        int minXPos = class_3532.method_15357(entityX - (double) radius);
        int maxXPos = class_3532.method_15357(entityX + (double) radius);
        int minYPos = class_3532.method_15357(entityY - (double) castDistance);
        int maxYPos = class_3532.method_15357(entityY);
        int minZPos = class_3532.method_15357(entityZ - (double) radius);
        int maxZPos = class_3532.method_15357(entityZ + (double) radius);

        for (int z = minZPos; z <= maxZPos; ++z) {
            for (int x = minXPos; x <= maxXPos; ++x) {
                pos.method_10103(x, 0, z);
                class_2791 chunk = level.method_22350(pos);

                for (int y = minYPos; y <= maxYPos; ++y) {
                    pos.method_33098(y);
                    float strengthGivenYFalloff = strength - (float) (entityY - pos.method_10264()) * 0.5F;
                    setupInstance(chunk, pos, entityX, entityZ, strengthGivenYFalloff);
                }
            }
        }
    }

    private void setupInstance(class_2791 chunk, class_2338.class_2339 pos, double entityX, double entityZ, float strength) {
        // TODO: cache this?
        var maxLocalRawBrightness = level.method_22339(pos);
        if (maxLocalRawBrightness <= 3) {
            // Too dark to render.
            return;
        }
        float blockBrightness = class_765.method_23284(level.method_8597(), maxLocalRawBrightness);
        float alpha = strength * 0.5F * blockBrightness;
        if (alpha < 0.0F) {
            // Too far away/too weak to render.
            return;
        }
        if (alpha > 1.0F) {
            alpha = 1.0F;
        }

        // Grab the AABB for the block below the current position.
        pos.method_33098(pos.method_10264() - 1);
        var shape = getShapeAt(chunk, pos);
        if (shape == null) {
            // No shape means the block shouldn't receive a shadow.
            return;
        }

        var renderOrigin = visualizationContext.renderOrigin();
        int x = pos.method_10263() - renderOrigin.method_10263();
        int y = pos.method_10264() - renderOrigin.method_10264() + 1; // +1 since we moved the pos down.
        int z = pos.method_10260() - renderOrigin.method_10260();

        double minX = x + shape.method_1091(class_2351.field_11048);
        double minY = y + shape.method_1091(class_2351.field_11052);
        double minZ = z + shape.method_1091(class_2351.field_11051);
        double maxX = x + shape.method_1105(class_2351.field_11048);
        double maxZ = z + shape.method_1105(class_2351.field_11051);

        var instance = instances.get();
        instance.x = (float) minX;
        instance.y = (float) minY;
        instance.z = (float) minZ;
        instance.entityX = (float) (entityX - renderOrigin.method_10263());
        instance.entityZ = (float) (entityZ - renderOrigin.method_10260());
        instance.sizeX = (float) (maxX - minX);
        instance.sizeZ = (float) (maxZ - minZ);
        instance.alpha = alpha;
        instance.radius = this.radius;
        instance.setChanged();
    }

    @Nullable
    private class_265 getShapeAt(class_2791 chunk, class_2338 pos) {
        class_2680 state = chunk.method_8320(pos);
        if (state.method_26217() == class_2464.field_11455) {
            return null;
        }
        if (!state.method_26234(chunk, pos)) {
            return null;
        }
        class_265 shape = state.method_26218(chunk, pos);
        if (shape.method_1110()) {
            return null;
        }
        return shape;
    }

    @Override
    protected void _delete() {
        instances.delete();
    }

    public record Config(float radius, float strength) {
        public static final float DEFAULT_RADIUS = 0;
        public static final float DEFAULT_STRENGTH = 1.0F;
    }

    /**
     * A single quad extending from the origin to (1, 0, 1).
     * <br>
     * To be scaled and translated to the correct position and size.
     */
    private static class ShadowMesh implements QuadMesh {
        private static final Vector4fc BOUNDING_SPHERE = new Vector4f(0.5f, 0, 0.5f, (float) (Math.sqrt(2) * 0.5));
        private static final ShadowMesh INSTANCE = new ShadowMesh();

        private ShadowMesh() {
        }

        @Override
        public int vertexCount() {
            return 4;
        }

        @Override
        public void write(MutableVertexList vertexList) {
            writeVertex(vertexList, 0, 0, 0);
            writeVertex(vertexList, 1, 0, 1);
            writeVertex(vertexList, 2, 1, 1);
            writeVertex(vertexList, 3, 1, 0);
        }

        // Magic numbers taken from:
        // net.minecraft.client.renderer.entity.EntityRenderDispatcher#shadowVertex
        private static void writeVertex(MutableVertexList vertexList, int i, float x, float z) {
            vertexList.x(i, x);
            vertexList.y(i, 0);
            vertexList.z(i, z);
            vertexList.r(i, 1);
            vertexList.g(i, 1);
            vertexList.b(i, 1);
            vertexList.u(i, 0);
            vertexList.v(i, 0);
            vertexList.light(i, class_765.field_32767);
            vertexList.overlay(i, class_4608.field_21444);
            vertexList.normalX(i, 0);
            vertexList.normalY(i, 1);
            vertexList.normalZ(i, 0);
        }

        @Override
        public Vector4fc boundingSphere() {
            return BOUNDING_SPHERE;
        }
    }
}
