/*
 * Decompiled with CFR 0.152.
 */
package sircow.preservedinferno.trade;

import java.util.ArrayList;
import java.util.List;
import net.minecraft.class_1914;
import net.minecraft.class_3989;
import net.minecraft.class_5819;
import sircow.preservedinferno.trade.WeightedTradeEntry;

public record WeightedTradePool(List<WeightedTradeEntry> entries, int rolls) {
    public List<WeightedTradeResult> rollOffers(class_3989 trader, class_5819 random) {
        class_1914 offer;
        ArrayList<WeightedTradeResult> results = new ArrayList<WeightedTradeResult>();
        for (WeightedTradeEntry entry : this.entries) {
            if (!entry.guaranteed() || (offer = entry.getOffer(trader, random)) == null) continue;
            results.add(new WeightedTradeResult(entry, offer));
        }
        for (int i = 0; i < this.rolls; ++i) {
            WeightedTradeEntry entry;
            entry = this.pickWeighted(random);
            if (entry.guaranteed() || (offer = entry.getOffer(trader, random)) == null) continue;
            results.add(new WeightedTradeResult(entry, offer));
        }
        return results;
    }

    private WeightedTradeEntry pickWeighted(class_5819 random) {
        List<WeightedTradeEntry> nonGuaranteed = this.entries.stream().filter(e -> !e.guaranteed()).toList();
        int total = nonGuaranteed.stream().mapToInt(WeightedTradeEntry::weight).sum();
        if (total <= 0) {
            return nonGuaranteed.get(random.method_43048(nonGuaranteed.size()));
        }
        int choice = random.method_43048(total);
        int cumulative = 0;
        for (WeightedTradeEntry e2 : nonGuaranteed) {
            if (choice >= (cumulative += e2.weight())) continue;
            return e2;
        }
        return nonGuaranteed.getLast();
    }

    public record WeightedTradeResult(WeightedTradeEntry entry, class_1914 offer) {
    }
}

