package com.zurrtum.create.client.mixin;

import com.llamalad7.mixinextras.injector.wrapoperation.Operation;
import com.llamalad7.mixinextras.injector.wrapoperation.WrapOperation;
import com.llamalad7.mixinextras.sugar.Local;
import com.mojang.datafixers.util.Pair;
import com.zurrtum.create.client.AllModels;
import com.zurrtum.create.client.flywheel.lib.model.baked.PartialModelEventHandler;
import com.zurrtum.create.client.model.NormalsModelElement;
import com.zurrtum.create.client.model.UnbakedModelParser;
import net.minecraft.class_10097;
import net.minecraft.class_10521;
import net.minecraft.class_10802;
import net.minecraft.class_1087;
import net.minecraft.class_1092;
import net.minecraft.class_1100;
import net.minecraft.class_2680;
import net.minecraft.class_2960;
import net.minecraft.class_3518;
import net.minecraft.class_793;
import net.minecraft.class_9824;
import net.minecraft.client.resources.model.*;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable;

import java.io.Reader;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import static com.zurrtum.create.Create.MOD_ID;

@Mixin(class_1092.class)
public class ModelManagerMixin {
    @Inject(method = "method_65750", at = @At(value = "INVOKE", target = "Lnet/minecraft/client/renderer/block/model/BlockModel;fromStream(Ljava/io/Reader;)Lnet/minecraft/client/renderer/block/model/BlockModel;"), cancellable = true)
    private static void deserialize(CallbackInfoReturnable<Pair<class_2960, class_793>> cir, @Local class_2960 identifier, @Local Reader input) {
        if (identifier.method_12836().equals(MOD_ID)) {
            try {
                class_1100 model = class_3518.method_15276(UnbakedModelParser.GSON, input, class_1100.class);
                if (model instanceof class_793 jsonModel) {
                    class_10802 geometry = (class_10802) jsonModel.comp_3739();
                    if (geometry != null) {
                        geometry.comp_3753().forEach(NormalsModelElement::markFacingNormals);
                    }
                    cir.setReturnValue(Pair.of(identifier, jsonModel));
                } else {
                    UnbakedModelParser.cache(identifier, model);
                    cir.setReturnValue(null);
                }
            } finally {
                if (input != null) {
                    try {
                        input.close();
                    } catch (Exception ignore) {
                    }
                }
            }
        }
    }

    @WrapOperation(method = "method_45897", at = @At(value = "INVOKE", target = "Ljava/util/List;stream()Ljava/util/stream/Stream;"))
    private static Stream<Pair<class_2960, class_1100>> replace(
        List<Pair<class_2960, class_1100>> instance,
        Operation<Stream<Pair<class_2960, class_1100>>> original
    ) {
        return Stream.concat(original.call(instance), UnbakedModelParser.getCaches());
    }

    @Inject(method = "discoverModelDependencies", at = @At(value = "NEW", target = "(Lnet/minecraft/client/resources/model/ResolvedModel;Ljava/util/Map;)Lnet/minecraft/client/resources/model/ModelManager$ResolvedModels;"))
    private static void collect(
        Map<class_2960, class_1100> modelMap,
        class_9824.class_10095 stateDefinition,
        class_10521.class_10522 result,
        CallbackInfoReturnable<class_1092.class_10816> cir,
        @Local class_10097 collector
    ) {
        Map<class_2680, class_1087.class_9979> models = stateDefinition.comp_3063();
        AllModels.ALL.forEach((state, resolver) -> {
            class_1087.class_9979 unbaked = resolver.apply(state, models.get(state));
            unbaked.method_62326(collector::method_68023);
            models.put(state, unbaked);
        });
        PartialModelEventHandler.getRegisterAdditional().keySet().forEach(collector::method_68023);
    }
}
