package com.zurrtum.create.client.vanillin.visuals;

import com.zurrtum.create.client.flywheel.api.instance.Instance;
import com.zurrtum.create.client.flywheel.api.material.Material;
import com.zurrtum.create.client.flywheel.api.visualization.VisualizationContext;
import com.zurrtum.create.client.flywheel.lib.material.CutoutShaders;
import com.zurrtum.create.client.flywheel.lib.material.SimpleMaterial;
import com.zurrtum.create.client.flywheel.lib.model.part.InstanceTree;
import com.zurrtum.create.client.flywheel.lib.model.part.ModelTrees;
import com.zurrtum.create.client.flywheel.lib.visual.AbstractBlockEntityVisual;
import com.zurrtum.create.client.flywheel.lib.visual.SimpleDynamicVisual;
import org.joml.Matrix4f;

import java.util.Set;
import java.util.function.Consumer;
import net.minecraft.class_1767;
import net.minecraft.class_2350;
import net.minecraft.class_2480;
import net.minecraft.class_2627;
import net.minecraft.class_3532;
import net.minecraft.class_4722;
import net.minecraft.class_5602;

public class ShulkerBoxVisual extends AbstractBlockEntityVisual<class_2627> implements SimpleDynamicVisual {
    private static final Material MATERIAL = SimpleMaterial.builder().cutout(CutoutShaders.ONE_TENTH).texture(class_4722.field_21704).mipmap(false)
        .backfaceCulling(false).build();
    private static final Set<String> PATHS_TO_PRUNE = Set.of("/head");

    private final InstanceTree instances;
    private final InstanceTree lid;

    private final Matrix4f initialPose;

    private float lastProgress = Float.NaN;

    public ShulkerBoxVisual(VisualizationContext ctx, class_2627 blockEntity, float partialTick) {
        super(ctx, blockEntity, partialTick);

        class_1767 color = blockEntity.method_11320();
        net.minecraft.class_4730 texture;
        if (color == null) {
            texture = class_4722.field_21710;
        } else {
            texture = class_4722.field_21711.get(color.method_7789());
        }

        instances = InstanceTree.create(instancerProvider(), ModelTrees.of(class_5602.field_27596, PATHS_TO_PRUNE, texture, MATERIAL));
        lid = instances.childOrThrow("lid");

        initialPose = createInitialPose();
        applyTransform(partialTick);
    }

    private Matrix4f createInitialPose() {
        var visualPosition = getVisualPosition();
        var rotation = getDirection().method_23224();
        return new Matrix4f().translate(visualPosition.method_10263(), visualPosition.method_10264(), visualPosition.method_10260()).translate(0.5f, 0.5f, 0.5f)
            .scale(0.9995f).rotate(rotation).scale(1, -1, -1).translate(0, -1, 0);
    }

    private class_2350 getDirection() {
        if (blockState.method_26204() instanceof class_2480) {
            return blockState.method_11654(class_2480.field_11496);
        }

        return class_2350.field_11036;
    }

    @Override
    public void beginFrame(Context context) {
        if (doDistanceLimitThisFrame(context) || !isVisible(context.frustum())) {
            return;
        }

        applyTransform(context.partialTick());
    }

    private void applyTransform(float partialTicks) {
        float progress = blockEntity.method_11312(partialTicks);
        if (progress == lastProgress) {
            return;
        }
        lastProgress = progress;

        lid.yRot(1.5f * class_3532.field_29844 * progress);
        lid.yPos(24f - progress * 8f);

        instances.updateInstancesStatic(initialPose);
    }

    @Override
    public void updateLight(float partialTick) {
        int packedLight = computePackedLight();
        instances.traverse(instance -> {
            instance.light(packedLight).setChanged();
        });
    }

    @Override
    public void collectCrumblingInstances(Consumer<Instance> consumer) {
        instances.traverse(consumer);
    }

    @Override
    protected void _delete() {
        instances.delete();
    }
}
