package io.github.dennisochulor.tickrate.mixin.core;

import com.llamalad7.mixinextras.injector.ModifyExpressionValue;
import com.mojang.brigadier.arguments.FloatArgumentType;
import com.mojang.brigadier.arguments.IntegerArgumentType;
import com.mojang.brigadier.builder.LiteralArgumentBuilder;
import com.mojang.brigadier.context.CommandContext;
import com.mojang.brigadier.exceptions.CommandSyntaxException;
import com.mojang.brigadier.tree.LiteralCommandNode;
import io.github.dennisochulor.tickrate.api.TickRateEvents;
import net.fabricmc.fabric.api.attachment.v1.AttachmentTarget;
import net.minecraft.class_1297;
import net.minecraft.class_1923;
import net.minecraft.class_2168;
import net.minecraft.class_2170;
import net.minecraft.class_2172;
import net.minecraft.class_2186;
import net.minecraft.class_2245;
import net.minecraft.class_2262;
import net.minecraft.class_2264;
import net.minecraft.class_2265;
import net.minecraft.class_2561;
import net.minecraft.class_2806;
import net.minecraft.class_2818;
import net.minecraft.class_3194;
import net.minecraft.class_4076;
import net.minecraft.class_4802;
import net.minecraft.class_8012;
import net.minecraft.class_8915;
import net.minecraft.class_8916;
import org.spongepowered.asm.mixin.*;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.ModifyArg;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable;

import java.util.*;

@Mixin(class_8916.class)
public class TickCommandMixin {

    @Shadow @Final private static String DEFAULT_TICK_RATE_STRING;
    @Shadow private static String format(long nanos) { return ""; }

    // requires method
    @ModifyArg(method = "register", at = @At(value = "INVOKE", target = "Lnet/minecraft/server/command/CommandManager;requirePermissionLevel(I)Lnet/minecraft/command/PermissionLevelPredicate;"))
    private static int modifyPermissionLevel(int requiredLevel) {
        return 2;
    }

    // register the subcommands by modifying literal("tick") return value
    @ModifyExpressionValue(method = "register", at = @At(value = "INVOKE", target = "Lnet/minecraft/server/command/CommandManager;literal(Ljava/lang/String;)Lcom/mojang/brigadier/builder/LiteralArgumentBuilder;", args = "ldc=tick", ordinal = 0))
    private static LiteralArgumentBuilder<class_2168> register(LiteralArgumentBuilder<class_2168> builder) {
        LiteralCommandNode<class_2168> chunkQuery = class_2170.method_9247("query")
                .executes(context -> executeQuery(context.getSource(), chunkCheck(getChunks(context,1), context.getSource()))).build();

        LiteralCommandNode<class_2168> chunkRate = class_2170.method_9247("rate")
                .then(class_2170.method_9247("reset").executes(context -> executeRate(context.getSource(), chunkCheck(getChunks(context,2), context.getSource()), -1.0f)))
                .then(class_2170.method_9244("rate", FloatArgumentType.floatArg(1.0F, 10000.0F))
                        .suggests((context, suggestionsBuilder) -> class_2172.method_9253(new String[]{DEFAULT_TICK_RATE_STRING,"reset"}, suggestionsBuilder))
                        .executes(context -> executeRate(context.getSource(), chunkCheck(getChunks(context,2), context.getSource()), FloatArgumentType.getFloat(context, "rate")))).build();

        LiteralCommandNode<class_2168> chunkUnfreeze = class_2170.method_9247("unfreeze")
                .executes(context -> executeFreeze(context.getSource(), chunkCheck(getChunks(context,1), context.getSource()), false)).build();

        LiteralCommandNode<class_2168> chunkFreeze = class_2170.method_9247("freeze")
                .executes(context -> executeFreeze(context.getSource(), chunkCheck(getChunks(context,1), context.getSource()), true)).build();

        LiteralCommandNode<class_2168> chunkStep = class_2170.method_9247("step")
                .executes(context -> executeStep(context.getSource(), chunkCheck(getChunks(context,1), context.getSource()), 1))
                .then(class_2170.method_9247("stop").executes(context -> executeStep(context.getSource(), chunkCheck(getChunks(context,2), context.getSource()), 0)))
                .then(class_2170.method_9244("time", class_2245.method_48287(1))
                        .suggests((context, suggestionsBuilder) -> class_2172.method_9253(new String[]{"1t", "1s"}, suggestionsBuilder))
                        .executes(context -> executeStep(context.getSource(), chunkCheck(getChunks(context,2), context.getSource()), IntegerArgumentType.getInteger(context, "time")))).build();

        LiteralCommandNode<class_2168> chunkSprint = class_2170.method_9247("sprint")
                .then(class_2170.method_9247("stop").executes(context -> executeSprint(context.getSource(), chunkCheck(getChunks(context,2), context.getSource()), 0)))
                .then(class_2170.method_9244("time", class_2245.method_48287(1))
                        .suggests((context, suggestionsBuilder) -> class_2172.method_9253(new String[]{"60s", "1d", "3d"}, suggestionsBuilder))
                        .executes(context -> executeSprint(context.getSource(), chunkCheck(getChunks(context,2), context.getSource()), IntegerArgumentType.getInteger(context, "time")))).build();

        builder.then(
                class_2170.method_9247("entity")
                        .then(class_2170.method_9244("entities", class_2186.method_9306())
                                .then(class_2170.method_9247("query")
                                        .executes(context -> executeQuery(context.getSource(), entityCheck(class_2186.method_9317(context, "entities"), context.getSource())))
                                )
                                .then(class_2170.method_9247("rate")
                                        .then(class_2170.method_9247("reset").executes(context -> executeRate(context.getSource(), entityCheck(class_2186.method_9317(context, "entities"), context.getSource()), -1.0f)))
                                        .then(class_2170.method_9244("rate", FloatArgumentType.floatArg(1.0F, 10000.0F))
                                                .suggests((context, suggestionsBuilder) -> class_2172.method_9253(new String[]{DEFAULT_TICK_RATE_STRING,"reset"}, suggestionsBuilder))
                                                .executes(context -> executeRate(context.getSource(), entityCheck(class_2186.method_9317(context, "entities"), context.getSource()), FloatArgumentType.getFloat(context, "rate"))))
                                )
                                .then(class_2170.method_9247("unfreeze")
                                        .executes(context -> executeFreeze(context.getSource(), entityCheck(class_2186.method_9317(context, "entities"), context.getSource()), false))
                                )
                                .then(class_2170.method_9247("freeze")
                                        .executes(context -> executeFreeze(context.getSource(), entityCheck(class_2186.method_9317(context, "entities"), context.getSource()), true))
                                )
                                .then(class_2170.method_9247("step")
                                        .executes(context -> executeStep(context.getSource(), entityCheck(class_2186.method_9317(context, "entities"), context.getSource()), 1))
                                        .then(class_2170.method_9247("stop").executes(context -> executeStep(context.getSource(), entityCheck(class_2186.method_9317(context, "entities"), context.getSource()), 0)))
                                        .then(class_2170.method_9244("time", class_2245.method_48287(1))
                                                .suggests((context, suggestionsBuilder) -> class_2172.method_9253(new String[]{"1t", "1s"}, suggestionsBuilder))
                                                .executes(context -> executeStep(context.getSource(), entityCheck(class_2186.method_9317(context, "entities"), context.getSource()), IntegerArgumentType.getInteger(context, "time"))))
                                )
                                .then(class_2170.method_9247("sprint")
                                        .then(class_2170.method_9247("stop").executes(context -> executeSprint(context.getSource(), entityCheck(class_2186.method_9317(context, "entities"), context.getSource()), 0)))
                                        .then(class_2170.method_9244("time", class_2245.method_48287(1))
                                                .suggests((context, suggestionsBuilder) -> class_2172.method_9253(new String[]{"60s", "1d", "3d"}, suggestionsBuilder))
                                                .executes(context -> executeSprint(context.getSource(), entityCheck(class_2186.method_9317(context, "entities"), context.getSource()), IntegerArgumentType.getInteger(context, "time"))))
                                )
                        )
                )
                .then(
                        class_2170.method_9247("chunk")
                                .then(class_2170.method_9244("from", class_2264.method_9701())
                                        .then(chunkQuery)
                                        .then(chunkRate)
                                        .then(chunkUnfreeze)
                                        .then(chunkFreeze)
                                        .then(chunkStep)
                                        .then(chunkSprint)
                                        .then(class_2170.method_9244("to", class_2264.method_9701())
                                                .then(chunkQuery)
                                                .then(chunkRate)
                                                .then(chunkUnfreeze)
                                                .then(chunkFreeze)
                                                .then(chunkStep)
                                                .then(chunkSprint)
                                        )
                                        .then(class_2170.method_9247("radius")
                                                .then(class_2170.method_9244("radius", FloatArgumentType.floatArg(1))
                                                        .then(chunkQuery)
                                                        .then(chunkRate)
                                                        .then(chunkUnfreeze)
                                                        .then(chunkFreeze)
                                                        .then(chunkStep)
                                                        .then(chunkSprint)
                                                )
                                        )
                                )
                );
        return builder;
    }

    /**
     * @author Ninjaking312
     * @reason To round the rate and set server (not mainloop) rate
     */
    @Overwrite
    private static int executeRate(class_2168 source, float rate) {
        int roundRate = Math.round(rate); // can't actually accept decimals
        class_8915 tickManager = source.method_9211().method_54833();
        tickManager.tickRate$setServerRate(roundRate);
        TickRateEvents.SERVER_RATE.invoker().onServerRate(source.method_9211(), roundRate);
        source.method_9226(() -> class_2561.method_43469("commands.tick.rate.success", roundRate), true);
        return roundRate;
    }

    /**
     * @author Ninjaking312
     * @reason To add distinction between mainloop and server tick rate
     */
    @Overwrite
    private static int executeQuery(class_2168 source) {
        class_8915 serverTickManager = source.method_9211().method_54833();
        if (serverTickManager.method_54670()) {
            source.method_9226(() -> class_2561.method_43471("commands.tick.status.sprinting"), false);
        }
        else {
            if (serverTickManager.method_54754()) {
                source.method_9226(() -> class_2561.method_43471("commands.tick.status.frozen"), false);
            } else if (serverTickManager.method_54750() < source.method_9211().method_54834()) {
                source.method_9226(() -> class_2561.method_43471("commands.tick.status.lagging"), false);
            } else {
                source.method_9226(() -> class_2561.method_43471("commands.tick.status.running"), false);
            }
        }

        source.method_9226(() -> class_2561.method_43470("Server's target tick rate: " + serverTickManager.tickRate$getServerRate() + " per second (" + format((long)((double) class_4802.field_33868 / (double)serverTickManager.tickRate$getServerRate())) + " mspt)"), false);
        source.method_9226(() -> class_2561.method_43470("Mainloop's target tick rate: " + Math.round(serverTickManager.method_54748()) + " per second (" + format((long)((double) class_4802.field_33868 / (double)serverTickManager.method_54748())) + " mspt)"), false);

        long[] ls = Arrays.copyOf(source.method_9211().method_54835(), source.method_9211().method_54835().length);
        Arrays.sort(ls);
        String p50 = format(ls[ls.length / 2]);
        String p95 = format(ls[(int)((double)ls.length * 0.95)]);
        String p99 = format(ls[(int)((double)ls.length * 0.99)]);
        float avg = source.method_9211().method_54832();
        source.method_9226(() -> class_2561.method_43470("Avg: %.1fms P50: %sms P95: %sms P99: %sms, sample: %s".formatted(avg,p50,p95,p99,ls.length)), false);
        return serverTickManager.tickRate$getServerRate();
    }

    @Inject(method = "executeSprint", at = @At("TAIL"))
    private static void executeSprint(class_2168 source, int ticks, CallbackInfoReturnable<Integer> cir) {
        TickRateEvents.SERVER_SPRINT.invoker().onServerSprint(source.method_9211(), ticks);
    }

    @Inject(method = "executeFreeze", at = @At("TAIL"))
    private static void executeFreeze(class_2168 source, boolean frozen, CallbackInfoReturnable<Integer> cir) {
        TickRateEvents.SERVER_FREEZE.invoker().onServerFreeze(source.method_9211(), frozen);
    }

    @Inject(method = "executeStep", at = @At("TAIL"))
    private static void executeStep(class_2168 source, int ticks, CallbackInfoReturnable<Integer> cir) {
        TickRateEvents.SERVER_STEP.invoker().onServerStep(source.method_9211(), ticks);
    }


    @Unique
    private static int executeRate(class_2168 source, List<? extends AttachmentTarget> targets, float rate) {
        if(targets == null) return 0;

        int roundRate = Math.round(rate); // can't actually accept decimals
        ServerTickManager tickManager = source.getServer().getTickManager();
        tickManager.tickRate$setRate(roundRate, targets);

        String targetType;
        switch(targets.getFirst()) {
            case Entity ignored -> {
                targetType = "entities";
                targets.forEach(target -> TickRateEvents.ENTITY_RATE.invoker().onEntityRate((Entity) target, roundRate==-1 ? 0 : roundRate));
            }
            case WorldChunk ignored -> {
                targetType = "chunks";
                targets.forEach(target -> TickRateEvents.CHUNK_RATE.invoker().onChunkRate((WorldChunk) target, roundRate==-1 ? 0 : roundRate));
            }
            default -> throw new IllegalArgumentException("Unknown target type: " + targets.getFirst());
        }

        if(roundRate != -1) {
            source.sendFeedback(() -> Text.of("Set tick rate of " + targets.size() + " " + targetType + " to " + roundRate + " TPS."), false);
            return roundRate;
        }
        else {
            source.sendFeedback(() -> Text.literal("Reset the target rate of " + targets.size() + " " + targetType), false);
            return 0;
        }
    }

    @Unique
    private static int executeQuery(class_2168 source, List<? extends AttachmentTarget> targets) {
        if(targets == null) return 0;

        ServerTickManager tickManager = source.getServer().getTickManager();
        int firstRate;
        String targetType;
        StringBuilder sb = new StringBuilder();
        switch(targets.getFirst()) {
            case Entity first -> {
                targetType = "entities";
                firstRate = tickManager.tickRate$getEntityRate(first);
                targets.forEach(e -> {
                    Entity entity = (Entity) e;
                    sb.append(entity.getType().getName().getString()).append(" ").append(entity.getNameForScoreboard()).append(" - ").append(tickManager.tickRate$getEntityRate(entity)).append(" TPS").append("\n");
                });
            }
            case WorldChunk first -> {
                targetType = "chunks";
                firstRate = tickManager.tickRate$getChunkRate(first);
                targets.forEach(chunk -> {
                    WorldChunk worldChunk = (WorldChunk) chunk;
                    sb.append("Chunk ").append(worldChunk.getPos().toString()).append(" - ").append(tickManager.tickRate$getChunkRate(worldChunk)).append(" TPS").append("\n");
                });
            }
            default -> throw new IllegalArgumentException("Unknown target type: " + targets.getFirst());
        }

        sb.insert(0, "The tick rates of the specified " + targetType + " are as follows:\n");
        sb.deleteCharAt(sb.length()-1); // to remove last \n
        source.sendFeedback(() -> Text.of(sb.toString()), false);
        return firstRate;
    }

    @Unique
    private static int executeFreeze(class_2168 source, List<? extends AttachmentTarget> targets, boolean frozen) {
        if(targets == null) return 0;

        ServerTickManager tickManager = source.getServer().getTickManager();
        tickManager.tickRate$setFrozen(frozen, targets);

        String targetType;
        switch(targets.getFirst()) {
            case Entity ignored -> {
                targetType = "entities";
                targets.forEach(entity -> TickRateEvents.ENTITY_FREEZE.invoker().onEntityFreeze((Entity) entity, frozen));
            }
            case WorldChunk ignored -> {
                targetType = "chunks";
                targets.forEach(chunk -> TickRateEvents.CHUNK_FREEZE.invoker().onChunkFreeze((WorldChunk) chunk, frozen));
            }
            default -> throw new IllegalArgumentException("Unknown target type: " + targets.getFirst());
        }

        source.sendFeedback(() -> Text.literal(targets.size() + " " + targetType + " have been " + (frozen ? "frozen." : "unfrozen.")), false);
        return 1;
    }

    @Unique
    private static int executeStep(class_2168 source, List<? extends AttachmentTarget> targets, int steps) {
        if(targets == null) return 0;

        ServerTickManager tickManager = source.getServer().getTickManager();
        boolean success = tickManager.tickRate$step(steps, targets);

        String targetType;
        switch(targets.getFirst()) {
            case Entity ignored -> {
                targetType = "entities";
                if(success && steps != 0) targets.forEach(entity -> TickRateEvents.ENTITY_STEP.invoker().onEntityStep((Entity) entity, steps));
            }
            case WorldChunk ignored -> {
                targetType = "chunks";
                if(success && steps != 0) targets.forEach(chunk -> TickRateEvents.CHUNK_STEP.invoker().onChunkStep((WorldChunk) chunk, steps));
            }
            default -> throw new IllegalArgumentException("Unknown target type: " + targets.getFirst());
        }

        if(success) {
            if(steps != 0) source.sendFeedback(() -> Text.literal(targets.size() + " " + targetType + " will step " + steps + " ticks."), false);
            else source.sendFeedback(() -> Text.literal(targets.size() + " " + targetType + " have stopped stepping."), false);
        }
        else source.sendFeedback(() -> Text.literal("All of the specified " + targetType + " must be frozen first and cannot be sprinting!").withColor(Colors.LIGHT_RED), false);
        return success ? 1 : 0;
    }

    @Unique
    private static int executeSprint(class_2168 source, List<? extends AttachmentTarget> targets, int ticks) {
        if(targets == null) return 0;

        ServerTickManager tickManager = source.getServer().getTickManager();
        boolean success = tickManager.tickRate$sprint(ticks, targets);

        String targetType;
        switch(targets.getFirst()) {
            case Entity ignored -> {
                targetType = "entities";
                if(success && ticks != 0) targets.forEach(entity -> TickRateEvents.ENTITY_SPRINT.invoker().onEntitySprint((Entity) entity, ticks));
            }
            case WorldChunk ignored -> {
                targetType = "chunks";
                if(success && ticks != 0) targets.forEach(chunk -> TickRateEvents.CHUNK_SPRINT.invoker().onChunkSprint((WorldChunk) chunk, ticks));
            }
            default -> throw new IllegalArgumentException("Unknown target type: " + targets.getFirst());
        }

        if(success) {
            if(ticks != 0) source.sendFeedback(() -> Text.literal(targets.size() + " " + targetType + " will sprint " + ticks + " ticks."), false);
            else source.sendFeedback(() -> Text.literal(targets.size() + " " + targetType + " have stopped sprinting."), false);
        }
        else source.sendFeedback(() -> Text.literal("All of the specified " + targetType + " must not be stepping!").withColor(Colors.LIGHT_RED), false);
        return success ? 1 : 0;
    }



    @Unique
    // returns NULL if any of the entities cannot be the command's target
    private static List<? extends class_1297> entityCheck(Collection<? extends class_1297> entities, class_2168 source) {
        return (List<? extends class_1297>) entities;
    }

    @Unique
    // returns NULL if any of the chunks cannot be the command's target (meaning they are unloaded)
    private static List<class_2818> chunkCheck(List<class_1923> chunks, class_2168 source) {
        List<class_2818> worldChunks = new ArrayList<>();
        boolean match = chunks.stream().anyMatch(chunkPos -> {
            class_2818 worldChunk = (class_2818) source.method_9225().method_8402(chunkPos.field_9181,chunkPos.field_9180,class_2806.field_12803,false);
            worldChunks.add(worldChunk);
            return worldChunk==null || worldChunk.method_12225() == class_3194.field_19334;
        });

        if(match) {
            source.method_9226(() -> class_2561.method_43470("Some of the specified chunks are not loaded!").method_54663(class_8012.field_46652), false);
            return null;
        }
        return worldChunks;
    }

    /**
     * @param depth number of steps back up the command tree to get to the node right before the chunkOperations
     */
    @Unique
    private static List<class_1923> getChunks(CommandContext<class_2168> context, int depth) throws CommandSyntaxException {
        // CommandContext#getArgument is not used because it throws an Exception when not found, which is not great for performance
        String lastNode = context.getNodes().get(context.getNodes().size() - depth - 1).getNode().getName();
        return switch(lastNode) {
            case "from" -> {
                class_2265 from = class_2264.method_9702(context, "from");
                if (from.comp_638() < -30000000 || from.comp_639() < -30000000 || from.comp_638() >= 30000000 || from.comp_639() >= 30000000)
                    throw class_2262.field_10704.create();
                yield List.of(from.method_34873());
            }
            case "to" -> {
                // logic taken from ForceLoadCommand :)
                class_2265 from = class_2264.method_9702(context, "from");
                class_2265 to = class_2264.method_9702(context, "to");
                int minX = Math.min(from.comp_638(), to.comp_638());
                int minZ = Math.min(from.comp_639(), to.comp_639());
                int maxX = Math.max(from.comp_638(), to.comp_638());
                int maxZ = Math.max(from.comp_639(), to.comp_639());
                if (minX < -30000000 || minZ < -30000000 || maxX >= 30000000 || maxZ >= 30000000)
                    throw class_2262.field_10704.create();

                int chunkMinX = class_4076.method_18675(minX);
                int chunkMinZ = class_4076.method_18675(minZ);
                int chunkMaxX = class_4076.method_18675(maxX);
                int chunkMaxZ = class_4076.method_18675(maxZ);
                List<class_1923> chunks = new ArrayList<>();
                for(int chunkX = chunkMinX; chunkX <= chunkMaxX; chunkX++) {
                    for(int chunkZ = chunkMinZ; chunkZ <= chunkMaxZ; chunkZ++) {
                        chunks.add(new class_1923(chunkX, chunkZ));
                    }
                }
                yield chunks;
            }
            case "radius" -> {
                // logic taken from ForceLoadCommand :)
                class_2265 circleCentre = class_2264.method_9702(context, "from");
                float radius = FloatArgumentType.getFloat(context, "radius");
                float minX = circleCentre.comp_638() - radius;
                float minZ = circleCentre.comp_639() - radius;
                float maxX = circleCentre.comp_638() + radius;
                float maxZ = circleCentre.comp_639() + radius;
                if (minX < -30000000 || minZ < -30000000 || maxX >= 30000000 || maxZ >= 30000000)
                    throw class_2262.field_10704.create();

                int chunkMinX = class_4076.method_32204(minX);
                int chunkMinZ = class_4076.method_32204(minZ);
                int chunkMaxX = class_4076.method_32204(maxX);
                int chunkMaxZ = class_4076.method_32204(maxZ);
                List<class_1923> chunks = new ArrayList<>();
                for(int chunkX = chunkMinX; chunkX <= chunkMaxX; chunkX++) {
                    for(int chunkZ = chunkMinZ; chunkZ <= chunkMaxZ; chunkZ++) {
                        // https://www.geeksforgeeks.org/check-if-any-point-overlaps-the-given-circle-and-rectangle/
                        class_1923 chunkPos = new class_1923(chunkX, chunkZ);
                        class_4076 chunkSectionPos = class_4076.method_18681(chunkPos,0);
                        int X1 = chunkSectionPos.method_19527();
                        int X2 = chunkSectionPos.method_19530();
                        int Z1 = chunkSectionPos.method_19529();
                        int Z2 = chunkSectionPos.method_19532();
                        int Xc = circleCentre.comp_638();
                        int Zc = circleCentre.comp_639();

                        // find closest point of chunk to centre of circle
                        int Xn = Math.max(X1, Math.min(Xc, X2));
                        int Yn = Math.max(Z1, Math.min(Zc, Z2));

                        // find distance between nearest point and circle centre
                        int Dx = Xn - Xc;
                        int Dz = Yn - Zc;
                        if((Dx * Dx + Dz * Dz) <= radius * radius) chunks.add(chunkPos); // if overlap, add it
                    }
                }
                yield chunks;
            }
            default -> throw new IllegalStateException("Unexpected value: " + lastNode);
        };
    }

}
