package com.zurrtum.create.mixin;

import com.llamalad7.mixinextras.injector.ModifyExpressionValue;
import com.mojang.datafixers.util.Pair;
import com.mojang.serialization.Codec;
import com.mojang.serialization.DataResult;
import com.mojang.serialization.DynamicOps;
import com.mojang.serialization.MapLike;
import com.zurrtum.create.foundation.recipe.ComponentsIngredient;
import org.spongepowered.asm.mixin.Final;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Mutable;
import org.spongepowered.asm.mixin.Shadow;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfo;

import java.util.Optional;
import net.minecraft.class_1856;
import net.minecraft.class_9129;
import net.minecraft.class_9139;

@Mixin(class_1856.class)
public class IngredientMixin {
    @Mutable
    @Shadow
    @Final
    public static Codec<class_1856> CODEC;

    @ModifyExpressionValue(method = "<clinit>", at = @At(value = "INVOKE", target = "Lnet/minecraft/network/codec/PacketCodec;xmap(Ljava/util/function/Function;Ljava/util/function/Function;)Lnet/minecraft/network/codec/PacketCodec;", ordinal = 0))
    private static class_9139<class_9129, class_1856> getPacketCodec(class_9139<class_9129, class_1856> packetCodec) {
        return new class_9139<>() {
            @Override
            public class_1856 decode(class_9129 buf) {
                int index = buf.readerIndex();
                if (buf.method_10816() != -1) {
                    buf.method_52988(index);
                    return packetCodec.decode(buf);
                }
                return ComponentsIngredient.field_48355.decode(buf);
            }

            @Override
            public void encode(class_9129 buf, class_1856 value) {
                if (value instanceof ComponentsIngredient componentsIngredient) {
                    buf.method_10804(-1);
                    ComponentsIngredient.field_48355.encode(buf, componentsIngredient);
                } else {
                    packetCodec.encode(buf, value);
                }
            }
        };
    }

    @ModifyExpressionValue(method = "<clinit>", at = @At(value = "INVOKE", target = "Lnet/minecraft/network/codec/PacketCodec;xmap(Ljava/util/function/Function;Ljava/util/function/Function;)Lnet/minecraft/network/codec/PacketCodec;", ordinal = 1))
    private static class_9139<class_9129, Optional<class_1856>> getIngredientPacketCodec(class_9139<class_9129, Optional<class_1856>> packetCodec) {
        return new class_9139<>() {
            @Override
            public Optional<class_1856> decode(class_9129 buf) {
                int index = buf.readerIndex();
                if (buf.method_10816() != -1) {
                    buf.method_52988(index);
                    return packetCodec.decode(buf);
                }
                return Optional.of(ComponentsIngredient.field_48355.decode(buf));
            }

            @Override
            public void encode(class_9129 buf, Optional<class_1856> value) {
                if (value.isPresent() && value.get() instanceof ComponentsIngredient componentsIngredient) {
                    buf.method_10804(-1);
                    ComponentsIngredient.field_48355.encode(buf, componentsIngredient);
                } else {
                    packetCodec.encode(buf, value);
                }
            }
        };
    }

    @Inject(method = "<clinit>", at = @At("TAIL"))
    private static void injectCodec(CallbackInfo ci) {
        Codec<class_1856> codec = CODEC;
        CODEC = new Codec<>() {
            @Override
            @SuppressWarnings({"unchecked", "rawtypes"})
            public <T> DataResult<Pair<class_1856, T>> decode(DynamicOps<T> ops, T input) {
                DataResult<MapLike<T>> map = ops.getMap(input);
                if (map.isError()) {
                    return codec.decode(ops, input);
                }
                T type = map.getOrThrow().get(ComponentsIngredient.TYPE_KEY);
                if (type == null) {
                    return codec.decode(ops, input);
                }
                if (ops.getStringValue(type).getOrThrow().equals(ComponentsIngredient.STRING_ID)) {
                    return (DataResult) ComponentsIngredient.field_46095.decode(ops, input);
                }
                return codec.decode(ops, input);
            }

            @Override
            public <T> DataResult<T> encode(class_1856 input, DynamicOps<T> ops, T prefix) {
                if (input instanceof ComponentsIngredient componentsIngredient) {
                    return ComponentsIngredient.field_46095.encode(componentsIngredient, ops, prefix);
                }
                return codec.encode(input, ops, prefix);
            }
        };
    }
}
