/*
 * Ex Deorum
 * Copyright (c) 2024 thedarkcolour
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package thedarkcolour.exdeorum.client;

import com.google.common.collect.ImmutableMap;
import com.mojang.blaze3d.vertex.DefaultVertexFormat;
import com.mojang.blaze3d.vertex.PoseStack;
import com.mojang.blaze3d.vertex.VertexConsumer;
import com.mojang.blaze3d.vertex.VertexFormat;
import net.irisshaders.iris.api.v0.IrisApi;
import net.minecraft.client.Minecraft;
import net.minecraft.client.renderer.MultiBufferSource;
import net.minecraft.client.renderer.RenderStateShard;
import net.minecraft.client.renderer.RenderType;
import net.minecraft.client.renderer.ShaderInstance;
import net.minecraft.client.renderer.Sheets;
import net.minecraft.client.renderer.texture.MissingTextureAtlasSprite;
import net.minecraft.client.renderer.texture.TextureAtlas;
import net.minecraft.client.renderer.texture.TextureAtlasSprite;
import net.minecraft.client.resources.model.BakedModel;
import net.minecraft.core.BlockPos;
import net.minecraft.core.registries.BuiltInRegistries;
import net.minecraft.resources.ResourceLocation;
import net.minecraft.util.Mth;
import net.minecraft.util.RandomSource;
import net.minecraft.world.inventory.InventoryMenu;
import net.minecraft.world.level.Level;
import net.minecraft.world.level.block.Block;
import net.minecraft.world.level.levelgen.LegacyRandomSource;
import net.minecraft.world.level.material.Fluid;
import net.neoforged.neoforge.client.extensions.common.IClientFluidTypeExtensions;
import net.neoforged.neoforge.client.model.CompositeModel;
import net.neoforged.neoforge.client.model.data.ModelData;
import org.joml.Vector3f;
import thedarkcolour.exdeorum.ExDeorum;
import thedarkcolour.exdeorum.client.ter.SieveRenderer;

import java.awt.Color;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.util.HashMap;
import java.util.Map;

public class RenderUtil {
    private static final VarHandle COMPOSITE_MODEL_CHILDREN;
    private static final Map<Block, RenderFace> TOP_FACES = new HashMap<>();
    public static final RenderStateShard.ShaderStateShard RENDER_TYPE_TINTED_CUTOUT_MIPPED_SHADER = new RenderStateShard.ShaderStateShard(RenderUtil::getRenderTypeTintedCutoutMippedShader);
    public static final RenderType TINTED_CUTOUT_MIPPED = RenderType.create(ExDeorum.ID + ":tinted_cutout_mipped", DefaultVertexFormat.NEW_ENTITY, VertexFormat.Mode.QUADS, RenderType.SMALL_BUFFER_SIZE, false, false, RenderType.CompositeState.builder().setLightmapState(RenderStateShard.LIGHTMAP).setShaderState(RENDER_TYPE_TINTED_CUTOUT_MIPPED_SHADER).setTextureState(RenderStateShard.BLOCK_SHEET_MIPPED).createCompositeState(true));
    public static TextureAtlas blockAtlas;
    public static ShaderInstance renderTypeTintedCutoutMippedShader;
    public static final IrisAccess IRIS_ACCESS;

    static {
        IrisAccess irisAccess;
        try {
            Class.forName("net.irisshaders.iris.api.v0.IrisApi");
            irisAccess = IrisApi.getInstance()::isShaderPackInUse;
        } catch (ClassNotFoundException e) {
            irisAccess = () -> false;
        }
        IRIS_ACCESS = irisAccess;

        var lookup = MethodHandles.lookup();
        try {
            COMPOSITE_MODEL_CHILDREN = MethodHandles.privateLookupIn(CompositeModel.Baked.class, lookup).findVarHandle(CompositeModel.Baked.class, "children", ImmutableMap.class);
        } catch (NoSuchFieldException | IllegalAccessException e) {
            throw new RuntimeException(e);
        }
    }

    public static void reload() {
        invalidateCaches();
        blockAtlas = Minecraft.getInstance().getModelManager().getAtlas(InventoryMenu.BLOCK_ATLAS);
    }

    public static void invalidateCaches() {
        SieveRenderer.MESH_TEXTURES.clear();
        TOP_FACES.clear();
        blockAtlas = null;
    }

    public static RenderFace getTopFaceOrDefault(Block block, Block defaultBlock) {
        var face = getTopFace(block);
        if (face.isMissingTexture()) {
            return getTopFace(defaultBlock);
        } else {
            return face;
        }
    }

    public static RenderFace getTopFace(Block block) {
        if (TOP_FACES.containsKey(block)) {
            return TOP_FACES.get(block);
        } else {
            var rand = new LegacyRandomSource(block.hashCode());
            BakedModel model = Minecraft.getInstance().getBlockRenderer().getBlockModel(block.defaultBlockState());
            RenderFace face;

            if (model instanceof CompositeModel.Baked composite) {
                @SuppressWarnings("unchecked")
                ImmutableMap<String, BakedModel> children = (ImmutableMap<String, BakedModel>) COMPOSITE_MODEL_CHILDREN.get(composite);
                RenderFace.CompositeLayer[] layers = new RenderFace.CompositeLayer[children.size()];
                int i = 0;

                for (var childModel : children.values()) {
                    var singleFace = getFaceFromModel(block, rand, childModel);
                    layers[i++] = new RenderFace.CompositeLayer(singleFace.renderType(), singleFace.sprite());
                }

                face = new RenderFace.Composite(layers);
            } else {
                face = getFaceFromModel(block, rand, model);
            }

            TOP_FACES.put(block, face);

            return face;
        }
    }

    private static RenderFace.Single getFaceFromModel(Block block, RandomSource rand, BakedModel model) {
        var texture = getTopTexture(block, model);
        var blockTypes = model.getRenderTypes(block.defaultBlockState(), rand, ModelData.EMPTY);
        for (var bufferLayer : RenderType.chunkBufferLayers()) {
            if (blockTypes.contains(bufferLayer)) {
                return new RenderFace.Single(bufferLayer, texture);
            }
        }
        throw new IllegalStateException("No render type found for block " + block);
    }

    private static TextureAtlasSprite getTopTexture(Block block, BakedModel model) {
        var registryName = BuiltInRegistries.BLOCK.getKey(block);
        var sprite = blockAtlas.getSprite(registryName.withPrefix("block/"));
        // for stuff like azalea bush, retry to get the top texture
        if (isMissingTexture(sprite)) {
            sprite = blockAtlas.getSprite(ResourceLocation.fromNamespaceAndPath(registryName.getNamespace(), "block/" + registryName.getPath() + "_top"));
        }
        if (isMissingTexture(sprite)) {
            sprite = model.getParticleIcon(ModelData.EMPTY);
        }
        return sprite;
    }

    public static boolean isMissingTexture(TextureAtlasSprite sprite) {
        return sprite.contents().name() == MissingTextureAtlasSprite.getLocation();
    }

    public static void renderFlatFluidSprite(MultiBufferSource buffers, PoseStack stack, Level level, BlockPos pos, float y, float edge, int light, int r, int g, int b, Fluid fluid) {
        var extensions = IClientFluidTypeExtensions.of(fluid);
        var state = fluid.defaultFluidState();
        var builder = buffers.getBuffer(Sheets.translucentCullBlockSheet());

        RenderUtil.renderFlatSprite(builder, stack, y, r, g, b, RenderUtil.blockAtlas.getSprite(extensions.getStillTexture(state, level, pos)), light, edge);
    }

    @SuppressWarnings("DuplicatedCode")
    public static void renderFluidCube(MultiBufferSource buffers, PoseStack stack, Level level, BlockPos pos, float minY, float maxY, float edge, int light, int r, int g, int b, Fluid fluid) {
        var extensions = IClientFluidTypeExtensions.of(fluid);
        var state = fluid.defaultFluidState();
        var builder = buffers.getBuffer(Sheets.translucentCullBlockSheet());
        var pose = stack.last().pose();
        var poseNormal = stack.last().normal();

        Vector3f normal;
        TextureAtlasSprite sprite = RenderUtil.blockAtlas.getSprite(extensions.getStillTexture(state, level, pos));
        float uMin = sprite.getU0();
        float uMax = sprite.getU1();
        float vMin = sprite.getV0();
        float vMax = sprite.getV1();

        float edgeMin = edge / 16f;
        float edgeMax = 1f - edge / 16f;

        // Top face
        normal = poseNormal.transform(new Vector3f(0, 1, 0));
        builder.addVertex(pose, edgeMin, maxY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, maxY, edgeMax).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, maxY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, maxY, edgeMin).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        // Bottom face
        normal = poseNormal.transform(new Vector3f(0, -1, 0));
        builder.addVertex(pose, edgeMin, minY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, minY, edgeMin).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, minY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, minY, edgeMax).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);

        // Flowing texture coordinates
        //sprite = RenderUtil.blockAtlas.getSprite(extensions.getFlowingTexture(state, level, pos));
        //uMin = sprite.getU0();
        //uMax = sprite.getU(8);
        //vMin = sprite.getV0();
        //vMax = sprite.getV(8);

        // South face
        normal = poseNormal.transform(new Vector3f(0, 0, 1));
        builder.addVertex(pose, edgeMax, maxY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, maxY, edgeMax).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, minY, edgeMax).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, minY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        // North face
        normal = poseNormal.transform(new Vector3f(0, 0, -1));
        builder.addVertex(pose, edgeMin, maxY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, maxY, edgeMin).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, minY, edgeMin).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, minY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        // East face
        normal = poseNormal.transform(new Vector3f(1, 0, 0));
        builder.addVertex(pose, edgeMax, maxY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, maxY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, minY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, minY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        // West face
        normal = poseNormal.transform(new Vector3f(-1, 0, 0));
        builder.addVertex(pose, edgeMin, maxY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, maxY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, minY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, minY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
    }

    // Renders a sprite inside the barrel with the height determined by how full the barrel is.
    public static void renderFlatSpriteLerp(VertexConsumer builder, PoseStack stack, float percentage, int r, int g, int b, TextureAtlasSprite sprite, int light, float edge, float yMin, float yMax) {
        float y = Mth.lerp(percentage, yMin, yMax) / 16f;

        renderFlatSprite(builder, stack, y, r, g, b, sprite, light, edge);
    }

    // Renders a sprite (y should be between 0 and 1)
    @SuppressWarnings("DuplicatedCode")
    public static void renderFlatSprite(VertexConsumer builder, PoseStack stack, float y, int r, int g, int b, TextureAtlasSprite sprite, int light, float edge) {
        var pose = stack.last().pose();
        var normal = stack.last().normal().transform(new Vector3f(0, 1, 0));

        // Position coordinates
        float edgeMin = edge / 16.0f;
        float edgeMax = (16.0f - edge) / 16.0f;

        // Texture coordinates
        float uMin = sprite.getU0();
        float uMax = sprite.getU1();
        float vMin = sprite.getV0();
        float vMax = sprite.getV1();

        // overlayCoords(0, 10) is NO_OVERLAY (0xA0000)
        builder.addVertex(pose, edgeMin, y, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, y, edgeMax).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, y, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, y, edgeMin).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setLight(light).setNormal(normal.x, normal.y, normal.z);
    }

    public static Color getRainbowColor(long time, float partialTicks) {
        return Color.getHSBColor((180 * Mth.sin((time + partialTicks) / 30.0f) - 180) / 360.0f, 0.5f, 0.8f);
    }

    public static ShaderInstance getRenderTypeTintedCutoutMippedShader() {
        return renderTypeTintedCutoutMippedShader;
    }

    public static int getFluidColor(Fluid fluid, Level level, BlockPos pos) {
        return IClientFluidTypeExtensions.of(fluid).getTintColor(fluid.defaultFluidState(), level, pos);
    }

    // todo use ambient occlusion
    // Renders a cuboid using the same side sprite on all six sides
    public static void renderCuboid(VertexConsumer builder, PoseStack stack, float minY, float maxY, int r, int g, int b, TextureAtlasSprite sprite, int light, float edge) {
        var pose = stack.last().pose();
        var poseNormal = stack.last().normal();

        Vector3f normal;
        float uMin = sprite.getU0();
        float uMax = sprite.getU1();
        float vMin = sprite.getV0();
        float vMax = sprite.getV1();

        float edgeMin = edge / 16f;
        float edgeMax = 1f - edge / 16f;

        int lightU = light & '\uffff';
        int lightV = light >> 16 & '\uffff';

        // Top face
        normal = poseNormal.transform(new Vector3f(0, 1, 0));
        builder.addVertex(pose, edgeMin, maxY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, maxY, edgeMax).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, maxY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, maxY, edgeMin).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        // Bottom face
        normal = poseNormal.transform(new Vector3f(0, -1, 0));
        builder.addVertex(pose, edgeMin, minY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, minY, edgeMin).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, minY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, minY, edgeMax).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);

        // Adjust UV based on height of cuboid, rendering from the top down to the bottom of the texture
        float f = sprite.getV1() - sprite.getV0();
        vMax = sprite.getV0() + f * (maxY - minY);

        // South face
        normal = poseNormal.transform(new Vector3f(0, 0, -1));
        builder.addVertex(pose, edgeMax, maxY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, maxY, edgeMax).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, minY, edgeMax).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, minY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        // North face
        normal = poseNormal.transform(new Vector3f(0, 0, -1));
        builder.addVertex(pose, edgeMin, maxY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, maxY, edgeMin).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, minY, edgeMin).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, minY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        // East face
        normal = poseNormal.transform(new Vector3f(1, 0, 0));
        builder.addVertex(pose, edgeMax, maxY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, maxY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, minY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMax, minY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        // West face
        normal = poseNormal.transform(new Vector3f(-1, 0, 0));
        builder.addVertex(pose, edgeMin, maxY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMin).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, maxY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMin).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, minY, edgeMin).setColor(r, g, b, 255).setUv(uMin, vMax).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
        builder.addVertex(pose, edgeMin, minY, edgeMax).setColor(r, g, b, 255).setUv(uMax, vMax).setUv1(0, 10).setUv2(lightU, lightV).setNormal(normal.x, normal.y, normal.z);
    }

    public interface IrisAccess {
        boolean areShadersEnabled();
    }
}
