package com.zurrtum.create.client.content.trains.track;

import com.mojang.blaze3d.vertex.PoseStack;
import com.mojang.blaze3d.vertex.VertexConsumer;
import com.zurrtum.create.AllItemTags;
import com.zurrtum.create.AllShapes;
import com.zurrtum.create.catnip.data.Iterate;
import com.zurrtum.create.catnip.data.WorldAttached;
import com.zurrtum.create.catnip.math.AngleHelper;
import com.zurrtum.create.catnip.math.VecHelper;
import com.zurrtum.create.client.catnip.animation.AnimationTickHolder;
import com.zurrtum.create.client.flywheel.lib.transform.TransformStack;
import com.zurrtum.create.client.foundation.utility.RaycastHelper;
import com.zurrtum.create.content.trains.track.*;
import com.zurrtum.create.infrastructure.component.BezierTrackPointLocation;
import net.minecraft.client.Minecraft;
import net.minecraft.client.player.LocalPlayer;
import net.minecraft.client.renderer.MultiBufferSource;
import net.minecraft.client.renderer.rendertype.RenderTypes;
import net.minecraft.core.BlockPos;
import net.minecraft.core.Direction;
import net.minecraft.core.Direction.Axis;
import net.minecraft.util.Mth;
import net.minecraft.world.entity.ai.attributes.Attributes;
import net.minecraft.world.level.GameType;
import net.minecraft.world.level.block.state.BlockState;
import net.minecraft.world.phys.AABB;
import net.minecraft.world.phys.BlockHitResult;
import net.minecraft.world.phys.HitResult.Type;
import net.minecraft.world.phys.Vec3;
import net.minecraft.world.phys.shapes.Shapes;
import net.minecraft.world.phys.shapes.VoxelShape;
import org.apache.commons.lang3.mutable.MutableBoolean;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;

public class TrackBlockOutline {

    public static WorldAttached<Map<BlockPos, TrackBlockEntity>> TRACKS_WITH_TURNS = new WorldAttached<>(w -> new HashMap<>());

    public static BezierPointSelection result;

    public static void pickCurves(Minecraft mc) {
        if (!(mc.getCameraEntity() instanceof LocalPlayer player))
            return;
        if (mc.level == null)
            return;

        Vec3 origin = player.getEyePosition(AnimationTickHolder.getPartialTicks(mc.level));

        double maxRange = mc.hitResult == null ? Double.MAX_VALUE : mc.hitResult.getLocation().distanceToSqr(origin);

        result = null;

        double range = player.getAttributeValue(Attributes.BLOCK_INTERACTION_RANGE);
        Vec3 target = RaycastHelper.getTraceTarget(player, Math.min(maxRange, range) + 1, origin);
        Map<BlockPos, TrackBlockEntity> turns = TRACKS_WITH_TURNS.get(mc.level);

        for (TrackBlockEntity be : turns.values()) {
            for (BezierConnection bc : be.getConnections().values()) {
                if (!bc.isPrimary())
                    continue;

                AABB bounds = bc.getBounds();
                if (!bounds.contains(origin) && bounds.clip(origin, target).isEmpty())
                    continue;

                float[] stepLUT = bc.getStepLUT();
                int segments = (int) (bc.getLength() * 2);
                AABB segmentBounds = AllShapes.TRACK_ORTHO.get(Direction.SOUTH).bounds();
                segmentBounds = segmentBounds.move(-.5, segmentBounds.getYsize() / -2, -.5);

                int bestSegment = -1;
                double bestDistance = Double.MAX_VALUE;
                double newMaxRange = maxRange;

                for (int i = 0; i < stepLUT.length - 2; i++) {
                    float t = stepLUT[i] * i / segments;
                    float t1 = stepLUT[i + 1] * (i + 1) / segments;
                    float t2 = stepLUT[i + 2] * (i + 2) / segments;

                    Vec3 v1 = bc.getPosition(t);
                    Vec3 v2 = bc.getPosition(t2);
                    Vec3 diff = v2.subtract(v1);
                    Vec3 angles = TrackRenderer.getModelAngles(bc.getNormal(t1), diff);

                    Vec3 anchor = v1.add(diff.scale(.5));
                    Vec3 localOrigin = origin.subtract(anchor);
                    Vec3 localDirection = target.subtract(origin);
                    localOrigin = VecHelper.rotate(localOrigin, AngleHelper.deg(-angles.x), Axis.X);
                    localOrigin = VecHelper.rotate(localOrigin, AngleHelper.deg(-angles.y), Axis.Y);
                    localDirection = VecHelper.rotate(localDirection, AngleHelper.deg(-angles.x), Axis.X);
                    localDirection = VecHelper.rotate(localDirection, AngleHelper.deg(-angles.y), Axis.Y);

                    Optional<Vec3> clip = segmentBounds.clip(localOrigin, localOrigin.add(localDirection));
                    if (clip.isEmpty())
                        continue;

                    if (bestSegment != -1 && bestDistance < clip.get().distanceToSqr(0, 0.25f, 0))
                        continue;

                    double distanceToSqr = clip.get().distanceToSqr(localOrigin);
                    if (distanceToSqr > maxRange)
                        continue;

                    bestSegment = i;
                    newMaxRange = distanceToSqr;
                    bestDistance = clip.get().distanceToSqr(0, 0.25f, 0);

                    BezierTrackPointLocation location = new BezierTrackPointLocation(bc.getKey(), i);
                    result = new BezierPointSelection(be, location, anchor, angles, diff.normalize());
                }

                if (bestSegment != -1)
                    maxRange = newMaxRange;
            }
        }

        if (result == null)
            return;

        if (mc.hitResult != null && mc.hitResult.getType() != Type.MISS) {
            Vec3 priorLoc = mc.hitResult.getLocation();
            mc.hitResult = BlockHitResult.miss(priorLoc, Direction.UP, BlockPos.containing(priorLoc));
        }
    }

    public static void drawCurveSelection(Minecraft mc, PoseStack ms, MultiBufferSource buffer, Vec3 camera) {
        if (mc.options.hideGui || mc.gameMode.getPlayerMode() == GameType.SPECTATOR)
            return;

        BezierPointSelection result = TrackBlockOutline.result;
        if (result == null)
            return;

        VertexConsumer vb = buffer.getBuffer(RenderTypes.lines());
        Vec3 vec = result.vec().subtract(camera);
        Vec3 angles = result.angles();
        TransformStack.of(ms).pushPose().translate(vec.x, vec.y + .125f, vec.z).rotateY((float) angles.y).rotateX((float) angles.x)
            .translate(-.5, -.125f, -.5);

        boolean holdingTrack = mc.player.getMainHandItem().is(AllItemTags.TRACKS);
        renderShape(AllShapes.TRACK_ORTHO.get(Direction.SOUTH), ms, vb, holdingTrack ? false : null);
        ms.popPose();
    }

    public static boolean drawCustomBlockSelection(Minecraft mc, BlockPos pos, MultiBufferSource vertexConsumers, Vec3 camPos, PoseStack ms) {
        BlockState blockstate = mc.level.getBlockState(pos);

        if (!(blockstate.getBlock() instanceof TrackBlock))
            return false;
        if (!mc.level.getWorldBorder().isWithinBounds(pos))
            return false;

        VertexConsumer vb = vertexConsumers.getBuffer(RenderTypes.lines());

        ms.pushPose();
        ms.translate(pos.getX() - camPos.x, pos.getY() - camPos.y, pos.getZ() - camPos.z);

        boolean holdingTrack = mc.player.getMainHandItem().is(AllItemTags.TRACKS);
        TrackShape shape = blockstate.getValue(TrackBlock.SHAPE);
        boolean canConnectFrom = !shape.isJunction() && !(mc.level.getBlockEntity(pos) instanceof TrackBlockEntity tbe && tbe.isTilted());

        MutableBoolean cancel = new MutableBoolean();
        walkShapes(
            shape, TransformStack.of(ms), s -> {
                renderShape(s, ms, vb, holdingTrack ? canConnectFrom : null);
                cancel.setTrue();
            }
        );

        ms.popPose();
        return cancel.isTrue();
    }

    public static void renderShape(VoxelShape s, PoseStack ms, VertexConsumer vb, Boolean valid) {
        PoseStack.Pose transform = ms.last();
        s.forAllEdges((x1, y1, z1, x2, y2, z2) -> {
            float xDiff = (float) (x2 - x1);
            float yDiff = (float) (y2 - y1);
            float zDiff = (float) (z2 - z1);
            float length = Mth.sqrt(xDiff * xDiff + yDiff * yDiff + zDiff * zDiff);

            xDiff /= length;
            yDiff /= length;
            zDiff /= length;

            float r = 0f;
            float g = 0f;
            float b = 0f;

            if (valid != null && valid) {
                g = 1f;
                b = 1f;
                r = 1f;
            }

            if (valid != null && !valid) {
                r = 1f;
                b = 0.125f;
                g = 0.25f;
            }

            vb.addVertex(transform.pose(), (float) x1, (float) y1, (float) z1).setColor(r, g, b, .4f).setNormal(transform.copy(), xDiff, yDiff, zDiff)
                .setLineWidth(1);
            vb.addVertex(transform.pose(), (float) x2, (float) y2, (float) z2).setColor(r, g, b, .4f).setNormal(transform.copy(), xDiff, yDiff, zDiff)
                .setLineWidth(1);

        });
    }

    private static final VoxelShape LONG_CROSS = Shapes.or(TrackVoxelShapes.longOrthogonalZ(), TrackVoxelShapes.longOrthogonalX());
    private static final VoxelShape LONG_ORTHO = TrackVoxelShapes.longOrthogonalZ();
    private static final VoxelShape LONG_ORTHO_OFFSET = TrackVoxelShapes.longOrthogonalZOffset();

    private static void walkShapes(TrackShape shape, TransformStack<?> msr, Consumer<VoxelShape> renderer) {
        float angle45 = Mth.PI / 4;

        if (shape == TrackShape.XO || shape == TrackShape.CR_NDX || shape == TrackShape.CR_PDX)
            renderer.accept(AllShapes.TRACK_ORTHO.get(Direction.EAST));
        else if (shape == TrackShape.ZO || shape == TrackShape.CR_NDZ || shape == TrackShape.CR_PDZ)
            renderer.accept(AllShapes.TRACK_ORTHO.get(Direction.SOUTH));

        if (shape.isPortal()) {
            for (Direction d : Iterate.horizontalDirections) {
                if (TrackShape.asPortal(d) != shape)
                    continue;
                msr.rotateCentered(AngleHelper.rad(AngleHelper.horizontalAngle(d)), Direction.UP);
                renderer.accept(LONG_ORTHO_OFFSET);
                return;
            }
        }

        if (shape == TrackShape.PD || shape == TrackShape.CR_PDX || shape == TrackShape.CR_PDZ) {
            msr.rotateCentered(angle45, Direction.UP);
            renderer.accept(LONG_ORTHO);
        } else if (shape == TrackShape.ND || shape == TrackShape.CR_NDX || shape == TrackShape.CR_NDZ) {
            msr.rotateCentered(-Mth.PI / 4, Direction.UP);
            renderer.accept(LONG_ORTHO);
        }

        if (shape == TrackShape.CR_O)
            renderer.accept(AllShapes.TRACK_CROSS);
        else if (shape == TrackShape.CR_D) {
            msr.rotateCentered(angle45, Direction.UP);
            renderer.accept(LONG_CROSS);
        }

        if (!(shape == TrackShape.AE || shape == TrackShape.AN || shape == TrackShape.AW || shape == TrackShape.AS))
            return;

        msr.translate(0, 1, 0);
        msr.rotateCentered(Mth.PI - AngleHelper.rad(shape.getModelRotation()), Direction.UP);
        msr.rotateX(angle45);
        msr.translate(0, -3 / 16f, 1 / 16f);
        renderer.accept(LONG_ORTHO);
    }

    public record BezierPointSelection(
        TrackBlockEntity blockEntity, BezierTrackPointLocation loc, Vec3 vec, Vec3 angles, Vec3 direction
    ) {
    }

    public static void registerToCurveInteraction(TrackBlockEntity be) {
        TRACKS_WITH_TURNS.get(be.getLevel()).put(be.getBlockPos(), be);
    }

    public static void removeFromCurveInteraction(TrackBlockEntity be) {
        TRACKS_WITH_TURNS.get(be.getLevel()).remove(be.getBlockPos());
    }
}
