package com.abdelaziz.canary.mixin.chunk.block_counting;

import com.abdelaziz.canary.common.block.BlockCountingSection;
import com.abdelaziz.canary.common.block.BlockStateFlagHolder;
import com.abdelaziz.canary.common.block.BlockStateFlags;
import com.abdelaziz.canary.common.block.TrackedBlockStatePredicate;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import net.minecraft.network.FriendlyByteBuf;
import net.minecraft.world.level.block.state.BlockState;
import net.minecraft.world.level.chunk.LevelChunkSection;
import net.minecraft.world.level.chunk.PalettedContainer;
import org.spongepowered.asm.mixin.Final;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Shadow;
import org.spongepowered.asm.mixin.Unique;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.Redirect;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfo;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable;
import org.spongepowered.asm.mixin.injection.callback.LocalCapture;

@Mixin({LevelChunkSection.class})
/* loaded from: input_file:com/abdelaziz/canary/mixin/chunk/block_counting/LevelChunkSectionMixin.class */
public abstract class LevelChunkSectionMixin implements BlockCountingSection {

    @Shadow
    @Final
    private PalettedContainer<BlockState> states;

    @Unique
    private short[] countsByFlag = null;
    private CompletableFuture<short[]> countsByFlagFuture;

    @Override // com.abdelaziz.canary.common.block.BlockCountingSection
    public boolean anyMatch(TrackedBlockStatePredicate trackedBlockStatePredicate, boolean z) {
        return (this.countsByFlag != null || tryInitializeCountsByFlag()) ? this.countsByFlag[trackedBlockStatePredicate.getIndex()] != 0 : z;
    }

    private boolean tryInitializeCountsByFlag() {
        CompletableFuture<short[]> completableFuture = this.countsByFlagFuture;
        if (completableFuture != null && completableFuture.isDone()) {
            try {
                this.countsByFlag = (short[]) completableFuture.get();
                return true;
            } catch (InterruptedException | CancellationException | ExecutionException e) {
                this.countsByFlagFuture = null;
            }
        }
        if (this.countsByFlagFuture != null) {
            return false;
        }
        PalettedContainer<BlockState> palettedContainer = this.states;
        this.countsByFlagFuture = CompletableFuture.supplyAsync(() -> {
            return calculateLithiumCounts(palettedContainer);
        });
        return false;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static short[] calculateLithiumCounts(PalettedContainer<BlockState> palettedContainer) {
        short[] sArr = new short[BlockStateFlags.NUM_FLAGS];
        palettedContainer.m_63099_((blockState, i) -> {
            addToFlagCount(sArr, blockState, i);
        });
        return sArr;
    }

    @Redirect(method = {"recalcBlockCounts()V"}, at = @At(value = "INVOKE", target = "Lnet/minecraft/world/level/chunk/PalettedContainer;count(Lnet/minecraft/level/world/chunk/PalettedContainer$CountConsumer;)V"))
    private void initFlagCounters(PalettedContainer<BlockState> palettedContainer, PalettedContainer.CountConsumer<BlockState> countConsumer) {
        palettedContainer.m_63099_((blockState, i) -> {
            countConsumer.m_63144_(blockState, i);
            addToFlagCount(this.countsByFlag, blockState, i);
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void addToFlagCount(short[] sArr, BlockState blockState, int i) {
        int allFlags = ((BlockStateFlagHolder) blockState).getAllFlags();
        while (true) {
            int i2 = allFlags;
            int numberOfTrailingZeros = Integer.numberOfTrailingZeros(i2);
            if (numberOfTrailingZeros >= 32) {
                return;
            }
            sArr[numberOfTrailingZeros] = (short) (sArr[numberOfTrailingZeros] + i);
            allFlags = i2 & ((1 << numberOfTrailingZeros) ^ (-1));
        }
    }

    @Inject(method = {"recalcBlockCounts()V"}, at = {@At("HEAD")})
    private void createFlagCounters(CallbackInfo callbackInfo) {
        this.countsByFlag = new short[BlockStateFlags.NUM_FLAGS];
    }

    @Inject(method = {"setBlockState(IIILnet/minecraft/world/level/block/state/BlockState;Z)Lnet/minecraft/world/level/block/state/BlockState;"}, at = {@At("HEAD")})
    private void joinFuture(int i, int i2, int i3, BlockState blockState, boolean z, CallbackInfoReturnable<BlockState> callbackInfoReturnable) {
        if (this.countsByFlagFuture != null) {
            this.countsByFlag = this.countsByFlagFuture.join();
            this.countsByFlagFuture = null;
        }
    }

    @Inject(method = {"read"}, at = {@At("HEAD")})
    private void resetData(FriendlyByteBuf friendlyByteBuf, CallbackInfo callbackInfo) {
        this.countsByFlag = null;
        this.countsByFlagFuture = null;
    }

    @Inject(method = {"setBlockState(IIILnet/minecraft/block/BlockState;Z)Lnet/minecraft/block/BlockState;"}, at = {@At(value = "INVOKE", target = "Lnet/minecraft/block/BlockState;getFluidState()Lnet/minecraft/fluid/FluidState;", ordinal = 0, shift = At.Shift.BEFORE)}, locals = LocalCapture.CAPTURE_FAILHARD)
    private void updateFlagCounters(int i, int i2, int i3, BlockState blockState, boolean z, CallbackInfoReturnable<BlockState> callbackInfoReturnable, BlockState blockState2) {
        short[] sArr = this.countsByFlag;
        if (sArr == null) {
            return;
        }
        int allFlags = ((BlockStateFlagHolder) blockState2).getAllFlags();
        int allFlags2 = allFlags ^ ((BlockStateFlagHolder) blockState).getAllFlags();
        while (true) {
            int i4 = allFlags2;
            int numberOfTrailingZeros = Integer.numberOfTrailingZeros(i4);
            if (numberOfTrailingZeros >= 32) {
                return;
            }
            sArr[numberOfTrailingZeros] = (short) (sArr[numberOfTrailingZeros] + (1 - (((allFlags >>> numberOfTrailingZeros) & 1) << 1)));
            allFlags2 = i4 & ((1 << numberOfTrailingZeros) ^ (-1));
        }
    }
}
