package io.github.dennisochulor.tickrate.mixin.core;

import io.github.dennisochulor.tickrate.TickRateS2CUpdatePayload;
import io.github.dennisochulor.tickrate.injected_interface.TickRateTickManager;
import io.github.dennisochulor.tickrate.TickState;
import net.fabricmc.fabric.api.networking.v1.ServerPlayNetworking;
import net.minecraft.class_1297;
import net.minecraft.class_1923;
import net.minecraft.class_1937;
import net.minecraft.class_2487;
import net.minecraft.class_2507;
import net.minecraft.class_2509;
import net.minecraft.class_3222;
import net.minecraft.class_4802;
import net.minecraft.class_5218;
import net.minecraft.class_5321;
import net.minecraft.class_8915;
import net.minecraft.class_8921;
import net.minecraft.server.MinecraftServer;
import org.spongepowered.asm.mixin.*;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfo;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable;

import java.io.File;
import java.io.IOException;
import java.util.*;

@Mixin(class_8915.class)
public abstract class ServerTickManagerMixin extends class_8921 implements TickRateTickManager {

    @Unique private float nominalTickRate = 20.0f;
    @Unique private int ticks = 0;
    @Unique private final Map<String,Float> entities = new HashMap<>(); // uuid -> tickRate
    @Unique private final Map<String,Float> chunks = new HashMap<>(); // world-longChunkPos -> tickRate
    @Unique private final Map<String,Float> unloadedEntities = new HashMap<>(); // uuid -> tickRate
    @Unique private final Map<String,Float> unloadedChunks = new HashMap<>(); // world-longChunkPos -> tickRate
    @Unique private final Map<String,Boolean> ticked = new HashMap<>(); // world-longChunkPos -> hasTickedThisMainloopTick, needed to ensure ChunkTickState is only updated ONCE per mainloop tick
    @Unique private final Map<String,Integer> steps = new HashMap<>(); // uuid/world-longChunkPos -> steps, if steps==0, then it's frozen
    @Unique private final Map<String,Integer> sprinting = new HashMap<>(); // uuid/world-longChunkPos -> sprintTicks
    @Unique private final Set<class_3222> playersWithMod = new HashSet<>(); // stores players that have this mod client-side
    @Unique private int sprintAvgTicksPerSecond = -1;
    @Unique private File datafile;

    @Shadow public abstract void method_54671(float tickRate);
    @Shadow @Final private MinecraftServer server;
    @Shadow private long scheduledSprintTicks;

    @Inject(method = "step", at = @At(value = "INVOKE", target = "Lnet/minecraft/server/ServerTickManager;sendStepPacket()V"))
    public void serverTickManager$step(int ticks, CallbackInfoReturnable<Boolean> cir) {
        this.field_46963++; // for some reason, the first tick is always skipped. so artificially add one :P
        method_54671(nominalTickRate);
    }

    @Inject(method = "stopStepping", at = @At(value = "INVOKE", target = "Lnet/minecraft/server/ServerTickManager;sendStepPacket()V"))
    public void stopStepping(CallbackInfoReturnable<Boolean> cir) {
        updateFastestTicker();
    }

    @Inject(method = "finishSprinting", at = @At("TAIL"))
    public void finishSprinting(CallbackInfo ci) {
        tickRate$sendUpdatePacket(); // tell client to stop sprinting
    }

    /**
     * @author Ninjaking312
     * @reason individual sprint or server sprint are both just sprint to code that doesn't know the difference
     */
    @Overwrite
    public boolean isSprinting() {
        return tickRate$isServerSprint() || tickRate$isIndividualSprint();
    }

    @Override
    public void method_54755() {
        this.field_46964 = !this.field_46965 || this.field_46963 > 0;
        if (this.field_46963 > 0) {
            this.field_46963--;
            if(this.field_46963 == 0) {
                updateFastestTicker();
                tickRate$sendUpdatePacket(); // tell client to stop stepping
            }
        }
    }

    public void tickRate$serverStarted() {
        datafile = server.method_27050(class_5218.field_24188).resolve("data/TickRateData.nbt").toFile();
        if(datafile.exists()) {
            try {
                class_2487 nbt = class_2507.method_10633(datafile.toPath());
                nominalTickRate = nbt.method_10583("nominalTickRate");
                class_2509.field_11560.method_29163(nbt.method_10580("entities")).getOrThrow().entries().forEach(pair -> {
                    String key = class_2509.field_11560.method_10656(pair.getFirst()).getOrThrow();
                    float value = class_2509.field_11560.method_10645(pair.getSecond()).getOrThrow().floatValue();
                    unloadedEntities.put(key,value);
                });
                class_2509.field_11560.method_29163(nbt.method_10580("chunks")).getOrThrow().entries().forEach(pair -> {
                    String key = class_2509.field_11560.method_10656(pair.getFirst()).getOrThrow();
                    float value = class_2509.field_11560.method_10645(pair.getSecond()).getOrThrow().floatValue();
                    unloadedChunks.put(key,value);
                });
                updateFastestTicker();
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public void tickRate$saveData() {
        class_2487 nbt = new class_2487();
        nbt.method_10548("nominalTickRate",nominalTickRate);
        var entitiesNbt = class_2509.field_11560.mapBuilder();
        entities.forEach((k,v) -> entitiesNbt.add(k,class_2509.field_11560.method_10662(v)));
        unloadedEntities.forEach((k,v) -> entitiesNbt.add(k,class_2509.field_11560.method_10662(v)));
        var chunksNbt = class_2509.field_11560.mapBuilder();
        chunks.forEach((k,v) -> chunksNbt.add(k,class_2509.field_11560.method_10662(v)));
        unloadedChunks.forEach((k,v) -> chunksNbt.add(k,class_2509.field_11560.method_10662(v)));
        nbt.method_10566("entities", entitiesNbt.build(class_2509.field_11560.method_10668()).getOrThrow());
        nbt.method_10566("chunks", chunksNbt.build(class_2509.field_11560.method_10668()).getOrThrow());
        try {
            class_2507.method_10630(nbt,datafile.toPath());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void tickRate$addPlayerWithMod(class_3222 player) {
        playersWithMod.add(player);
    }

    public void tickRate$removePlayerWithMod(class_3222 player) {
        playersWithMod.remove(player);
    }

    public boolean tickRate$hasClientMod(class_3222 player) {
        return playersWithMod.contains(player);
    }

    public void tickRate$sendUpdatePacket() {
        TickState server1 = tickRate$getServerTickState();
        Map<String,TickState> entities1 = new HashMap<>();
        Map<String,TickState> chunks1 = new HashMap<>();
        this.entities.keySet().forEach(key -> entities1.put(key,getEntityTickState(key)));
        this.chunks.keySet().forEach(key -> chunks1.put(key,getChunkTickState(key)));
        this.steps.keySet().forEach(key -> { // for frozen stuff that has no specific rate
            if(key.contains(":")) chunks1.putIfAbsent(key, getChunkTickState(key)); // only chunk keys have : in them
            else entities1.putIfAbsent(key, getEntityTickState(key));
        });
        this.sprinting.keySet().forEach(key -> { // for sprinting stuff that has no specific rate
            if(key.contains(":")) chunks1.putIfAbsent(key, getChunkTickState(key)); // only chunk keys have : in them
            else entities1.putIfAbsent(key, getEntityTickState(key));
        });

        Map<class_5321<class_1937>,TickRateS2CUpdatePayload> worldPayloads = new HashMap<>();
        server.method_3738().forEach(serverWorld -> {
            Map<Integer,TickState> entities2 = new HashMap<>();
            Map<Long,TickState> chunks2 = new HashMap<>();
            String worldRegistryId = serverWorld.method_27983().method_29177().toString();
            entities1.forEach((uuid,state) -> {
                class_1297 e = serverWorld.method_14190(UUID.fromString(uuid));
                if(e != null) entities2.put(e.method_5628(), state);
            });
            chunks1.forEach((key,state) -> {
                String[] arr = key.split("-", 2); // longChunkPos could be -ve itself
                if(arr[0].equals(worldRegistryId)) chunks2.put(Long.parseLong(arr[1]),state);
            });
            worldPayloads.put(serverWorld.method_27983(), new TickRateS2CUpdatePayload(server1, entities2, chunks2));
        });
        playersWithMod.forEach(player -> ServerPlayNetworking.send(player, worldPayloads.get(player.method_37908().method_27983())));
    }

    public boolean tickRate$shouldTickEntity(class_1297 entity) {
        String key = entity.method_5845();
        if(ticked.get(key) != null) return ticked.get(key);
        if(tickRate$isServerSprint()) return true;
        if(method_54754()) {
            if(entity instanceof class_3222) return true;
            return method_54752();
        }

        if(sprinting.computeIfPresent(key, (k,v) -> {
            if(v == 0) return null;
            return --v;
        }) != null)
        {
            ticked.put(key,true);
            return true;
        }

        if(steps.getOrDefault(key,-1) == 0) return false;

        Float tickRate = entities.get(key);
        boolean shouldTick;
        if(tickRate != null)
            shouldTick = internalShouldTick(tickRate);
        else
            shouldTick = tickRate$shouldTickChunk(entity.method_37908(),entity.method_31476().method_8324());

        steps.computeIfPresent(key,(k,v) -> {
            if(v > 0 && shouldTick) return --v;
            return v;
        });

        ticked.put(key,shouldTick);
        return shouldTick;
    }

    public boolean tickRate$shouldTickChunk(class_1937 world, long chunkPos) {
        String key = world.method_27983().method_29177() + "-" + chunkPos;
        if(ticked.get(key) != null) return ticked.get(key);
        if(tickRate$isServerSprint()) return true;
        if(method_54754()) return method_54752();

        if(sprinting.computeIfPresent(key, (k,v) -> {
            if(v == 0) return null;
            return --v;
        }) != null)
        {
            ticked.put(key,true);
            return true;
        }

        if(steps.getOrDefault(key,-1) == 0) return false;

        Float tickRate = chunks.get(key);
        boolean shouldTick;
        if(tickRate == null) // follow nominal rate
            shouldTick = tickRate$shouldTickServer();
        else
            shouldTick = internalShouldTick(tickRate);

        steps.computeIfPresent(key,(k,v) -> {
            if(v > 0 && shouldTick) return --v;
            return v;
        });

        ticked.put(key,shouldTick);
        return shouldTick;
    }

    public boolean tickRate$shouldTickServer() {
        if(tickRate$isServerSprint()) return true;
        if(method_54754()) return method_54752();
        return internalShouldTick(nominalTickRate);
    }

    public void tickRate$updateChunkLoad(class_1937 world, long chunkPos, boolean loaded) {
        String key = world.method_27983().method_29177() + "-" + chunkPos;
        if(loaded) {
            Float rate = unloadedChunks.get(key);
            if(rate == null) return;
            unloadedChunks.remove(key);
            chunks.put(key,rate);
            if(rate > field_46961) updateFastestTicker();
        }
        else {
            Float rate = chunks.get(key);
            sprinting.remove(key); // just remove sprint if it unloads
            if(rate == null) return;
            chunks.remove(key);
            unloadedChunks.put(key,rate);
            if(rate == field_46961) updateFastestTicker();
        }
    }

    public void tickRate$updateEntityLoad(class_1297 entity, boolean loaded) {
        String key = entity.method_5845();
        if(loaded) {
            Float rate = unloadedEntities.get(key);
            if(rate == null) return;
            unloadedEntities.remove(key);
            entities.put(key,rate);
            if(rate > field_46961) updateFastestTicker();
        }
        else {
            Float rate = entities.get(key);
            sprinting.remove(key); // just remove sprint if it unloads
            if(rate == null) return;

            class_1297.class_5529 reason = entity.method_35049();
            if(reason != null) {
                switch (reason) {
                    case field_26998,field_26999 -> tickRate$removeEntity(entity,!entity.method_31747(),true,true);
                    case field_27000,field_27001 -> {
                        tickRate$removeEntity(entity,true,false,true);
                        unloadedEntities.put(key,rate);
                    }
                    case field_27002 -> {} // NO-OP
                }
            }
            else {
                tickRate$removeEntity(entity,true,false,true); // removed for no reason?? wtf
                unloadedEntities.put(key,rate); // just have to save even if that's not the correct thing to do
            }
            if(rate == field_46961) updateFastestTicker();
        }
    }

    public void tickRate$setServerRate(float rate) {
        nominalTickRate = rate;
        updateFastestTicker();
    }

    public float tickRate$getServerRate() {
        return nominalTickRate;
    }

    public TickState tickRate$getServerTickState() {
        return new TickState(tickRate$getServerRate(),method_54754(),method_54752(),tickRate$isServerSprint());
    }

    public void tickRate$ticked() {
        if(tickRate$isIndividualSprint()) {
            ticks++;
            if(ticks > sprintAvgTicksPerSecond) {
                ticks = 1;
                sprintAvgTicksPerSecond = (int) (class_4802.field_33868 / server.method_54834());
                tickRate$sendUpdatePacket();
            }
        }
        else {
            sprintAvgTicksPerSecond = -1;
            ticks++;
            if(ticks > field_46961) {
                ticks = 1;
                tickRate$sendUpdatePacket();
            }
        }
        ticked.clear();
    }

    public boolean tickRate$isIndividualSprint() {
        return !sprinting.isEmpty();
    }

    public boolean tickRate$isServerSprint() {
        return scheduledSprintTicks > 0L;
    }

    public void tickRate$removeEntity(class_1297 entity, boolean rate, boolean steps, boolean sprint) {
        String uuid = entity.method_5845();
        if(rate) entities.remove(uuid);
        if(steps) this.steps.remove(uuid);
        if(sprint) sprinting.remove(uuid);
    }

    public void tickRate$setEntityRate(float rate, Collection<? extends class_1297> entities) {
        if(rate == 0) entities.forEach(e -> this.entities.remove(e.method_5845()));
        else entities.forEach(e -> this.entities.put(e.method_5845(), rate));
        updateFastestTicker();
    }

    public float tickRate$getEntityRate(class_1297 entity) {
        if(method_54752()) return nominalTickRate; // server step override
        if(entity.method_5765()) return tickRate$getEntityRate(entity.method_5668()); //passengers follow tick rate of root vehicle
        Float rate = entities.get(entity.method_5845());
        if(rate != null) return rate;
        rate = chunks.get(entity.method_37908().method_27983().method_29177() + "-" + class_1923.method_37232(entity.method_24515()));
        if(rate != null) return rate;
        return nominalTickRate;
    }

    public void tickRate$setEntityFrozen(boolean frozen, Collection<? extends class_1297> entities) {
        if(frozen) {
            entities.forEach(e -> {
                steps.put(e.method_5845(),0);
                sprinting.remove(e.method_5845()); // if sprinting, stop the sprint
            });
        }
        else {
            entities.forEach(e -> steps.remove(e.method_5845()));
        }
    }

    public boolean tickRate$stepEntity(int steps, Collection<? extends class_1297> entities) {
        if(entities.stream().anyMatch(e -> !this.steps.containsKey(e.method_5845()) || this.sprinting.containsKey(e.method_5845()))) {
            return false; // some are not frozen or are sprinting, error
        }
        entities.forEach(e -> this.steps.put(e.method_5845(), steps));
        return true;
    }

    public boolean tickRate$sprintEntity(int ticks, Collection<? extends class_1297> entities) {
        if(entities.stream().anyMatch(e -> this.steps.getOrDefault(e.method_5845(),-1) > 0)) {
            return false; // some are stepping, error
        }
        if(ticks == 0) entities.forEach(e -> this.sprinting.remove(e.method_5845()));
        else entities.forEach(e -> this.sprinting.put(e.method_5845(), ticks));
        return true;
    }

    public TickState tickRate$getEntityTickStateShallow(class_1297 entity) {
        return getEntityTickState(entity.method_5845());
    }

    public TickState tickRate$getEntityTickStateDeep(class_1297 entity) {
        if(entity.method_5765()) return tickRate$getEntityTickStateDeep(entity.method_5668()); // all passengers will follow TPS of the root entity
        TickState state = tickRate$getEntityTickStateShallow(entity);
        float rate = state.rate();
        TickState serverState = tickRate$getServerTickState();

        if(rate == -1.0f) rate = tickRate$getChunkTickStateDeep(entity.method_37908(), entity.method_31476().method_8324()).rate();
        if(serverState.frozen() || serverState.sprinting() || serverState.stepping())
            return new TickState(serverState.stepping() ? serverState.rate() : rate,serverState.frozen(),serverState.stepping(),serverState.sprinting());
        return new TickState(rate,state.frozen(),state.stepping(),state.sprinting());
    }

    public void tickRate$setChunkRate(float rate, class_1937 world, Collection<class_1923> chunks) {
        if(rate == 0) {
            chunks.forEach(chunkPos -> {
                String key = world.method_27983().method_29177() + "-" + chunkPos.method_8324();
                this.chunks.remove(key);
            });
        }
        else {
            chunks.forEach(chunkPos -> this.chunks.put(world.method_27983().method_29177() + "-" + chunkPos.method_8324(), rate));
        }
        updateFastestTicker();
    }

    public float tickRate$getChunkRate(class_1937 world, long chunkPos) {
        if(method_54752()) return nominalTickRate; // server step override
        String key = world.method_27983().method_29177() + "-" + chunkPos;
        Float rate = chunks.get(key);
        if(rate != null) return rate;
        return nominalTickRate;
    }

    public void tickRate$setChunkFrozen(boolean frozen, class_1937 world, Collection<class_1923> chunks) {
        if(frozen) {
            chunks.forEach(chunkPos -> {
                String key = world.method_27983().method_29177() + "-" + chunkPos.method_8324();
                steps.put(key,0);
                sprinting.remove(key); // if sprinting, stop the sprint
            });
        }
        else {
            chunks.forEach(chunkPos -> steps.remove(world.method_27983().method_29177() + "-" + chunkPos.method_8324()));
        }
    }

    public boolean tickRate$stepChunk(int steps, class_1937 world, Collection<class_1923> chunks) {
        boolean error = chunks.stream().anyMatch(chunkPos -> {
            String key = world.method_27983().method_29177() + "-" + chunkPos.method_8324();
            return !this.steps.containsKey(key) || this.sprinting.containsKey(key);
        });
        if(error) return false; // some are not frozen or are sprinting, error

        chunks.forEach(chunkPos -> this.steps.put(world.method_27983().method_29177() + "-" + chunkPos.method_8324(), steps));
        return true;
    }

    public boolean tickRate$sprintChunk(int ticks, class_1937 world, Collection<class_1923> chunks) {
        if(chunks.stream().anyMatch(chunkPos -> this.steps.getOrDefault(world.method_27983().method_29177() + "-" + chunkPos.method_8324(),-1) > 0))
            return false; // some are stepping, error

        if(ticks == 0) chunks.forEach(chunkPos -> this.sprinting.remove(world.method_27983().method_29177() + "-" + chunkPos.method_8324()));
        else chunks.forEach(chunkPos -> this.sprinting.put(world.method_27983().method_29177() + "-" + chunkPos.method_8324(), ticks));
        return true;
    }

    public TickState tickRate$getChunkTickStateShallow(class_1937 world, long chunkPos) {
        return getChunkTickState(world.method_27983().method_29177() + "-" + chunkPos);
    }

    public TickState tickRate$getChunkTickStateDeep(class_1937 world, long chunkPos) {
        TickState state = tickRate$getChunkTickStateShallow(world, chunkPos);
        float rate = state.rate();
        TickState serverState = tickRate$getServerTickState();

        if(state.rate() == -1.0f) rate = serverState.rate();
        if(serverState.frozen() || serverState.sprinting() || serverState.stepping())
            return new TickState(serverState.stepping() ? serverState.rate() : rate,serverState.frozen(),serverState.stepping(),serverState.sprinting());
        return new TickState(rate,state.frozen(),state.stepping(),state.sprinting());
    }


    // PRIVATE METHODS

    @Unique
    private boolean internalShouldTick(float tickRate) {
        // attempt to evenly space out the exact number of ticks
        float fastestTickRate = tickRate$isIndividualSprint() ? sprintAvgTicksPerSecond : this.field_46961;

        double d = (fastestTickRate-1)/(tickRate+1);
        if(tickRate == fastestTickRate) return true;
        if(ticks == 1) return Math.ceil(1+(1*d)) == 1;

        double eventsToTick = (ticks-1)/d;
        if(eventsToTick >= tickRate) return Math.ceil(1+(tickRate*d)) == ticks;
        double floorEventToTick = Math.floor(eventsToTick);
        double ceilEventToTick = Math.ceil(eventsToTick);
        if(Math.ceil(1+(floorEventToTick*d)) == ticks) return true;
        return Math.ceil(1+(ceilEventToTick*d)) == ticks;
    }

    @Unique
    private void updateFastestTicker() {
        if(method_54752()) return;
        float fastest = 1.0f;
        fastest = Math.max(fastest, nominalTickRate);
        for(float rate : entities.values())
            fastest = Math.max(fastest,rate);
        for(float rate : chunks.values())
            fastest = Math.max(fastest,rate);
        if(fastest != field_46961) {
            method_54671(fastest);
            ticks = 1; // reset it
        }
    }

    @Unique
    private TickState getChunkTickState(String key) {
        float rate = chunks.getOrDefault(key, -1.0f);
        boolean frozen = steps.containsKey(key);
        boolean stepping = frozen && steps.get(key) != 0;
        boolean sprinting = this.sprinting.containsKey(key);
        return new TickState(rate,frozen,stepping,sprinting);
    }

    @Unique
    private TickState getEntityTickState(String key) {
        float rate = entities.getOrDefault(key, -1.0f);
        boolean frozen = steps.containsKey(key);
        boolean stepping = frozen && steps.get(key) != 0;
        boolean sprinting = this.sprinting.containsKey(key);
        return new TickState(rate,frozen,stepping,sprinting);
    }

}
