/*
 * Decompiled with CFR 0.152.
 */
package io.homo.superresolution.core.vulkan.shader;

import io.homo.superresolution.core.vulkan.VkDeviceManager;
import io.homo.superresolution.core.vulkan.VkException;
import io.homo.superresolution.core.vulkan.shader.VkShaderUniform;
import io.homo.superresolution.core.vulkan.shader.VkShaderUniformType;
import java.nio.ByteBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import org.lwjgl.system.MemoryStack;
import org.lwjgl.system.MemoryUtil;
import org.lwjgl.system.Struct;
import org.lwjgl.vulkan.VK10;
import org.lwjgl.vulkan.VkComputePipelineCreateInfo;
import org.lwjgl.vulkan.VkDescriptorBufferInfo;
import org.lwjgl.vulkan.VkDescriptorSetAllocateInfo;
import org.lwjgl.vulkan.VkDescriptorSetLayoutBinding;
import org.lwjgl.vulkan.VkDescriptorSetLayoutCreateInfo;
import org.lwjgl.vulkan.VkDevice;
import org.lwjgl.vulkan.VkPipelineLayoutCreateInfo;
import org.lwjgl.vulkan.VkPipelineShaderStageCreateInfo;
import org.lwjgl.vulkan.VkPushConstantRange;
import org.lwjgl.vulkan.VkShaderModuleCreateInfo;
import org.lwjgl.vulkan.VkWriteDescriptorSet;

public class VkComputeShader {
    private final VkDeviceManager deviceManager;
    public long descriptorSetLayout = -1L;
    public long descriptorSet = -1L;
    public long pipelineLayout = -1L;
    public long shaderModule = -1L;
    public long pipeline = -1L;
    public ArrayList<VkShaderUniform> uniforms = new ArrayList();
    public HashMap<Integer, VkShaderUniform> uniformsMap = new HashMap();
    private ByteBuffer shaderBin;

    public VkComputeShader(VkDeviceManager deviceManager) {
        this.deviceManager = deviceManager;
    }

    public VkComputeShader build() {
        if (this.shaderBin == null) {
            throw new VkException();
        }
        this.loadShader();
        this.createPipeline();
        return this;
    }

    public ByteBuffer getShaderBin() {
        return this.shaderBin;
    }

    public VkComputeShader setShaderBin(ByteBuffer bytes) {
        this.shaderBin = bytes;
        return this;
    }

    public VkComputeShader addUniform(VkShaderUniform uniform) {
        this.uniforms.add(uniform);
        this.uniformsMap.put(uniform.binding, uniform);
        return this;
    }

    public VkShaderUniform getUniform(int binding) {
        return this.uniformsMap.get(binding);
    }

    private void loadShader() {
        VkShaderModuleCreateInfo createInfo = VkShaderModuleCreateInfo.create();
        createInfo.sType(16);
        createInfo.pCode(this.shaderBin);
        LongBuffer ptr = MemoryStack.stackCallocLong((int)1);
        VK10.vkCreateShaderModule((VkDevice)this.deviceManager.device, (VkShaderModuleCreateInfo)createInfo, null, (LongBuffer)ptr);
        this.shaderModule = ptr.get(0);
    }

    private void createPipeline() {
        if (this.shaderModule == -1L) {
            throw new VkException();
        }
        VkDescriptorSetLayoutBinding.Buffer bindLayout = VkDescriptorSetLayoutBinding.calloc((int)this.uniforms.size());
        for (VkShaderUniform vkShaderUniform : this.uniforms) {
            bindLayout.put((Struct)vkShaderUniform.build());
        }
        bindLayout.flip();
        VkDescriptorSetLayoutCreateInfo info = VkDescriptorSetLayoutCreateInfo.calloc();
        info.sType(32);
        info.pBindings(bindLayout);
        LongBuffer longBuffer = MemoryStack.stackCallocLong((int)1);
        VK10.vkCreateDescriptorSetLayout((VkDevice)this.deviceManager.device, (VkDescriptorSetLayoutCreateInfo)info, null, (LongBuffer)longBuffer);
        this.descriptorSetLayout = longBuffer.get(0);
        VkDescriptorSetAllocateInfo info2 = VkDescriptorSetAllocateInfo.calloc();
        info2.sType(34);
        info2.descriptorPool(this.deviceManager.descriptorPool);
        info2.pSetLayouts(MemoryStack.stackLongs((long)this.descriptorSetLayout));
        LongBuffer setsPtr = MemoryStack.stackCallocLong((int)1);
        VK10.vkAllocateDescriptorSets((VkDevice)this.deviceManager.device, (VkDescriptorSetAllocateInfo)info2, (LongBuffer)setsPtr);
        this.descriptorSet = setsPtr.get(0);
        int bufferUniformCount = 0;
        for (VkShaderUniform vkShaderUniform : this.uniforms) {
            if (vkShaderUniform.type != VkShaderUniformType.buffer) continue;
            ++bufferUniformCount;
        }
        VkWriteDescriptorSet.Buffer writeDescSets = VkWriteDescriptorSet.calloc((int)bufferUniformCount);
        for (VkShaderUniform uni : this.uniforms) {
            if (uni.type != VkShaderUniformType.buffer) continue;
            VkDescriptorBufferInfo.Buffer bufferInfos = VkDescriptorBufferInfo.calloc((int)1);
            bufferInfos.put((Struct)uni.createBufferInfo(this.deviceManager));
            bufferInfos.flip();
            VkWriteDescriptorSet writeDescSet = VkWriteDescriptorSet.calloc();
            writeDescSet.sType(35);
            writeDescSet.dstSet(this.descriptorSet);
            writeDescSet.dstBinding(uni.binding);
            writeDescSet.descriptorCount(1);
            writeDescSet.descriptorType(uni.type.getValue());
            writeDescSet.dstArrayElement(0);
            writeDescSet.pBufferInfo(bufferInfos);
            writeDescSets.put((Struct)writeDescSet);
        }
        writeDescSets.flip();
        VK10.vkUpdateDescriptorSets((VkDevice)this.deviceManager.device, (VkWriteDescriptorSet.Buffer)writeDescSets, null);
        bufferUniformCount = 0;
        for (VkShaderUniform vkShaderUniform : this.uniforms) {
            if (vkShaderUniform.type != VkShaderUniformType.buffer) continue;
            ++bufferUniformCount;
        }
        VkPushConstantRange.Buffer pushConstRange = VkPushConstantRange.calloc((int)bufferUniformCount);
        for (VkShaderUniform uni : this.uniforms) {
            if (uni.type != VkShaderUniformType.buffer) continue;
            VkPushConstantRange pushConst = VkPushConstantRange.calloc();
            pushConst.stageFlags(32);
            pushConst.size(uni.size);
            pushConst.offset(0);
            pushConstRange.put((Struct)pushConst);
        }
        pushConstRange.flip();
        VkPipelineLayoutCreateInfo vkPipelineLayoutCreateInfo = VkPipelineLayoutCreateInfo.calloc();
        vkPipelineLayoutCreateInfo.sType(30);
        LongBuffer descriptorSetLayout = MemoryStack.stackLongs((long)this.descriptorSetLayout);
        vkPipelineLayoutCreateInfo.pSetLayouts(descriptorSetLayout);
        vkPipelineLayoutCreateInfo.pPushConstantRanges(pushConstRange);
        LongBuffer ptr2 = MemoryStack.stackCallocLong((int)1);
        VK10.vkCreatePipelineLayout((VkDevice)this.deviceManager.device, (VkPipelineLayoutCreateInfo)vkPipelineLayoutCreateInfo, null, (LongBuffer)ptr2);
        this.pipelineLayout = ptr2.get(0);
        VkPipelineShaderStageCreateInfo pipeShaderStageCreateInfo = VkPipelineShaderStageCreateInfo.create();
        pipeShaderStageCreateInfo.sType(18);
        pipeShaderStageCreateInfo.stage(32);
        pipeShaderStageCreateInfo.module(this.shaderModule);
        pipeShaderStageCreateInfo.pName(MemoryUtil.memUTF8((CharSequence)"main"));
        VkComputePipelineCreateInfo.Buffer infos = VkComputePipelineCreateInfo.calloc((int)1);
        VkComputePipelineCreateInfo vkComputePipelineCreateInfo = VkComputePipelineCreateInfo.calloc();
        vkComputePipelineCreateInfo.sType(29);
        vkComputePipelineCreateInfo.stage(pipeShaderStageCreateInfo);
        vkComputePipelineCreateInfo.layout(this.pipelineLayout);
        infos.put((Struct)vkComputePipelineCreateInfo);
        infos.flip();
        LongBuffer ptr3 = MemoryStack.stackCallocLong((int)1);
        VK10.vkCreateComputePipelines((VkDevice)this.deviceManager.device, (long)0L, (VkComputePipelineCreateInfo.Buffer)infos, null, (LongBuffer)ptr3);
        this.pipeline = ptr3.get(0);
    }
}

