package com.zurrtum.create.mixin;

import com.google.common.collect.Maps;
import com.llamalad7.mixinextras.injector.ModifyExpressionValue;
import com.mojang.serialization.Codec;
import com.zurrtum.create.content.trains.station.StationBlockEntity;
import com.zurrtum.create.content.trains.station.StationMapData;
import com.zurrtum.create.content.trains.station.StationMarker;
import org.spongepowered.asm.mixin.Final;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Shadow;
import org.spongepowered.asm.mixin.Unique;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfo;

import java.util.*;
import net.minecraft.class_1922;
import net.minecraft.class_1936;
import net.minecraft.class_20;
import net.minecraft.class_22;
import net.minecraft.class_2338;

@Mixin(class_22.class)
public abstract class MapItemSavedDataMixin implements StationMapData {
    @Shadow
    @Final
    public int centerX;

    @Shadow
    @Final
    public int centerZ;

    @Shadow
    @Final
    public byte scale;

    @Shadow
    @Final
    Map<String, class_20> decorations;

    @Shadow
    private int trackedDecorationCount;

    @Unique
    private final Map<String, StationMarker> create$stationMarkers = Maps.newHashMap();

    @ModifyExpressionValue(method = "type(Lnet/minecraft/world/level/saveddata/maps/MapId;)Lnet/minecraft/world/level/saveddata/SavedDataType;", at = @At(value = "FIELD", target = "Lnet/minecraft/world/level/saveddata/maps/MapItemSavedData;CODEC:Lcom/mojang/serialization/Codec;"))
    private static Codec<class_22> saveCodec(Codec<class_22> codec) {
        return StationMarker.WrapperCodec.get(codec);
    }

    @Override
    public Map<String, StationMarker> create$getStationMarkers() {
        return create$stationMarkers;
    }

    @Override
    public void create$addStationMarker(StationMarker marker) {
        create$stationMarkers.put(marker.getId(), marker);

        int scaleMultiplier = 1 << scale;
        float localX = (marker.getTarget().method_10263() - centerX) / (float) scaleMultiplier;
        float localZ = (marker.getTarget().method_10260() - centerZ) / (float) scaleMultiplier;

        if (localX < -63.0F || localX > 63.0F || localZ < -63.0F || localZ > 63.0F) {
            this.removeDecoration(marker.getId());
            return;
        }

        byte localXByte = (byte) (int) (localX * 2.0F + 0.5F);
        byte localZByte = (byte) (int) (localZ * 2.0F + 0.5F);

        class_20 decoration = StationMarker.createStationDecoration(localXByte, localZByte, Optional.of(marker.getName()));
        class_20 oldDecoration = decorations.put(marker.getId(), decoration);
        if (!decoration.equals(oldDecoration)) {
            if (oldDecoration != null && oldDecoration.comp_1842().comp_349().comp_2518()) {
                --trackedDecorationCount;
            }

            if (decoration.comp_1842().comp_349().comp_2518()) {
                ++trackedDecorationCount;
            }

            setDecorationsDirty();
        }
    }

    @Shadow
    protected abstract void removeDecoration(String id);

    @Shadow
    protected abstract void setDecorationsDirty();

    @Shadow
    public abstract boolean isTrackedCountOverLimit(int trackedCount);

    @Override
    public boolean create$toggleStation(class_1936 level, class_2338 pos, StationBlockEntity stationBlockEntity) {
        double xCenter = pos.method_10263() + 0.5D;
        double zCenter = pos.method_10260() + 0.5D;
        int scaleMultiplier = 1 << scale;

        double localX = (xCenter - (double) centerX) / (double) scaleMultiplier;
        double localZ = (zCenter - (double) centerZ) / (double) scaleMultiplier;

        if (localX < -63.0D || localX > 63.0D || localZ < -63.0D || localZ > 63.0D)
            return false;

        StationMarker marker = StationMarker.fromWorld(level, pos);
        if (marker == null)
            return false;

        if (create$stationMarkers.remove(marker.getId(), marker)) {
            removeDecoration(marker.getId());
            return true;
        }

        if (!isTrackedCountOverLimit(256)) {
            create$addStationMarker(marker);
            return true;
        }

        return false;
    }

    @Inject(method = "checkBanners(Lnet/minecraft/world/level/BlockGetter;II)V", at = @At("RETURN"))
    public void create$onCheckBanners(class_1922 blockGetter, int x, int z, CallbackInfo ci) {
        create$checkStations(blockGetter, x, z);
    }

    @Unique
    private void create$checkStations(class_1922 blockGetter, int x, int z) {
        Iterator<StationMarker> iterator = create$stationMarkers.values().iterator();
        List<StationMarker> newMarkers = new ArrayList<>();

        while (iterator.hasNext()) {
            StationMarker marker = iterator.next();
            if (marker.getTarget().method_10263() == x && marker.getTarget().method_10260() == z) {
                StationMarker other = StationMarker.fromWorld(blockGetter, marker.getSource());
                if (!marker.equals(other)) {
                    iterator.remove();
                    removeDecoration(marker.getId());

                    if (other != null && marker.getTarget().equals(other.getTarget())) {
                        newMarkers.add(other);
                    }
                }
            }
        }

        for (StationMarker marker : newMarkers) {
            create$addStationMarker(marker);
        }
    }
}
