package com.zurrtum.create.content.trains.entity;

import com.mojang.serialization.*;
import com.zurrtum.create.Create;
import com.zurrtum.create.catnip.data.Couple;
import com.zurrtum.create.catnip.data.Pair;
import com.zurrtum.create.content.trains.graph.*;
import com.zurrtum.create.content.trains.signal.TrackEdgePoint;
import org.jetbrains.annotations.Nullable;

import java.util.*;
import java.util.function.*;
import net.minecraft.class_11368;
import net.minecraft.class_11372;
import net.minecraft.class_243;
import net.minecraft.class_3532;

public class TravellingPoint {

    public TrackNode node1, node2;
    public TrackEdge edge;
    public double position;
    public boolean blocked;
    public boolean upsideDown;

    public enum SteerDirection {
        NONE(0),
        LEFT(-1),
        RIGHT(1);

        final float targetDot;

        SteerDirection(float targetDot) {
            this.targetDot = targetDot;
        }
    }

    public interface ITrackSelector extends BiFunction<TrackGraph, Pair<Boolean, List<Map.Entry<TrackNode, TrackEdge>>>, Map.Entry<TrackNode, TrackEdge>> {
    }

    ;

    public interface IEdgePointListener extends BiPredicate<Double, Pair<TrackEdgePoint, Couple<TrackNode>>> {
    }

    ;

    public interface ITurnListener extends BiConsumer<Double, TrackEdge> {
    }

    ;

    public interface IPortalListener extends Predicate<Couple<TrackNodeLocation>> {
    }

    ;

    public TravellingPoint() {
    }

    public TravellingPoint(TrackNode node1, TrackNode node2, TrackEdge edge, double position, boolean upsideDown) {
        this.node1 = node1;
        this.node2 = node2;
        this.edge = edge;
        this.position = position;
        this.upsideDown = upsideDown;
    }

    public IEdgePointListener ignoreEdgePoints() {
        return (d, c) -> false;
    }

    public ITurnListener ignoreTurns() {
        return (d, c) -> {
        };
    }

    public IPortalListener ignorePortals() {
        return $ -> false;
    }

    public ITrackSelector follow(TravellingPoint other) {
        return follow(other, null);
    }

    public ITrackSelector follow(TravellingPoint other, @Nullable Consumer<Boolean> success) {
        return (graph, pair) -> {
            List<Map.Entry<TrackNode, TrackEdge>> validTargets = pair.getSecond();
            boolean forward = pair.getFirst();
            TrackNode target = forward ? other.node1 : other.node2;
            TrackNode secondary = forward ? other.node2 : other.node1;

            for (Map.Entry<TrackNode, TrackEdge> entry : validTargets)
                if (entry.getKey() == target || entry.getKey() == secondary) {
                    if (success != null)
                        success.accept(true);
                    return entry;
                }

            List<List<Map.Entry<TrackNode, TrackEdge>>> frontiers = new ArrayList<>(validTargets.size());
            List<Set<TrackEdge>> visiteds = new ArrayList<>(validTargets.size());

            for (Map.Entry<TrackNode, TrackEdge> validTarget : validTargets) {
                ArrayList<Map.Entry<TrackNode, TrackEdge>> e = new ArrayList<>();
                e.add(validTarget);
                frontiers.add(e);
                HashSet<TrackEdge> e2 = new HashSet<>();
                e2.add(validTarget.getValue());
                visiteds.add(e2);
            }

            for (int i = 0; i < 20; i++) {
                for (int j = 0; j < validTargets.size(); j++) {
                    Map.Entry<TrackNode, TrackEdge> entry = validTargets.get(j);
                    List<Map.Entry<TrackNode, TrackEdge>> frontier = frontiers.get(j);
                    if (frontier.isEmpty())
                        continue;

                    Map.Entry<TrackNode, TrackEdge> currentEntry = frontier.remove(0);
                    for (Map.Entry<TrackNode, TrackEdge> nextEntry : graph.getConnectionsFrom(currentEntry.getKey()).entrySet()) {
                        TrackEdge nextEdge = nextEntry.getValue();
                        if (!visiteds.get(j).add(nextEdge))
                            continue;
                        if (!currentEntry.getValue().canTravelTo(nextEdge))
                            continue;

                        TrackNode nextNode = nextEntry.getKey();
                        if (nextNode == target) {
                            if (success != null)
                                success.accept(true);
                            return entry;
                        }

                        frontier.add(nextEntry);
                    }
                }
            }

            if (success != null)
                success.accept(false);
            return validTargets.get(0);
        };
    }

    public ITrackSelector steer(SteerDirection direction, class_243 upNormal) {
        return (graph, pair) -> {
            List<Map.Entry<TrackNode, TrackEdge>> validTargets = pair.getSecond();
            double closest = Double.MAX_VALUE;
            Map.Entry<TrackNode, TrackEdge> best = null;

            for (Map.Entry<TrackNode, TrackEdge> entry : validTargets) {
                class_243 trajectory = edge.getDirection(false);
                class_243 entryTrajectory = entry.getValue().getDirection(true);
                class_243 normal = trajectory.method_1036(upNormal);
                double dot = normal.method_1026(entryTrajectory);
                double diff = Math.abs(direction.targetDot - dot);
                if (diff > closest)
                    continue;

                closest = diff;
                best = entry;
            }

            if (best == null) {
                Create.LOGGER.warn("Couldn't find steer target, choosing first");
                return validTargets.getFirst();
            }

            return best;
        };
    }

    public double travel(TrackGraph graph, double distance, ITrackSelector trackSelector) {
        return travel(graph, distance, trackSelector, ignoreEdgePoints());
    }

    public double travel(TrackGraph graph, double distance, ITrackSelector trackSelector, IEdgePointListener signalListener) {
        return travel(graph, distance, trackSelector, signalListener, ignoreTurns());
    }

    public double travel(
        TrackGraph graph,
        double distance,
        ITrackSelector trackSelector,
        IEdgePointListener signalListener,
        ITurnListener turnListener
    ) {
        return travel(graph, distance, trackSelector, signalListener, turnListener, ignorePortals());
    }

    public double travel(
        TrackGraph graph,
        double distance,
        ITrackSelector trackSelector,
        IEdgePointListener signalListener,
        ITurnListener turnListener,
        IPortalListener portalListener
    ) {
        blocked = false;
        if (edge == null)
            return 0;
        double edgeLength = edge.getLength();
        if (class_3532.method_20390(distance, 0))
            return 0;

        double prevPos = position;
        double traveled = distance;
        double currentT = edgeLength == 0 ? 0 : position / edgeLength;
        double incrementT = edge.incrementT(currentT, distance);
        position = incrementT * edgeLength;

        // FIXME: using incrementT like this becomes inaccurate at medium-long distances
        // travelling points would travel only 50m instead of 100m due to the low
        // incrementT at their starting position (e.g. bezier turn)
        // In an ideal scenario the amount added to position would iterate the traversed
        // edges for context first

        // A workaround was added in TrackEdge::incrementT

        List<Map.Entry<TrackNode, TrackEdge>> validTargets = new ArrayList<>();

        boolean forward = distance > 0;
        double collectedDistance = forward ? -prevPos : -edgeLength + prevPos;

        Double blockedLocation = edgeTraversedFrom(graph, forward, signalListener, turnListener, prevPos, collectedDistance);
        if (blockedLocation != null) {
            position = blockedLocation;
            traveled = position - prevPos;
            return traveled;
        }

        if (forward) {
            // Moving forward
            while (position > edgeLength) {
                validTargets.clear();

                for (Map.Entry<TrackNode, TrackEdge> entry : graph.getConnectionsFrom(node2).entrySet()) {
                    TrackNode newNode = entry.getKey();
                    if (newNode == node1)
                        continue;

                    TrackEdge newEdge = entry.getValue();
                    if (!edge.canTravelTo(newEdge))
                        continue;

                    validTargets.add(entry);
                }

                if (validTargets.isEmpty()) {
                    traveled -= position - edgeLength;
                    position = edgeLength;
                    blocked = true;
                    break;
                }

                Map.Entry<TrackNode, TrackEdge> entry = validTargets.size() == 1 ? validTargets.get(0) : trackSelector.apply(
                    graph,
                    Pair.of(true, validTargets)
                );

                if (entry.getValue().getLength() == 0 && portalListener.test(Couple.create(node2.getLocation(), entry.getKey().getLocation()))) {
                    traveled -= position - edgeLength;
                    position = edgeLength;
                    blocked = true;
                    break;
                }

                node1 = node2;
                node2 = entry.getKey();
                edge = entry.getValue();
                position -= edgeLength;

                collectedDistance += edgeLength;
                if (edge.isTurn())
                    turnListener.accept(collectedDistance, edge);

                blockedLocation = edgeTraversedFrom(graph, forward, signalListener, turnListener, 0, collectedDistance);

                if (blockedLocation != null) {
                    traveled -= position;
                    position = blockedLocation;
                    traveled += position;
                    break;
                }

                prevPos = 0;
                edgeLength = edge.getLength();
            }

        } else {
            // Moving backwards
            while (position < 0) {
                validTargets.clear();

                for (Map.Entry<TrackNode, TrackEdge> entry : graph.getConnectionsFrom(node1).entrySet()) {
                    TrackNode newNode = entry.getKey();
                    if (newNode == node2)
                        continue;
                    if (!graph.getConnectionsFrom(newNode).get(node1).canTravelTo(edge))
                        continue;

                    validTargets.add(entry);
                }

                if (validTargets.isEmpty()) {
                    traveled -= position;
                    position = 0;
                    blocked = true;
                    break;
                }

                Map.Entry<TrackNode, TrackEdge> entry = validTargets.size() == 1 ? validTargets.get(0) : trackSelector.apply(
                    graph,
                    Pair.of(false, validTargets)
                );

                if (entry.getValue().getLength() == 0 && portalListener.test(Couple.create(entry.getKey().getLocation(), node1.getLocation()))) {
                    traveled -= position;
                    position = 0;
                    blocked = true;
                    break;
                }

                node2 = node1;
                node1 = entry.getKey();
                edge = graph.getConnectionsFrom(node1).get(node2);
                collectedDistance += edgeLength;
                edgeLength = edge.getLength();
                position += edgeLength;

                blockedLocation = edgeTraversedFrom(graph, forward, signalListener, turnListener, edgeLength, collectedDistance);

                if (blockedLocation != null) {
                    traveled -= position;
                    position = blockedLocation;
                    traveled += position;
                    break;
                }
            }

        }

        return traveled;
    }

    protected Double edgeTraversedFrom(
        TrackGraph graph,
        boolean forward,
        IEdgePointListener edgePointListener,
        ITurnListener turnListener,
        double prevPos,
        double totalDistance
    ) {
        if (edge.isTurn())
            turnListener.accept(Math.max(0, totalDistance), edge);

        double from = forward ? prevPos : position;
        double to = forward ? position : prevPos;

        EdgeData edgeData = edge.getEdgeData();
        List<TrackEdgePoint> edgePoints = edgeData.getPoints();

        double length = edge.getLength();
        for (int i = 0; i < edgePoints.size(); i++) {
            int index = forward ? i : edgePoints.size() - i - 1;
            TrackEdgePoint nextBoundary = edgePoints.get(index);
            double locationOn = nextBoundary.getLocationOn(edge);
            double distance = forward ? locationOn : length - locationOn;
            if (forward ? (locationOn < from || locationOn >= to) : (locationOn <= from || locationOn > to))
                continue;
            Couple<TrackNode> nodes = Couple.create(node1, node2);
            if (edgePointListener.test(totalDistance + distance, Pair.of(nextBoundary, forward ? nodes : nodes.swap())))
                return locationOn;
        }

        return null;
    }

    public void reverse(TrackGraph graph) {
        TrackNode n = node1;
        node1 = node2;
        node2 = n;
        position = edge.getLength() - position;
        edge = graph.getConnectionsFrom(node1).get(node2);
    }

    public class_243 getPosition(@Nullable TrackGraph trackGraph) {
        return getPosition(trackGraph, false);
    }

    public class_243 getPosition(@Nullable TrackGraph trackGraph, boolean flipUpsideDown) {
        return getPositionWithOffset(trackGraph, 0, flipUpsideDown);
    }

    public class_243 getPositionWithOffset(@Nullable TrackGraph trackGraph, double offset, boolean flipUpsideDown) {
        double t = (position + offset) / edge.getLength();
        return edge.getPosition(trackGraph, t).method_1019(edge.getNormal(trackGraph, t).method_1021(upsideDown ^ flipUpsideDown ? -1 : 1));
    }

    public void migrateTo(List<TrackGraphLocation> locations) {
        TrackGraphLocation location = locations.removeFirst();
        TrackGraph graph = location.graph;
        node1 = graph.locateNode(location.edge.getFirst());
        node2 = graph.locateNode(location.edge.getSecond());
        position = location.position;
        edge = graph.getConnectionsFrom(node1).get(node2);
    }

    public void write(class_11372 view, DimensionPalette dimensions) {
        if (Objects.isNull(node1) || Objects.isNull(node2))
            return;
        class_11372.class_11374 list = view.method_71476("Nodes");
        node1.getLocation().write(list.method_71480(), dimensions);
        node2.getLocation().write(list.method_71480(), dimensions);
        view.method_71463("Position", position);
        view.method_71472("UpsideDown", upsideDown);
    }

    public static <T> DataResult<T> encode(final TravellingPoint input, final DynamicOps<T> ops, final T empty, DimensionPalette dimensions) {
        RecordBuilder<T> map = ops.mapBuilder();
        if (Objects.isNull(input.node1) || Objects.isNull(input.node2))
            return map.build(empty);
        ListBuilder<T> list = ops.listBuilder();
        list.add(TrackNodeLocation.encode(input.node1.getLocation(), ops, empty, dimensions));
        list.add(TrackNodeLocation.encode(input.node2.getLocation(), ops, empty, dimensions));
        map.add("Nodes", list.build(empty));
        map.add("Position", ops.createDouble(input.position));
        map.add("UpsideDown", ops.createBoolean(input.upsideDown));
        return map.build(empty);
    }

    public static TravellingPoint read(class_11368 view, TrackGraph graph, DimensionPalette dimensions) {
        if (graph == null)
            return new TravellingPoint(null, null, null, 0, false);

        Couple<TrackNode> locs = view.method_71436("Nodes").map(list -> {
            Iterator<class_11368> iterator = list.iterator();
            return Couple.create(
                graph.locateNode(TrackNodeLocation.read(iterator.next(), dimensions)),
                graph.locateNode(TrackNodeLocation.read(iterator.next(), dimensions))
            );
        }).orElseGet(() -> Couple.create(null, null));

        if (locs.either(Objects::isNull))
            return new TravellingPoint(null, null, null, 0, false);

        double position = view.method_71422("Position", 0);
        return new TravellingPoint(
            locs.getFirst(),
            locs.getSecond(),
            graph.getConnectionsFrom(locs.getFirst()).get(locs.getSecond()),
            position,
            view.method_71433("UpsideDown", false)
        );
    }

    public static <T> TravellingPoint decode(DynamicOps<T> ops, T input, TrackGraph graph, DimensionPalette dimensions) {
        if (graph == null)
            return new TravellingPoint(null, null, null, 0, false);

        MapLike<T> map = ops.getMap(input).getOrThrow();
        Couple<TrackNode> locs = ops.getStream(map.get("Nodes")).result().map(stream -> {
            Iterator<T> iterator = stream.iterator();
            return Couple.create(
                graph.locateNode(TrackNodeLocation.decode(ops, iterator.next(), dimensions)),
                graph.locateNode(TrackNodeLocation.decode(ops, iterator.next(), dimensions))
            );
        }).orElseGet(() -> Couple.create(null, null));

        if (locs.either(Objects::isNull))
            return new TravellingPoint(null, null, null, 0, false);

        double position = ops.getNumberValue(map.get("Position"), 0).doubleValue();
        return new TravellingPoint(
            locs.getFirst(),
            locs.getSecond(),
            graph.getConnectionsFrom(locs.getFirst()).get(locs.getSecond()),
            position,
            ops.getBooleanValue(map.get("UpsideDown")).getOrThrow()
        );
    }

}
