/*
 * Decompiled with CFR 0.152.
 */
package dev.djefrey.colorwheel.compile.oit;

import dev.djefrey.colorwheel.ClrwlSamplers;
import dev.djefrey.colorwheel.Colorwheel;
import dev.djefrey.colorwheel.compile.ClrwlCompilation;
import dev.djefrey.colorwheel.compile.ClrwlPipelineStage;
import dev.djefrey.colorwheel.compile.ClrwlPipelines;
import dev.djefrey.colorwheel.compile.oit.ClrwlOitCompositeShaderKey;
import dev.djefrey.colorwheel.engine.ClrwlOitAccumulateOverride;
import dev.engine_room.flywheel.backend.compile.core.FailedCompilation;
import dev.engine_room.flywheel.backend.compile.core.ProgramLinker;
import dev.engine_room.flywheel.backend.compile.core.ShaderResult;
import dev.engine_room.flywheel.backend.gl.GlCompat;
import dev.engine_room.flywheel.backend.gl.shader.GlProgram;
import dev.engine_room.flywheel.backend.gl.shader.GlShader;
import dev.engine_room.flywheel.backend.glsl.ShaderSources;
import dev.engine_room.flywheel.backend.glsl.SourceComponent;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import net.irisshaders.iris.helpers.StringPair;
import net.minecraft.resources.ResourceLocation;
import org.lwjgl.opengl.GL20;

public class ClrwlOitPrograms {
    private static final ResourceLocation DEPTH = Colorwheel.rl("internal/oit/depth.frag");
    private final ShaderSources sources;
    private final Map<ClrwlOitCompositeShaderKey, GlProgram> compositeProgramCache = new HashMap<ClrwlOitCompositeShaderKey, GlProgram>();

    public ClrwlOitPrograms(ShaderSources sources) {
        this.sources = sources;
    }

    public GlProgram getOitCompositeProgram(int[] drawBuffers, int[] ranks, List<ClrwlOitAccumulateOverride> overrides) {
        ClrwlOitCompositeShaderKey key = new ClrwlOitCompositeShaderKey(drawBuffers, ranks, overrides);
        return this.compositeProgramCache.computeIfAbsent(key, this::compileComposite);
    }

    private GlProgram compileComposite(ClrwlOitCompositeShaderKey key) {
        String id = ClrwlPipelines.OIT_COMPOSITE.id();
        GlShader vertex = this.compileStage(id, ClrwlPipelines.OIT_COMPOSITE.vertex(), key).unwrap();
        GlShader fragment = this.compileStage(id, ClrwlPipelines.OIT_COMPOSITE.fragment(), key).unwrap();
        ProgramLinker linker = new ProgramLinker();
        GlProgram program = linker.link(List.of(vertex, fragment), $ -> {});
        program.bind();
        program.setUniformBlockBinding("_ClrwlFrameUniforms", 0);
        program.setSamplerBinding("_flw_depthRange", ClrwlSamplers.DEPTH_RANGE);
        for (int i = 0; i < key.ranks().length; ++i) {
            program.setSamplerBinding("clrwl_coefficients" + i, ClrwlSamplers.getCoefficient(i));
        }
        int[] drawBuffers = key.drawBuffers();
        for (int i = 0; i < drawBuffers.length; ++i) {
            program.setSamplerBinding("_clrwl_accumulate" + drawBuffers[i], ClrwlSamplers.getAccumulate(i));
        }
        GlProgram.unbind();
        return program;
    }

    private <K> ShaderResult compileStage(String name, ClrwlPipelineStage<K> stage, K key) {
        ClrwlCompilation compile = new ClrwlCompilation(null, null, null, null, this.sources);
        compile.version(GlCompat.MAX_GLSL_VERSION);
        for (String string : stage.extensions()) {
            compile.enableExtension(string);
        }
        for (StringPair stringPair : stage.defines()) {
            compile.define(stringPair);
        }
        stage.compile().accept(key, compile);
        for (BiFunction biFunction : stage.fetchers()) {
            ClrwlOitPrograms.expand((SourceComponent)biFunction.apply(key, compile), compile::appendComponent);
        }
        String source = compile.getShaderCode();
        int n = GL20.glCreateShader((int)stage.type().glEnum);
        GlCompat.safeShaderSource((int)n, (CharSequence)source);
        GL20.glCompileShader((int)n);
        String shaderName = name + "." + stage.type().extension;
        String infoLog = GL20.glGetShaderInfoLog((int)n);
        if (GL20.glGetShaderi((int)n, (int)35713) == 1) {
            return ShaderResult.success((GlShader)new GlShader(n, stage.type().toFlw().orElseThrow(), shaderName), (String)infoLog);
        }
        GL20.glDeleteShader((int)n);
        return ShaderResult.failure((FailedCompilation)new FailedCompilation(shaderName, List.of(), "", source, infoLog));
    }

    private static void expand(SourceComponent rootSource, Consumer<SourceComponent> out) {
        LinkedHashSet<SourceComponent> included = new LinkedHashSet<SourceComponent>();
        ClrwlOitPrograms.recursiveDepthFirstInclude(included, rootSource);
        included.add(rootSource);
        included.forEach(out);
    }

    private static void recursiveDepthFirstInclude(Set<SourceComponent> included, SourceComponent component) {
        for (SourceComponent include : component.included()) {
            ClrwlOitPrograms.recursiveDepthFirstInclude(included, include);
        }
        included.addAll(component.included());
    }
}

