package net.coderbot.iris.sodium.shader_overrides;

import me.jellysquid.mods.sodium.client.gl.device.RenderDevice;
import me.jellysquid.mods.sodium.client.gl.shader.GlProgram;
import me.jellysquid.mods.sodium.client.gl.shader.GlShader;
import me.jellysquid.mods.sodium.client.gl.shader.ShaderConstants;
import me.jellysquid.mods.sodium.client.gl.shader.ShaderType;
import me.jellysquid.mods.sodium.client.render.chunk.passes.BlockRenderPass;
import me.jellysquid.mods.sodium.client.render.chunk.shader.ChunkProgram;
import me.jellysquid.mods.sodium.client.render.chunk.shader.ChunkShaderBindingPoints;
import net.coderbot.iris.Iris;
import net.coderbot.iris.gl.program.ProgramImages;
import net.coderbot.iris.gl.program.ProgramSamplers;
import net.coderbot.iris.gl.program.ProgramUniforms;
import net.coderbot.iris.pipeline.SodiumTerrainPipeline;
import net.coderbot.iris.pipeline.WorldRenderingPipeline;
import net.coderbot.iris.shadows.ShadowRenderingState;
import net.coderbot.iris.sodium.IrisChunkShaderBindingPoints;
import net.minecraft.util.ResourceLocation;
import org.jetbrains.annotations.Nullable;

import java.util.EnumMap;
import java.util.Locale;
import java.util.Optional;

public class IrisChunkProgramOverrides {
	private static final ShaderConstants EMPTY_CONSTANTS = ShaderConstants.builder().build();

	private final EnumMap<IrisTerrainPass, ChunkProgram> programs = new EnumMap<>(IrisTerrainPass.class);

	private int versionCounterForSodiumShaderReload = -1;

	private GlShader createVertexShader(RenderDevice device, IrisTerrainPass pass, SodiumTerrainPipeline pipeline) {
		Optional<String> irisVertexShader;

		if (pass == IrisTerrainPass.SHADOW) {
			irisVertexShader = pipeline.getShadowVertexShaderSource();
		} else if (pass == IrisTerrainPass.GBUFFER_SOLID) {
			irisVertexShader = pipeline.getTerrainVertexShaderSource();
		} else if (pass == IrisTerrainPass.GBUFFER_TRANSLUCENT) {
			irisVertexShader = pipeline.getTranslucentVertexShaderSource();
		} else {
			throw new IllegalArgumentException("Unknown pass type " + pass);
		}

		String source = irisVertexShader.orElse(null);

		if (source == null) {
			return null;
		}

		return new GlShader(device, ShaderType.VERTEX, new ResourceLocation("iris",
			"sodium-terrain-" + pass.toString().toLowerCase(Locale.ROOT) + ".vsh"), source, EMPTY_CONSTANTS);
	}

	private GlShader createGeometryShader(RenderDevice device, IrisTerrainPass pass, SodiumTerrainPipeline pipeline) {
		Optional<String> irisGeometryShader;

		if (pass == IrisTerrainPass.SHADOW) {
			irisGeometryShader = pipeline.getShadowGeometryShaderSource();
		} else if (pass == IrisTerrainPass.GBUFFER_SOLID) {
			irisGeometryShader = pipeline.getTerrainGeometryShaderSource();
		} else if (pass == IrisTerrainPass.GBUFFER_TRANSLUCENT) {
			irisGeometryShader = pipeline.getTranslucentGeometryShaderSource();
		} else {
			throw new IllegalArgumentException("Unknown pass type " + pass);
		}

		String source = irisGeometryShader.orElse(null);

		if (source == null) {
			return null;
		}

		return new GlShader(device, IrisShaderTypes.GEOMETRY, new ResourceLocation("iris",
			"sodium-terrain-" + pass.toString().toLowerCase(Locale.ROOT) + ".gsh"), source, EMPTY_CONSTANTS);
	}

	private GlShader createFragmentShader(RenderDevice device, IrisTerrainPass pass, SodiumTerrainPipeline pipeline) {
		Optional<String> irisFragmentShader;

		if (pass == IrisTerrainPass.SHADOW) {
			irisFragmentShader = pipeline.getShadowFragmentShaderSource();
		} else if (pass == IrisTerrainPass.GBUFFER_SOLID) {
			irisFragmentShader = pipeline.getTerrainFragmentShaderSource();
		} else if (pass == IrisTerrainPass.GBUFFER_TRANSLUCENT) {
			irisFragmentShader = pipeline.getTranslucentFragmentShaderSource();
		} else {
			throw new IllegalArgumentException("Unknown pass type " + pass);
		}

		String source = irisFragmentShader.orElse(null);

		if (source == null) {
			return null;
		}

		return new GlShader(device, ShaderType.FRAGMENT, new ResourceLocation("iris",
			"sodium-terrain-" + pass.toString().toLowerCase(Locale.ROOT) + ".fsh"), source, EMPTY_CONSTANTS);
	}

	@Nullable
	private ChunkProgram createShader(RenderDevice device, IrisTerrainPass pass, SodiumTerrainPipeline pipeline) {
		GlShader vertShader = createVertexShader(device, pass, pipeline);
		GlShader geomShader = createGeometryShader(device, pass, pipeline);
		GlShader fragShader = createFragmentShader(device, pass, pipeline);

		if (vertShader == null || fragShader == null) {
			if (vertShader != null) {
				vertShader.delete();
			}

			if (geomShader != null) {
				geomShader.delete();
			}

			if (fragShader != null) {
				fragShader.delete();
			}

			// TODO: Partial shader programs?
			return null;
		}

		try {
			GlProgram.Builder builder = GlProgram.builder(new ResourceLocation("sodium", "chunk_shader_for_"
					+ pass.getName()));

			if (geomShader != null) {
				builder.attachShader(geomShader);
			}

			return builder.attachShader(vertShader)
					.attachShader(fragShader)
					.bindAttribute("iris_Pos", ChunkShaderBindingPoints.POSITION)
					.bindAttribute("iris_Color", ChunkShaderBindingPoints.COLOR)
					.bindAttribute("iris_TexCoord", ChunkShaderBindingPoints.TEX_COORD)
					.bindAttribute("iris_LightCoord", ChunkShaderBindingPoints.LIGHT_COORD)
					.bindAttribute("iris_Normal", IrisChunkShaderBindingPoints.NORMAL)
					.bindAttribute("at_tangent", IrisChunkShaderBindingPoints.TANGENT)
					.bindAttribute("mc_midTexCoord", IrisChunkShaderBindingPoints.MID_TEX_COORD)
					.bindAttribute("mc_Entity", IrisChunkShaderBindingPoints.BLOCK_ID)
					.bindAttribute("at_midBlock", IrisChunkShaderBindingPoints.MID_BLOCK)
					.bindAttribute("iris_ModelOffset", ChunkShaderBindingPoints.MODEL_OFFSET)
					.build((program, name) -> new IrisChunkProgram(device, program, name, pass == IrisTerrainPass.SHADOW, pipeline, pipeline.getCustomUniforms()));
		} finally {
			vertShader.delete();
			if (geomShader != null) {
				geomShader.delete();
			}
			fragShader.delete();
		}
	}

	public void createShaders(SodiumTerrainPipeline sodiumTerrainPipeline, RenderDevice device) {
		if (sodiumTerrainPipeline != null) {
			for (IrisTerrainPass pass : IrisTerrainPass.values()) {
				if (pass == IrisTerrainPass.SHADOW && !sodiumTerrainPipeline.hasShadowPass()) {
					this.programs.put(pass, null);
					continue;
				}

				this.programs.put(pass, createShader(device, pass, sodiumTerrainPipeline));
			}
		} else {
			this.programs.clear();
		}
	}

	@Nullable
	public ChunkProgram getProgramOverride(RenderDevice device, BlockRenderPass pass) {
		WorldRenderingPipeline worldRenderingPipeline = Iris.getPipelineManager().getPipelineNullable();
		SodiumTerrainPipeline sodiumTerrainPipeline = null;

		if (worldRenderingPipeline != null) {
			sodiumTerrainPipeline = worldRenderingPipeline.getSodiumTerrainPipeline();
		}

		if (versionCounterForSodiumShaderReload != Iris.getPipelineManager().getVersionCounterForSodiumShaderReload()) {
			versionCounterForSodiumShaderReload = Iris.getPipelineManager().getVersionCounterForSodiumShaderReload();
			deleteShaders();
			createShaders(sodiumTerrainPipeline, device);
		}

		if (ShadowRenderingState.areShadowsCurrentlyBeingRendered()) {
			if (sodiumTerrainPipeline != null && !sodiumTerrainPipeline.hasShadowPass()) {
				throw new IllegalStateException("Shadow program requested, but the pack does not have a shadow pass?");
			}

			return this.programs.get(IrisTerrainPass.SHADOW);
		} else {
			return this.programs.get(pass.isTranslucent() ? IrisTerrainPass.GBUFFER_TRANSLUCENT : IrisTerrainPass.GBUFFER_SOLID);
		}
	}

	public void deleteShaders() {
		for (ChunkProgram program : this.programs.values()) {
			if (program != null) {
				program.delete();
			}
		}

		this.programs.clear();
	}
}
