/*
 * Decompiled with CFR 0.152.
 */
package io.github.flemmli97.fateubw.client.particles;

import com.mojang.blaze3d.vertex.PoseStack;
import com.mojang.blaze3d.vertex.VertexConsumer;
import com.mojang.datafixers.util.Pair;
import io.github.flemmli97.fateubw.common.particles.trail.TrailInfo;
import io.github.flemmli97.fateubw.common.particles.trail.TrailPositions;
import java.util.ArrayList;
import net.minecraft.client.Camera;
import net.minecraft.client.Minecraft;
import net.minecraft.util.Mth;
import net.minecraft.world.entity.Entity;
import net.minecraft.world.phys.Vec3;
import org.jetbrains.annotations.Nullable;
import org.joml.Matrix4f;
import org.joml.Matrix4fc;
import org.joml.Vector3f;
import org.joml.Vector3fc;
import org.joml.Vector4f;

public class TrailRenderer {
    public static void render(Entity entity, TrailInfo info, TrailPositions positions, VertexConsumer consumer, float partialTicks) {
        PoseStack stack = new PoseStack();
        Vec3 vec3 = Minecraft.getInstance().getEntityRenderDispatcher().camera.getPosition();
        double lerpX = Mth.lerp((double)partialTicks, (double)entity.xo, (double)entity.getX());
        double lerpY = Mth.lerp((double)partialTicks, (double)entity.yo, (double)entity.getY());
        double lerpZ = Mth.lerp((double)partialTicks, (double)entity.zo, (double)entity.getZ());
        double dx = lerpX - vec3.x();
        double dy = lerpY - vec3.y();
        double dz = lerpZ - vec3.z();
        stack.translate(dx, dy, dz);
        TrailRenderer.render(info, positions, stack, consumer, Minecraft.getInstance().getEntityRenderDispatcher().camera, (float)lerpX, (float)lerpY, (float)lerpZ, (float)entity.getX(), (float)entity.getY(), (float)entity.getZ(), 0.0f, 1.0f, 0.0f, 1.0f);
    }

    public static void render(TrailInfo info, TrailPositions position, PoseStack stack, VertexConsumer buffer, Camera camera, float partialX, float partialY, float partialZ, float x, float y, float z, float u0, float u1, float v0, float v1) {
        if (position == null || position.size() < 2) {
            return;
        }
        Matrix4f mat = stack.last().pose();
        ArrayList<Pair> positions = new ArrayList<Pair>();
        for (int i = 0; i < position.size() - 1; ++i) {
            TrailPositions.TrailPosition next2;
            TrailPositions.TrailPosition next;
            TrailPositions.TrailPosition pos = position.getAt(i);
            if (pos == null) continue;
            TrailPositions.TrailPosition previous = position.getAt(i - 1);
            if (previous == null) {
                previous = pos;
            }
            if ((next = position.getAt(i + 1)) == null) {
                next = pos;
            }
            if ((next2 = position.getAt(i + 2)) == null) {
                next2 = next;
            }
            float step = 1.0f / (float)Math.max(1, info.interpolation());
            for (float j = 0.0f; j < 1.0f; j += step) {
                Vector3f prevPos;
                Vector3f stepPos = TrailRenderer.catmullRom(j, previous.pos(), pos.pos(), next.pos(), next2.pos()).sub(partialX, partialY, partialZ);
                if (stepPos == null) continue;
                Vector3f stepNormal = TrailRenderer.catmullRom(j, previous.normal(), pos.normal(), next.normal(), next2.normal());
                Vector3f previousNormal = null;
                if (!positions.isEmpty()) {
                    prevPos = (Vector3f)((Pair)positions.getLast()).getFirst();
                    previousNormal = (Vector3f)((Pair)positions.getLast()).getSecond();
                } else {
                    prevPos = i == 0 ? next.pos().toVector3f().sub(partialX, partialY, partialZ) : previous.pos().toVector3f().sub(partialX, partialY, partialZ);
                }
                Vector3f normal = TrailRenderer.calculateNormal(i == 0 ? stepPos : prevPos, i == 0 ? prevPos : stepPos, stepNormal, previousNormal, camera);
                positions.add(Pair.of((Object)stepPos, (Object)normal));
            }
        }
        TrailPositions.TrailPosition current = position.getLast();
        if (current != null) {
            Vector3f currentPos = current.pos().toVector3f().sub(x, y, z);
            Pair last = (Pair)positions.getLast();
            positions.add(Pair.of((Object)currentPos, (Object)TrailRenderer.calculateNormal((Vector3f)last.getFirst(), currentPos, current.normal() != null ? current.normal().toVector3f() : null, (Vector3f)last.getSecond(), camera)));
        }
        int size = position.getLength() * info.interpolation();
        int diff = Math.abs(size - (positions.size() - 1));
        for (int i = 0; i < positions.size() - 1; ++i) {
            Vector4f[] vertices;
            Pair pos = (Pair)positions.get(i);
            Pair next = (Pair)positions.get(i + 1);
            float prog = Mth.clamp((float)((float)(i + diff) / (float)size), (float)0.0f, (float)1.0f);
            float progNext = Mth.clamp((float)((float)(i + diff + 1) / (float)size), (float)0.0f, (float)1.0f);
            for (Vector4f vert : vertices = TrailRenderer.vertices(info, (Vector3f)pos.getFirst(), (Vector3f)next.getFirst(), (Vector3f)pos.getSecond(), (Vector3f)next.getSecond(), prog, progNext)) {
                vert.mul((Matrix4fc)mat);
            }
            float ulen = u1 - u0;
            float u0p = Mth.clamp((float)(u0 + prog * ulen), (float)u0, (float)u1);
            float u1p = Mth.clamp((float)(u0 + progNext * ulen), (float)u0, (float)u1);
            float r = Mth.lerp((float)prog, (float)info.r2(), (float)info.r());
            float g = Mth.lerp((float)prog, (float)info.g2(), (float)info.g());
            float b = Mth.lerp((float)prog, (float)info.b2(), (float)info.b());
            float a = Mth.lerp((float)prog, (float)info.a2(), (float)info.a());
            float r2 = Mth.lerp((float)progNext, (float)info.r2(), (float)info.r());
            float g2 = Mth.lerp((float)progNext, (float)info.g2(), (float)info.g());
            float b2 = Mth.lerp((float)progNext, (float)info.b2(), (float)info.b());
            float a2 = Mth.lerp((float)progNext, (float)info.a2(), (float)info.a());
            TrailRenderer.draw(buffer, vertices, u0p, u1p, v0, v1, r, g, b, a, r2, g2, b2, a2);
        }
    }

    protected static Vector3f catmullRom(float delta, @Nullable Vec3 p1, @Nullable Vec3 p2, @Nullable Vec3 p3, @Nullable Vec3 p4) {
        if (delta == 0.0f) {
            return p2 != null ? p2.toVector3f() : null;
        }
        if (delta == 1.0f) {
            return p3 != null ? p3.toVector3f() : null;
        }
        if (p1 == null || p2 == null || p3 == null || p4 == null) {
            return null;
        }
        if (p2.equals((Object)p3)) {
            return p2.toVector3f();
        }
        return new Vector3f(Mth.catmullrom((float)delta, (float)((float)p1.x()), (float)((float)p2.x()), (float)((float)p3.x()), (float)((float)p4.x())), Mth.catmullrom((float)delta, (float)((float)p1.y()), (float)((float)p2.y()), (float)((float)p3.y()), (float)((float)p4.y())), Mth.catmullrom((float)delta, (float)((float)p1.z()), (float)((float)p2.z()), (float)((float)p3.z()), (float)((float)p4.z())));
    }

    protected static Vector3f calculateNormal(Vector3f from, Vector3f to, @Nullable Vector3f normal, @Nullable Vector3f previousNormal, Camera camera) {
        if (normal != null) {
            return normal.normalize();
        }
        if (from.equals((Object)to)) {
            return previousNormal != null ? previousNormal : new Vector3f(0.0f, 1.0f, 0.0f);
        }
        Vector3f target = to.add((Vector3fc)from, new Vector3f());
        return target.cross((Vector3fc)camera.getLookVector(), new Vector3f()).normalize();
    }

    protected static Vector4f[] vertices(TrailInfo info, Vector3f current, Vector3f next, Vector3f currentNorm, Vector3f nextNorm, float progPrev, float prog) {
        float scale = Mth.lerp((float)progPrev, (float)info.width2(), (float)info.width());
        float scaleNext = Mth.lerp((float)prog, (float)info.width2(), (float)info.width());
        Vector4f vert_1 = new Vector4f(current.x() - currentNorm.x() * scale, current.y() - currentNorm.y() * scale, current.z() - currentNorm.z() * scale, 1.0f);
        Vector4f vert_2 = new Vector4f(current.x() + currentNorm.x() * scale, current.y() + currentNorm.y() * scale, current.z() + currentNorm.z() * scale, 1.0f);
        Vector4f vert_3 = new Vector4f(next.x() + nextNorm.x() * scaleNext, next.y() + nextNorm.y() * scaleNext, next.z() + nextNorm.z() * scaleNext, 1.0f);
        Vector4f vert_4 = new Vector4f(next.x() - nextNorm.x() * scaleNext, next.y() - nextNorm.y() * scaleNext, next.z() - nextNorm.z() * scaleNext, 1.0f);
        return new Vector4f[]{vert_1, vert_2, vert_3, vert_4};
    }

    protected static void draw(VertexConsumer buffer, Vector4f[] vertices, float u0, float u1, float v0, float v1, float r, float g, float b, float a, float r2, float g2, float b2, float a2) {
        for (int i = 0; i < vertices.length; i += 4) {
            buffer.addVertex(vertices[i].x(), vertices[i].y(), vertices[i].z()).setUv(u0, v1).setColor(r, g, b, a).setLight(0xFF00FF);
            buffer.addVertex(vertices[i + 1].x(), vertices[i + 1].y(), vertices[i + 1].z()).setUv(u0, v0).setColor(r, g, b, a).setLight(0xFF00FF);
            buffer.addVertex(vertices[i + 2].x(), vertices[i + 2].y(), vertices[i + 2].z()).setUv(u1, v0).setColor(r2, g2, b2, a2).setLight(0xFF00FF);
            buffer.addVertex(vertices[i + 3].x(), vertices[i + 3].y(), vertices[i + 3].z()).setUv(u1, v1).setColor(r2, g2, b2, a2).setLight(0xFF00FF);
        }
    }
}

