/*
 * Decompiled with CFR 0.152.
 */
package net.caffeinemc.mods.sodium.mixin.features.render.immediate.buffer_builder.sorting;

import com.llamalad7.mixinextras.injector.wrapoperation.Operation;
import com.llamalad7.mixinextras.injector.wrapoperation.WrapOperation;
import com.mojang.blaze3d.vertex.ByteBufferBuilder;
import com.mojang.blaze3d.vertex.MeshData;
import com.mojang.blaze3d.vertex.VertexFormat;
import com.mojang.blaze3d.vertex.VertexSorting;
import net.caffeinemc.mods.sodium.client.util.sorting.VertexSorters;
import net.caffeinemc.mods.sodium.client.util.sorting.VertexSortingExtended;
import net.caffeinemc.mods.sodium.mixin.features.render.immediate.buffer_builder.sorting.MeshDataAccessor;
import net.minecraft.client.renderer.MultiBufferSource;
import org.lwjgl.system.MemoryUtil;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Unique;
import org.spongepowered.asm.mixin.injection.At;

@Mixin(value={MultiBufferSource.BufferSource.class})
public class MultiBufferSourceMixin {
    @Unique
    private static final int VERTICES_PER_QUAD = 6;

    @WrapOperation(method={"endBatch(Lnet/minecraft/client/renderer/RenderType;Lcom/mojang/blaze3d/vertex/BufferBuilder;)V"}, at={@At(value="INVOKE", target="Lcom/mojang/blaze3d/vertex/MeshData;sortQuads(Lcom/mojang/blaze3d/vertex/ByteBufferBuilder;Lcom/mojang/blaze3d/vertex/VertexSorting;)Lcom/mojang/blaze3d/vertex/MeshData$SortState;")})
    private MeshData.SortState redirectSortQuads(MeshData meshData, ByteBufferBuilder bufferBuilder, VertexSorting sorting, Operation<MeshData.SortState> original) {
        if (!(sorting instanceof VertexSortingExtended)) {
            return (MeshData.SortState)original.call(new Object[]{meshData, bufferBuilder, sorting});
        }
        VertexSortingExtended sortingExtended = (VertexSortingExtended)sorting;
        MultiBufferSourceMixin.acceleratedSort(meshData, bufferBuilder, sortingExtended);
        return null;
    }

    @Unique
    private static void acceleratedSort(MeshData meshData, ByteBufferBuilder bufferBuilder, VertexSortingExtended sorting) {
        MeshData.DrawState drawState = meshData.drawState();
        if (drawState.mode() != VertexFormat.Mode.QUADS) {
            return;
        }
        int[] sortedPrimitiveIds = VertexSorters.sort(meshData.vertexBuffer(), drawState.vertexCount(), drawState.format().getVertexSize(), sorting);
        ByteBufferBuilder.Result sortedIndexBuffer = MultiBufferSourceMixin.buildSortedIndexBuffer(meshData, bufferBuilder, sortedPrimitiveIds);
        ((MeshDataAccessor)meshData).setIndexBuffer(sortedIndexBuffer);
    }

    @Unique
    private static ByteBufferBuilder.Result buildSortedIndexBuffer(MeshData meshData, ByteBufferBuilder bufferBuilder, int[] primitiveIds) {
        VertexFormat.IndexType indexType = meshData.drawState().indexType();
        long ptr = bufferBuilder.reserve(primitiveIds.length * 6 * indexType.bytes);
        if (indexType == VertexFormat.IndexType.SHORT) {
            MultiBufferSourceMixin.writeIndexBufferShort(ptr, primitiveIds);
        } else if (indexType == VertexFormat.IndexType.INT) {
            MultiBufferSourceMixin.writeIndexBufferInt(ptr, primitiveIds);
        } else {
            throw new UnsupportedOperationException();
        }
        return bufferBuilder.build();
    }

    @Unique
    private static void writeIndexBufferInt(long ptr, int[] primitiveIds) {
        for (int primitiveId : primitiveIds) {
            MemoryUtil.memPutInt((long)(ptr + 0L), (int)(primitiveId * 4 + 0));
            MemoryUtil.memPutInt((long)(ptr + 4L), (int)(primitiveId * 4 + 1));
            MemoryUtil.memPutInt((long)(ptr + 8L), (int)(primitiveId * 4 + 2));
            MemoryUtil.memPutInt((long)(ptr + 12L), (int)(primitiveId * 4 + 2));
            MemoryUtil.memPutInt((long)(ptr + 16L), (int)(primitiveId * 4 + 3));
            MemoryUtil.memPutInt((long)(ptr + 20L), (int)(primitiveId * 4 + 0));
            ptr += 24L;
        }
    }

    @Unique
    private static void writeIndexBufferShort(long ptr, int[] primitiveIds) {
        for (int primitiveId : primitiveIds) {
            MemoryUtil.memPutShort((long)(ptr + 0L), (short)((short)(primitiveId * 4 + 0)));
            MemoryUtil.memPutShort((long)(ptr + 2L), (short)((short)(primitiveId * 4 + 1)));
            MemoryUtil.memPutShort((long)(ptr + 4L), (short)((short)(primitiveId * 4 + 2)));
            MemoryUtil.memPutShort((long)(ptr + 6L), (short)((short)(primitiveId * 4 + 2)));
            MemoryUtil.memPutShort((long)(ptr + 8L), (short)((short)(primitiveId * 4 + 3)));
            MemoryUtil.memPutShort((long)(ptr + 10L), (short)((short)(primitiveId * 4 + 0)));
            ptr += 12L;
        }
    }
}

