package com.zurrtum.create.client.mixin;

import com.llamalad7.mixinextras.injector.wrapoperation.Operation;
import com.llamalad7.mixinextras.injector.wrapoperation.WrapOperation;
import com.llamalad7.mixinextras.sugar.Local;
import com.mojang.datafixers.util.Pair;
import com.zurrtum.create.client.AllModels;
import com.zurrtum.create.client.flywheel.lib.model.baked.PartialModelEventHandler;
import com.zurrtum.create.client.model.NormalsModelElement;
import com.zurrtum.create.client.model.UnbakedModelParser;
import net.minecraft.class_10097;
import net.minecraft.class_10521;
import net.minecraft.class_10769;
import net.minecraft.class_10801;
import net.minecraft.class_10802;
import net.minecraft.class_1086;
import net.minecraft.class_1087;
import net.minecraft.class_1088;
import net.minecraft.class_1088.class_7778;
import net.minecraft.class_1092;
import net.minecraft.class_1100;
import net.minecraft.class_2680;
import net.minecraft.class_2960;
import net.minecraft.class_3518;
import net.minecraft.class_793;
import net.minecraft.class_9824;
import net.minecraft.class_9826;
import net.minecraft.client.render.model.*;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable;

import java.io.Reader;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.stream.Stream;

import static com.zurrtum.create.Create.MOD_ID;

@Mixin(class_1092.class)
public class BakedModelManagerMixin {
    @Inject(method = "method_65750", at = @At(value = "INVOKE", target = "Lnet/minecraft/client/render/model/json/JsonUnbakedModel;deserialize(Ljava/io/Reader;)Lnet/minecraft/client/render/model/json/JsonUnbakedModel;"), cancellable = true)
    private static void deserialize(
        CallbackInfoReturnable<Pair<class_2960, class_793>> cir,
        @Local class_2960 identifier,
        @Local Reader input
    ) {
        if (identifier.method_12836().equals(MOD_ID)) {
            try {
                class_1100 model = class_3518.method_15276(UnbakedModelParser.GSON, input, class_1100.class);
                if (model instanceof class_793 jsonModel) {
                    class_10802 geometry = (class_10802) jsonModel.comp_3739();
                    if (geometry != null) {
                        geometry.comp_3753().forEach(NormalsModelElement::markFacingNormals);
                    }
                    cir.setReturnValue(Pair.of(identifier, jsonModel));
                } else {
                    UnbakedModelParser.cache(identifier, model);
                    cir.setReturnValue(null);
                }
            } finally {
                if (input != null) {
                    try {
                        input.close();
                    } catch (Exception ignore) {
                    }
                }
            }
        }
    }

    @WrapOperation(method = "method_45897", at = @At(value = "INVOKE", target = "Ljava/util/List;stream()Ljava/util/stream/Stream;"))
    private static Stream<Pair<class_2960, class_1100>> replace(
        List<Pair<class_2960, class_1100>> instance,
        Operation<Stream<Pair<class_2960, class_1100>>> original
    ) {
        return Stream.concat(original.call(instance), UnbakedModelParser.getCaches());
    }

    @Inject(method = "collect", at = @At(value = "NEW", target = "(Lnet/minecraft/client/render/model/BakedSimpleModel;Ljava/util/Map;)Lnet/minecraft/client/render/model/BakedModelManager$Models;"))
    private static void collect(
        Map<class_2960, class_1100> modelMap,
        class_9824.class_10095 stateDefinition,
        class_10521.class_10522 result,
        CallbackInfoReturnable<class_1092.class_10816> cir,
        @Local class_10097 collector
    ) {
        Map<class_2680, class_1087.class_9979> models = stateDefinition.comp_3063();
        AllModels.ALL.forEach((state, resolver) -> {
            class_1087.class_9979 unbaked = resolver.apply(state, models.get(state));
            unbaked.method_62326(collector::method_68023);
            models.put(state, unbaked);
        });
        PartialModelEventHandler.getRegisterAdditional().keySet().forEach(collector::method_68023);
    }

    @WrapOperation(method = "bake", at = @At(value = "INVOKE", target = "Lnet/minecraft/client/render/model/ModelBaker;bake(Lnet/minecraft/client/render/model/ErrorCollectingSpriteGetter;Ljava/util/concurrent/Executor;)Ljava/util/concurrent/CompletableFuture;"))
    private static CompletableFuture<class_1088.class_10524> bake(
        class_1088 baker,
        class_9826 spriteGetter,
        Executor executor,
        Operation<CompletableFuture<class_1088.class_10524>> original
    ) {
        class_1088.class_7778 bakerImpl = baker.new class_7778(spriteGetter);
        CompletableFuture<class_1088.class_10524> modelsCompletableFuture = original.call(baker, spriteGetter, executor);
        return class_10769.method_67612(
            PartialModelEventHandler.getRegisterAdditional(), (id, model) -> {
                class_10801 bakedModel = class_10801.method_67931(bakerImpl, id, class_1086.field_5350);
                PartialModelEventHandler.onBakingCompleted(model, bakedModel);
                return bakedModel;
            }, executor
        ).thenAccept(PartialModelEventHandler::onBakingCompleted).thenCompose(v -> modelsCompletableFuture);
    }
}
