package games.enchanted.eg_stop_unloading_my_shaders.common.mixin.shader.gl;

import com.llamalad7.mixinextras.injector.wrapmethod.WrapMethod;
import com.llamalad7.mixinextras.injector.wrapoperation.Operation;
import com.llamalad7.mixinextras.injector.wrapoperation.WrapOperation;
import com.mojang.blaze3d.pipeline.RenderPipeline;
import com.mojang.blaze3d.shaders.ShaderType;
import games.enchanted.eg_stop_unloading_my_shaders.common.Logging;
import games.enchanted.eg_stop_unloading_my_shaders.common.ModConstants;
import games.enchanted.eg_stop_unloading_my_shaders.common.ShaderReloadManager;
import games.enchanted.eg_stop_unloading_my_shaders.common.translations.Messages;
import org.slf4j.Logger;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Slice;

import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import net.minecraft.class_10151;
import net.minecraft.class_10865;
import net.minecraft.class_10867;
import net.minecraft.class_2561;
import net.minecraft.class_2960;

@Mixin(class_10865.class)
public class GlDeviceMixin {
    @WrapOperation(
        at = @At(value = "INVOKE", target = "Ljava/util/Map;computeIfAbsent(Ljava/lang/Object;Ljava/util/function/Function;)Ljava/lang/Object;"),
        method = "getOrCompileShader"
    )
    private <K, V> Object eg_sumr$bypassShaderCacheToForceVanillaFallback(Map<K, V> shaderCache, K compilationKey, Function<? super K, ? extends V> mapping, Operation<V> original) {
        if(ShaderReloadManager.shouldLoadVanillaFallback()) return mapping.apply(compilationKey);

        return original.call(shaderCache, compilationKey, mapping);
    }

    @WrapMethod(
        method = "compilePipeline"
    )
    private class_10867 eg_sumr$checkForFailedShaderCompilationAndTriggerVanillaFallback(RenderPipeline pipeline, BiFunction<class_2960, ShaderType, String> shaderSource, Operation<class_10867> original) {
        ShaderReloadManager.setShouldLoadVanillaFallback(false);
        class_10867 compiledPipeline = original.call(pipeline, shaderSource);
        if(compiledPipeline.isValid()) return compiledPipeline;

        ShaderReloadManager.setShouldLoadVanillaFallback(true);
        return original.call(
            pipeline,
            (BiFunction<class_2960, ShaderType, String>)(class_2960 location, ShaderType shaderType) -> {
                return ModConstants.getVanillaShaderConfigs().comp_3106().get(new class_10151.class_10155(location, shaderType));
            }
        );
    }

    // wraps the error call after attempting to link gl programs
    @WrapOperation(
        at = @At(value = "INVOKE", target = "Lorg/slf4j/Logger;error(Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)V", remap = false),
        slice = @Slice(from = @At(value = "INVOKE", target = "Lcom/mojang/blaze3d/opengl/GlProgram;link(Lcom/mojang/blaze3d/opengl/GlShaderModule;Lcom/mojang/blaze3d/opengl/GlShaderModule;Lcom/mojang/blaze3d/vertex/VertexFormat;Ljava/lang/String;)Lcom/mojang/blaze3d/opengl/GlProgram;")),
        method = "compilePipeline"
    )
    private void eg_sumr$logShaderLinkerError(Logger instance, String string, Object oLocation, Object oCompilationException, Operation<Void> original) {
        original.call(instance, string, oLocation, oCompilationException);
        if(oLocation instanceof class_2960 location1 && oCompilationException instanceof class_10151.class_10152 compilationException) {
            ShaderReloadManager.showShaderErrorMessage(
                Messages.getFailedToLinkMessage(location1.toString()),
                class_2561.method_43470(compilationException.getMessage())
            );
        } else {
            ShaderReloadManager.showShaderErrorMessage(
                Messages.getFailedToLinkMessage(oLocation.toString()),
                Messages.getCouldntGetFullErrorMessage()
            );
            Logging.error("eg_sumr$logShaderLinkerError did not receive the correct parameters, please report this! Got '%s' and '%s'".formatted(
                oLocation.getClass().getName(),
                oCompilationException.getClass().getName()
            ));
        }
    }

    // wraps the first error call after checking if the shader source exists
    @WrapOperation(
        at = @At(value = "INVOKE", target = "Lorg/slf4j/Logger;error(Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)V", ordinal = 0, remap = false),
        method = "compileShader"
    )
    private void eg_sumr$showFailedToFindShaderMessage(Logger instance, String string, Object oShaderType, Object oLocation, Operation<Void> original) {
        if(oLocation instanceof class_2960 location && oShaderType instanceof ShaderType shaderType) {
            ShaderReloadManager.showShaderErrorMessage(
                Messages.getCouldntFindSourceMessage(shaderType.getName(), location.toString()),
                null
            );
        } else {
            ShaderReloadManager.showShaderErrorMessage(
                Messages.getCouldntFindSourceMessage(oShaderType.toString(), oLocation.toString()),
                null
            );
            Logging.error("eg_sumr$showFailedToFindShaderMessage did not receive the correct parameters, please report this! Got '%s' and '%s'".formatted(
                oShaderType.getClass().getName(),
                oLocation.getClass().getName()
            ));
        }
        original.call(instance, string, oShaderType, oLocation);
    }

    // wraps the second error call after getting the shader compilation info
    @WrapOperation(
        slice = @Slice(from = @At(value = "INVOKE", target = "Lcom/mojang/blaze3d/opengl/GlStateManager;glGetShaderInfoLog(II)Ljava/lang/String;", remap = false)),
        at = @At(
            value = "INVOKE",
            target = "Lorg/slf4j/Logger;error(Ljava/lang/String;[Ljava/lang/Object;)V",
            ordinal = 0,
            remap = false
        ),
        method = "compileShader"
    )
    private void eg_sumr$showFailedToCompileShaderMessage(Logger instance, String string, Object[] objects, Operation<Void> original) {
        if(objects[1] instanceof class_2960 location && objects[0] instanceof String typeName && objects[2] instanceof String compilationMessage) {
            ShaderReloadManager.showShaderErrorMessage(
                Messages.getCouldntCompileShaderMessage(typeName, location.toString()),
                class_2561.method_43470(compilationMessage)
            );
        } else {
            ShaderReloadManager.showShaderErrorMessage(
                Messages.getCouldntCompileShaderMessage(objects[0].toString(), objects[1].toString()),
                Messages.getCouldntGetFullErrorMessage()
            );
            Logging.error("eg_sumr$showFailedToCompileShaderMessage did not receive the correct parameters, please report this! Got '%s', '%s', and '%s'".formatted(
                objects[0].getClass().getName(),
                objects[1].getClass().getName(),
                objects[2].getClass().getName()
            ));
        }
        original.call(instance, string, objects);
    }
}
