package com.zurrtum.create.client.catnip.impl.client.render.model;

import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.zurrtum.create.client.catnip.client.render.model.ShadeSeparatedBufferSource;
import com.zurrtum.create.client.catnip.client.render.model.ShadeSeparatedResultConsumer;
import com.zurrtum.create.client.catnip.impl.client.render.TransformingVertexConsumer;
import com.zurrtum.create.client.infrastructure.model.CopycatModel;
import com.zurrtum.create.client.infrastructure.model.WrapperBlockStateModel;
import com.zurrtum.create.client.model.LayerBakedModel;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import org.jetbrains.annotations.Nullable;

import java.util.Iterator;
import java.util.List;
import net.minecraft.class_1087;
import net.minecraft.class_10889;
import net.minecraft.class_11515;
import net.minecraft.class_1920;
import net.minecraft.class_2338;
import net.minecraft.class_2464;
import net.minecraft.class_2680;
import net.minecraft.class_310;
import net.minecraft.class_3610;
import net.minecraft.class_4587;
import net.minecraft.class_4608;
import net.minecraft.class_4696;
import net.minecraft.class_5819;
import net.minecraft.class_776;
import net.minecraft.class_778;

// Modified from https://github.com/Engine-Room/Flywheel/blob/2f67f54c8898d91a48126c3c753eefa6cd224f84/forge/src/lib/java/dev/engine_room/flywheel/lib/model/baked/BakedModelBufferer.java
public final class BakedModelBuffererImpl {
    private static final ThreadLocal<ThreadLocalObjects> THREAD_LOCAL_OBJECTS = ThreadLocal.withInitial(ThreadLocalObjects::new);

    private BakedModelBuffererImpl() {
    }


    public static void bufferModel(
        class_1087 model,
        class_2338 pos,
        class_1920 level,
        class_2680 state,
        @Nullable class_4587 poseStack,
        ShadeSeparatedBufferSource bufferSource
    ) {
        ThreadLocalObjects objects = THREAD_LOCAL_OBJECTS.get();
        class_5819 random = objects.random;
        random.method_43052(state.method_26190(pos));
        if (poseStack == null) {
            poseStack = objects.identityPoseStack;
        }
        List<class_10889> parts = new ObjectArrayList<>();
        if (model instanceof CopycatModel copycatModel) {
            copycatModel.addPartsWithInfo(level, pos, state, random, parts);
        } else {
            model.method_68513(random, parts);
        }
        bufferModel(parts, pos, level, state, poseStack, bufferSource, objects.universalEmitter);
    }

    private static void bufferModel(
        List<class_10889> parts,
        class_2338 pos,
        class_1920 level,
        class_2680 state,
        class_4587 poseStack,
        ShadeSeparatedBufferSource bufferSource,
        UniversalMeshEmitter universalEmitter
    ) {
        class_778 blockRenderer = class_310.method_1551().method_1541().method_3350();
        int size = parts.size();
        if (size == 0) {
            return;
        }

        Supplier<class_11515> defaultLayer = Suppliers.memoize(() -> class_4696.method_23679(state));
        class_11515 firstLayer = LayerBakedModel.getBlockRenderLayer(parts.getFirst(), defaultLayer);
        if (size == 1) {
            render(universalEmitter, bufferSource, firstLayer, poseStack, blockRenderer, level, parts, state, pos);
        } else {
            class_11515[] renderLayers = new class_11515[size];
            renderLayers[0] = firstLayer;
            boolean simple = true;
            for (int i = 1; i < size; i++) {
                renderLayers[i] = LayerBakedModel.getBlockRenderLayer(parts.get(i), defaultLayer);
                if (simple && renderLayers[i] != firstLayer) {
                    simple = false;
                }
            }
            if (simple) {
                render(universalEmitter, bufferSource, firstLayer, poseStack, blockRenderer, level, parts, state, pos);
            } else {
                for (int i = 0; i < size; i++) {
                    render(universalEmitter, bufferSource, renderLayers[i], poseStack, blockRenderer, level, List.of(parts.get(i)), state, pos);
                }
            }
        }

        universalEmitter.clear();
    }

    private static void render(
        UniversalMeshEmitter universalEmitter,
        ShadeSeparatedBufferSource bufferSource,
        class_11515 layer,
        class_4587 poseStack,
        class_778 blockRenderer,
        class_1920 level,
        List<class_10889> parts,
        class_2680 state,
        class_2338 pos
    ) {
        universalEmitter.prepare(bufferSource, layer);
        poseStack.method_22903();
        blockRenderer.method_3374(level, parts, state, pos, poseStack, universalEmitter, false, class_4608.field_21444);
        poseStack.method_22909();
    }

    public static void bufferModel(
        List<class_10889> parts,
        class_2338 pos,
        class_1920 level,
        class_2680 state,
        @Nullable class_4587 poseStack,
        ShadeSeparatedResultConsumer resultConsumer
    ) {
        ThreadLocalObjects objects = THREAD_LOCAL_OBJECTS.get();
        DefaultShadeSeparatedBufferSource bufferSource = objects.defaultBufferSource;
        bufferSource.prepare(resultConsumer);
        if (poseStack == null) {
            poseStack = objects.identityPoseStack;
        }
        bufferModel(parts, pos, level, state, poseStack, bufferSource, objects.universalEmitter);
        bufferSource.end();
    }

    public static void bufferModel(
        class_1087 model,
        class_2338 pos,
        class_1920 level,
        class_2680 state,
        @Nullable class_4587 poseStack,
        ShadeSeparatedResultConsumer resultConsumer
    ) {
        ThreadLocalObjects objects = THREAD_LOCAL_OBJECTS.get();
        DefaultShadeSeparatedBufferSource bufferSource = objects.defaultBufferSource;
        bufferSource.prepare(resultConsumer);
        bufferModel(model, pos, level, state, poseStack, bufferSource);
        bufferSource.end();
    }

    public static void bufferBlocks(
        Iterator<class_2338> posIterator,
        class_1920 level,
        @Nullable class_4587 poseStack,
        boolean renderFluids,
        ShadeSeparatedBufferSource bufferSource
    ) {
        ThreadLocalObjects objects = THREAD_LOCAL_OBJECTS.get();
        if (poseStack == null) {
            poseStack = objects.identityPoseStack;
        }
        class_5819 random = objects.random;
        UniversalMeshEmitter universalEmitter = objects.universalEmitter;
        TransformingVertexConsumer transformingWrapper = objects.transformingWrapper;

        class_776 renderDispatcher = class_310.method_1551().method_1541();

        class_778 blockRenderer = renderDispatcher.method_3350();
        class_778.method_20544();

        while (posIterator.hasNext()) {
            class_2338 pos = posIterator.next();
            class_2680 state = level.method_8320(pos);

            if (renderFluids) {
                class_3610 fluidState = state.method_26227();

                if (!fluidState.method_15769()) {
                    class_11515 renderType = class_4696.method_23680(fluidState);

                    transformingWrapper.prepare(bufferSource.getBuffer(renderType, true), poseStack);

                    poseStack.method_22903();
                    poseStack.method_46416(pos.method_10263() - (pos.method_10263() & 0xF), pos.method_10264() - (pos.method_10264() & 0xF), pos.method_10260() - (pos.method_10260() & 0xF));
                    renderDispatcher.method_3352(pos, level, transformingWrapper, state, fluidState);
                    poseStack.method_22909();
                }
            }

            if (state.method_26217() == class_2464.field_11458) {
                long seed = state.method_26190(pos);
                class_1087 model = renderDispatcher.method_3349(state);
                random.method_43052(seed);
                class_11515 renderType = class_4696.method_23679(state);
                universalEmitter.prepare(bufferSource, renderType);
                poseStack.method_22903();
                poseStack.method_46416(pos.method_10263(), pos.method_10264(), pos.method_10260());
                List<class_10889> parts = new ObjectArrayList<>();
                if (WrapperBlockStateModel.unwrapCompat(model) instanceof WrapperBlockStateModel wrapper) {
                    wrapper.addPartsWithInfo(level, pos, state, random, parts);
                } else {
                    model.method_68513(random, parts);
                }
                blockRenderer.method_3374(level, parts, state, pos, poseStack, universalEmitter, true, class_4608.field_21444);
                poseStack.method_22909();
            }
        }

        class_778.method_20545();
        transformingWrapper.clear();
        universalEmitter.clear();
    }

    public static void bufferBlocks(
        Iterator<class_2338> posIterator,
        class_1920 level,
        @Nullable class_4587 poseStack,
        boolean renderFluids,
        ShadeSeparatedResultConsumer resultConsumer
    ) {
        ThreadLocalObjects objects = THREAD_LOCAL_OBJECTS.get();
        DefaultShadeSeparatedBufferSource bufferSource = objects.defaultBufferSource;
        bufferSource.prepare(resultConsumer);
        bufferBlocks(posIterator, level, poseStack, renderFluids, bufferSource);
        bufferSource.end();
    }

    private static class ThreadLocalObjects {
        public final class_4587 identityPoseStack = new class_4587();
        public final class_5819 random = class_5819.method_43053();

        public final DefaultShadeSeparatedBufferSource defaultBufferSource = new DefaultShadeSeparatedBufferSource();
        public final UniversalMeshEmitter universalEmitter = new UniversalMeshEmitter();
        public final TransformingVertexConsumer transformingWrapper = new TransformingVertexConsumer();
    }
}
