package me.cortex.vulkanite.lib.pipeline;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.nio.ByteBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import me.cortex.vulkanite.lib.base.VContext;
import me.cortex.vulkanite.lib.descriptors.VDescriptorSetLayout;
import me.cortex.vulkanite.lib.memory.VBuffer;
import me.cortex.vulkanite.lib.other.VUtil;
import org.lwjgl.system.MemoryStack;
import org.lwjgl.system.MemoryUtil;
import org.lwjgl.vulkan.KHRRayTracingPipeline;
import org.lwjgl.vulkan.VK10;
import org.lwjgl.vulkan.VkAllocationCallbacks;
import org.lwjgl.vulkan.VkPhysicalDeviceRayTracingPipelinePropertiesKHR;
import org.lwjgl.vulkan.VkPipelineLayoutCreateInfo;
import org.lwjgl.vulkan.VkPipelineShaderStageCreateInfo;
import org.lwjgl.vulkan.VkRayTracingPipelineCreateInfoKHR;
import org.lwjgl.vulkan.VkRayTracingShaderGroupCreateInfoKHR;
import org.lwjgl.vulkan.VkStridedDeviceAddressRegionKHR;

/* loaded from: input_file:me/cortex/vulkanite/lib/pipeline/RaytracePipelineBuilder.class */
public class RaytracePipelineBuilder {
    private ShaderModule gen;
    private final Set<ShaderModule> shaders = new LinkedHashSet();
    private final List<ShaderModule> missGroups = new ArrayList();
    private final List<HitG> hitGroups = new ArrayList();
    private final List<ShaderModule> callGroups = new ArrayList();
    Set<VDescriptorSetLayout> layouts = new LinkedHashSet();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:me/cortex/vulkanite/lib/pipeline/RaytracePipelineBuilder$HitG.class */
    public static final class HitG extends Record {
        private final ShaderModule chit;
        private final ShaderModule ahit;
        private final ShaderModule intr;

        private HitG(ShaderModule shaderModule, ShaderModule shaderModule2, ShaderModule shaderModule3) {
            this.chit = shaderModule;
            this.ahit = shaderModule2;
            this.intr = shaderModule3;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, HitG.class), HitG.class, "chit;ahit;intr", "FIELD:Lme/cortex/vulkanite/lib/pipeline/RaytracePipelineBuilder$HitG;->chit:Lme/cortex/vulkanite/lib/pipeline/ShaderModule;", "FIELD:Lme/cortex/vulkanite/lib/pipeline/RaytracePipelineBuilder$HitG;->ahit:Lme/cortex/vulkanite/lib/pipeline/ShaderModule;", "FIELD:Lme/cortex/vulkanite/lib/pipeline/RaytracePipelineBuilder$HitG;->intr:Lme/cortex/vulkanite/lib/pipeline/ShaderModule;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, HitG.class), HitG.class, "chit;ahit;intr", "FIELD:Lme/cortex/vulkanite/lib/pipeline/RaytracePipelineBuilder$HitG;->chit:Lme/cortex/vulkanite/lib/pipeline/ShaderModule;", "FIELD:Lme/cortex/vulkanite/lib/pipeline/RaytracePipelineBuilder$HitG;->ahit:Lme/cortex/vulkanite/lib/pipeline/ShaderModule;", "FIELD:Lme/cortex/vulkanite/lib/pipeline/RaytracePipelineBuilder$HitG;->intr:Lme/cortex/vulkanite/lib/pipeline/ShaderModule;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, HitG.class, Object.class), HitG.class, "chit;ahit;intr", "FIELD:Lme/cortex/vulkanite/lib/pipeline/RaytracePipelineBuilder$HitG;->chit:Lme/cortex/vulkanite/lib/pipeline/ShaderModule;", "FIELD:Lme/cortex/vulkanite/lib/pipeline/RaytracePipelineBuilder$HitG;->ahit:Lme/cortex/vulkanite/lib/pipeline/ShaderModule;", "FIELD:Lme/cortex/vulkanite/lib/pipeline/RaytracePipelineBuilder$HitG;->intr:Lme/cortex/vulkanite/lib/pipeline/ShaderModule;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public ShaderModule chit() {
            return this.chit;
        }

        public ShaderModule ahit() {
            return this.ahit;
        }

        public ShaderModule intr() {
            return this.intr;
        }
    }

    public RaytracePipelineBuilder setRayGen(ShaderModule shaderModule) {
        this.gen = shaderModule;
        return this;
    }

    public RaytracePipelineBuilder addMiss(ShaderModule shaderModule) {
        this.shaders.add(shaderModule);
        this.missGroups.add(shaderModule);
        return this;
    }

    public RaytracePipelineBuilder addHit(ShaderModule shaderModule, ShaderModule shaderModule2, ShaderModule shaderModule3) {
        if (shaderModule != null) {
            this.shaders.add(shaderModule);
        }
        if (shaderModule2 != null) {
            this.shaders.add(shaderModule2);
        }
        if (shaderModule3 != null) {
            this.shaders.add(shaderModule3);
        }
        this.hitGroups.add(new HitG(shaderModule, shaderModule2, shaderModule3));
        return this;
    }

    public RaytracePipelineBuilder addCallable(ShaderModule shaderModule) {
        this.shaders.add(shaderModule);
        this.callGroups.add(shaderModule);
        return this;
    }

    public RaytracePipelineBuilder addLayout(VDescriptorSetLayout vDescriptorSetLayout) {
        this.layouts.add(vDescriptorSetLayout);
        return this;
    }

    public VRaytracePipeline build(VContext vContext, int i) {
        this.shaders.add(this.gen);
        MemoryStack stackPush = MemoryStack.stackPush();
        try {
            VkPipelineShaderStageCreateInfo.Buffer calloc = VkPipelineShaderStageCreateInfo.calloc(this.shaders.size(), stackPush);
            HashMap hashMap = new HashMap();
            for (ShaderModule shaderModule : this.shaders) {
                VkPipelineShaderStageCreateInfo vkPipelineShaderStageCreateInfo = calloc.get(hashMap.size());
                hashMap.put(shaderModule, Integer.valueOf(hashMap.size()));
                shaderModule.setupStruct(stackPush, vkPipelineShaderStageCreateInfo);
            }
            VkRayTracingShaderGroupCreateInfoKHR.Buffer calloc2 = VkRayTracingShaderGroupCreateInfoKHR.calloc(1 + this.missGroups.size() + this.hitGroups.size() + this.callGroups.size(), stackPush);
            calloc2.forEach(vkRayTracingShaderGroupCreateInfoKHR -> {
                vkRayTracingShaderGroupCreateInfoKHR.sType$Default().generalShader(-1).intersectionShader(-1).closestHitShader(-1).anyHitShader(-1);
            });
            calloc2.get().type(0).generalShader(((Integer) hashMap.get(this.gen)).intValue());
            this.missGroups.forEach(shaderModule2 -> {
                calloc2.get().type(0).generalShader(((Integer) hashMap.get(shaderModule2)).intValue());
            });
            this.hitGroups.forEach(hitG -> {
                calloc2.get().type(hitG.intr == null ? 1 : 2).closestHitShader(hitG.chit == null ? -1 : ((Integer) hashMap.get(hitG.chit)).intValue()).anyHitShader(hitG.ahit == null ? -1 : ((Integer) hashMap.get(hitG.ahit)).intValue()).intersectionShader(hitG.intr == null ? -1 : ((Integer) hashMap.get(hitG.intr)).intValue());
            });
            calloc2.rewind();
            VkPipelineLayoutCreateInfo sType$Default = VkPipelineLayoutCreateInfo.calloc(stackPush).sType$Default();
            sType$Default.pSetLayouts(stackPush.longs(this.layouts.stream().mapToLong(vDescriptorSetLayout -> {
                return vDescriptorSetLayout.layout;
            }).toArray()));
            LongBuffer mallocLong = stackPush.mallocLong(1);
            VUtil._CHECK_(VK10.vkCreatePipelineLayout(vContext.device, sType$Default, (VkAllocationCallbacks) null, mallocLong));
            VkRayTracingPipelineCreateInfoKHR maxPipelineRayRecursionDepth = VkRayTracingPipelineCreateInfoKHR.calloc(stackPush).sType$Default().layout(mallocLong.get(0)).pStages(calloc).pGroups(calloc2).maxPipelineRayRecursionDepth(i);
            LongBuffer mallocLong2 = stackPush.mallocLong(1);
            VUtil._CHECK_(KHRRayTracingPipeline.vkCreateRayTracingPipelinesKHR(vContext.device, 0L, 0L, VkRayTracingPipelineCreateInfoKHR.create(maxPipelineRayRecursionDepth.address(), 1), (VkAllocationCallbacks) null, mallocLong2));
            VkPhysicalDeviceRayTracingPipelinePropertiesKHR vkPhysicalDeviceRayTracingPipelinePropertiesKHR = vContext.properties.rtPipelineProperties;
            long shaderGroupBaseAlignment = vkPhysicalDeviceRayTracingPipelinePropertiesKHR.shaderGroupBaseAlignment();
            long shaderGroupHandleSize = vkPhysicalDeviceRayTracingPipelinePropertiesKHR.shaderGroupHandleSize();
            long alignUp = VUtil.alignUp(shaderGroupHandleSize, vkPhysicalDeviceRayTracingPipelinePropertiesKHR.shaderGroupHandleAlignment());
            if (alignUp != shaderGroupHandleSize) {
                throw new IllegalStateException("Painpoint, handleSizeAligned != handleSize");
            }
            int capacity = calloc2.capacity();
            long alignUp2 = VUtil.alignUp(0 + alignUp, shaderGroupBaseAlignment);
            long size = this.missGroups.size();
            long alignUp3 = VUtil.alignUp(alignUp2 + (alignUp * size), shaderGroupBaseAlignment);
            long size2 = this.hitGroups.size();
            long alignUp4 = VUtil.alignUp(alignUp3 + (alignUp * size2), shaderGroupBaseAlignment);
            long alignUp5 = VUtil.alignUp(alignUp4 + (this.callGroups.size() * alignUp), shaderGroupBaseAlignment);
            ByteBuffer malloc = stackPush.malloc(capacity * ((int) shaderGroupHandleSize));
            VUtil._CHECK_(KHRRayTracingPipeline.vkGetRayTracingShaderGroupHandlesKHR(vContext.device, mallocLong2.get(0), 0, capacity, malloc), "Failed to obtain ray tracing group handles");
            long memAddress = MemoryUtil.memAddress(malloc);
            VBuffer createBufferGlobal = vContext.memory.createBufferGlobal(alignUp5, 132098, 3, 1024);
            long map = createBufferGlobal.map();
            MemoryUtil.memCopy(memAddress, map + 0, shaderGroupHandleSize);
            long j = memAddress + shaderGroupHandleSize;
            MemoryUtil.memCopy(j, map + alignUp2, shaderGroupHandleSize * size);
            long j2 = j + (shaderGroupHandleSize * size);
            MemoryUtil.memCopy(j2, map + alignUp3, shaderGroupHandleSize);
            long j3 = j2 + (shaderGroupHandleSize * size2);
            MemoryUtil.memCopy(j3, map + alignUp4, shaderGroupHandleSize);
            long j4 = j3 + (shaderGroupHandleSize * size);
            createBufferGlobal.unmap();
            createBufferGlobal.flush();
            VRaytracePipeline vRaytracePipeline = new VRaytracePipeline(vContext, mallocLong2.get(0), mallocLong.get(0), createBufferGlobal, VkStridedDeviceAddressRegionKHR.calloc().set(createBufferGlobal.deviceAddress() + 0, alignUp, alignUp), VkStridedDeviceAddressRegionKHR.calloc().set(createBufferGlobal.deviceAddress() + alignUp2, alignUp, alignUp * alignUp2), VkStridedDeviceAddressRegionKHR.calloc().set(createBufferGlobal.deviceAddress() + alignUp3, alignUp, alignUp * alignUp3), VkStridedDeviceAddressRegionKHR.calloc().set(createBufferGlobal.deviceAddress() + alignUp4, alignUp, alignUp * alignUp4), this.shaders);
            if (stackPush != null) {
                stackPush.close();
            }
            return vRaytracePipeline;
        } catch (Throwable th) {
            if (stackPush != null) {
                try {
                    stackPush.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }
}
