package ai.djl.modality.nlp.generate;

import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;
import java.util.Arrays;
import java.util.Collection;
import java.util.stream.Collectors;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/modality/nlp/generate/TextGenerator.class */
public class TextGenerator {
    private String searchName;
    private SearchConfig config;
    private Predictor<NDList, CausalLMOutput> predictor;
    private NDArray positionOffset;
    private long[] endPosition;
    static final /* synthetic */ boolean $assertionsDisabled;

    public TextGenerator(Predictor<NDList, CausalLMOutput> predictor, String str, SearchConfig searchConfig) {
        this.predictor = predictor;
        this.searchName = str;
        this.config = searchConfig;
    }

    public NDArray greedySearch(NDArray nDArray) throws TranslateException {
        NDArray concat;
        this.endPosition = new long[Math.toIntExact(nDArray.getShape().get(0))];
        Arrays.fill(this.endPosition, this.config.getMaxSeqLength());
        NDArray prepareAttentionMaskOffset = prepareAttentionMaskOffset(nDArray, this.config);
        NDManager manager = nDArray.getManager();
        GreedyBatchTensorList greedyBatchTensorList = new GreedyBatchTensorList(nDArray, null, null, prepareAttentionMaskOffset);
        do {
            NDScope nDScope = new NDScope();
            try {
                NDArray pastOutputIds = greedyBatchTensorList.getPastOutputIds();
                NDArray nextInputIds = greedyBatchTensorList.getNextInputIds();
                NDArray pastAttentionMask = greedyBatchTensorList.getPastAttentionMask();
                NDList pastKeyValues = greedyBatchTensorList.getPastKeyValues();
                NDList prepareInput = prepareInput(nextInputIds, pastAttentionMask, pastOutputIds == null ? 0L : pastOutputIds.getShape().getLastDimension(), 1);
                if (pastKeyValues != null) {
                    prepareInput.addAll(pastKeyValues);
                }
                CausalLMOutput predict = this.predictor.predict(prepareInput);
                NDArray greedyStepGen = StepGeneration.greedyStepGen(predict.getLogits());
                if (pastOutputIds == null) {
                    concat = nextInputIds;
                    greedyBatchTensorList.setPastOutputIds(concat);
                } else {
                    concat = pastOutputIds.concat(nextInputIds, 1);
                    greedyBatchTensorList.setPastOutputIds(concat);
                }
                greedyBatchTensorList.setNextInputIds(greedyStepGen);
                NDList pastKeyValuesList = predict.getPastKeyValuesList();
                greedyBatchTensorList.setPastKeyValues(pastKeyValuesList);
                NDArray concat2 = pastAttentionMask.concat(manager.ones(new Shape(nDArray.getShape().get(0), 1), DataType.INT64), 1);
                greedyBatchTensorList.setPastAttentionMask(concat2);
                NDScope.unregister(greedyStepGen, concat2, concat);
                NDScope.unregister(pastKeyValuesList);
                nDScope.close();
                long[] longArray = greedyBatchTensorList.getNextInputIds().toLongArray();
                for (int i = 0; i < this.endPosition.length; i++) {
                    int length = longArray.length;
                    int i2 = 0;
                    while (true) {
                        if (i2 >= length) {
                            break;
                        }
                        if (longArray[i2] == this.config.getEosTokenId()) {
                            this.endPosition[i] = greedyBatchTensorList.getPastOutputIds().getShape().get(1) + 1;
                            break;
                        }
                        i2++;
                    }
                }
            } catch (Throwable th) {
                try {
                    nDScope.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        } while (greedyBatchTensorList.getPastOutputIds().getShape().get(1) + 1 < this.config.getMaxSeqLength());
        return greedyBatchTensorList.getPastOutputIds().concat(greedyBatchTensorList.getNextInputIds(), 1);
    }

    public NDArray beamSearch(NDArray nDArray) throws TranslateException {
        this.endPosition = new long[Math.toIntExact(nDArray.getShape().get(0))];
        Arrays.fill(this.endPosition, this.config.getMaxSeqLength());
        NDArray prepareAttentionMaskOffset = prepareAttentionMaskOffset(nDArray, this.config);
        NDManager manager = nDArray.getManager();
        long beam = this.config.getBeam();
        long j = nDArray.getShape().get(0);
        BeamBatchTensorList beamBatchTensorList = new BeamBatchTensorList();
        long j2 = 0;
        long j3 = 0;
        do {
            if (beamBatchTensorList.getPastAttentionMask() == null) {
                CausalLMOutput predict = this.predictor.predict(prepareInput(nDArray, prepareAttentionMaskOffset, 0L, 1));
                NDList pKVar = predict.getLogits().get(":, -1, :", new Object[0]).softmax(1).topK(Math.toIntExact(beam), -1, true, false);
                NDArray expandDims = pKVar.get(1).expandDims(2);
                NDArray normalize = pKVar.get(0).normalize(1.0d, 1L);
                if (!$assertionsDisabled && expandDims.getShape().getShape().length != 3) {
                    throw new AssertionError("Wrong shape");
                }
                if (!$assertionsDisabled && normalize.getShape().getShape().length != 2) {
                    throw new AssertionError("Wrong Shape");
                }
                prepareAttentionMaskOffset = prepareAttentionMaskOffset.concat(manager.ones(new Shape(j, 1), DataType.INT64), -1).expandDims(1).repeat(1, beam);
                NDList nDList = new NDList((Collection<NDArray>) predict.getPastKeyValuesList().stream().map(nDArray2 -> {
                    return nDArray2.expandDims(1).repeat(1, beam);
                }).collect(Collectors.toList()));
                beamBatchTensorList = new BeamBatchTensorList(expandDims, nDArray.expandDims(1).repeat(1, beam), nDList, prepareAttentionMaskOffset, normalize);
                j2 = nDList.get(0).getShape().get(2);
                j3 = nDList.get(0).getShape().getLastDimension();
            }
            NDScope nDScope = new NDScope();
            try {
                long lastDimension = beamBatchTensorList.getPastOutputIds().getShape().getLastDimension();
                NDList prepareInput = prepareInput(beamBatchTensorList.getNextInputIds().reshape(j * beam, 1), beamBatchTensorList.getPastAttentionMask().reshape(j * beam, -1), lastDimension, this.config.getBeam());
                long j4 = j2;
                long j5 = j3;
                prepareInput.addAll(new NDList((Collection<NDArray>) beamBatchTensorList.getPastKeyValues().stream().map(nDArray3 -> {
                    return nDArray3.reshape(j * beam, j4, lastDimension, j5);
                }).collect(Collectors.toList())));
                CausalLMOutput predict2 = this.predictor.predict(prepareInput);
                beamBatchTensorList = updateSearchState(beamBatchTensorList, predict2, StepGeneration.beamStepGeneration(beamBatchTensorList.getLastProbs(), predict2.getLogits(), j, beam), manager);
                NDScope.unregister(beamBatchTensorList.getNextInputIds(), beamBatchTensorList.getPastOutputIds(), beamBatchTensorList.getPastAttentionMask(), beamBatchTensorList.getLastProbs());
                NDScope.unregister(beamBatchTensorList.getPastKeyValues());
                nDScope.close();
                long[] longArray = beamBatchTensorList.getNextInputIds().toLongArray();
                for (int i = 0; i < this.endPosition.length; i++) {
                    int length = longArray.length;
                    int i2 = 0;
                    while (true) {
                        if (i2 >= length) {
                            break;
                        }
                        if (longArray[i2] == this.config.getEosTokenId()) {
                            this.endPosition[i] = beamBatchTensorList.getPastOutputIds().getShape().get(1) + 1;
                            break;
                        }
                        i2++;
                    }
                }
            } catch (Throwable th) {
                try {
                    nDScope.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        } while (beamBatchTensorList.getPastOutputIds().getShape().getLastDimension() + 1 < this.config.getMaxSeqLength());
        return beamBatchTensorList.getPastOutputIds().concat(beamBatchTensorList.getNextInputIds(), -1).reshape(j * beam, -1);
    }

    public NDArray contrastiveSearch(NDArray nDArray) throws TranslateException {
        this.endPosition = new long[Math.toIntExact(nDArray.getShape().get(0))];
        Arrays.fill(this.endPosition, this.config.getMaxSeqLength());
        NDManager manager = nDArray.getManager();
        NDArray prepareAttentionMaskOffset = prepareAttentionMaskOffset(nDArray, this.config);
        ContrastiveBatchTensorList contrastiveBatchTensorList = new ContrastiveBatchTensorList();
        do {
            if (contrastiveBatchTensorList.getPastKeyValues() == null) {
                CausalLMOutput predict = this.predictor.predict(prepareInput(nDArray, prepareAttentionMaskOffset, 0L, 1));
                contrastiveBatchTensorList = new ContrastiveBatchTensorList(nDArray, prepareAttentionMaskOffset, predict.getHiddenState(), predict.getLogits().get(":, -1, :", new Object[0]), predict.getPastKeyValuesList(), new long[0]);
            }
            NDScope nDScope = new NDScope();
            try {
                NDArray nDArray2 = contrastiveBatchTensorList.getLogits().topK(this.config.getK(), -1, true, false).get(1);
                NDArray reshape = nDArray2.flatten().reshape(-1, 1);
                if (!$assertionsDisabled && reshape.getDataType() != DataType.INT64) {
                    throw new AssertionError("inputIds datatype should be int64");
                }
                if (!$assertionsDisabled && reshape.getShape().getShape().length != 2) {
                    throw new AssertionError("shape not right");
                }
                NDList nDList = new NDList((Collection<NDArray>) contrastiveBatchTensorList.getPastKeyValues().stream().map(nDArray3 -> {
                    return nDArray3.repeat(0, this.config.getK());
                }).collect(Collectors.toList()));
                if (!$assertionsDisabled && nDList.get(0).getDataType() != DataType.FLOAT32) {
                    throw new AssertionError("inputIds datatype should be Float32");
                }
                NDArray concat = contrastiveBatchTensorList.getPastAttentionMask().repeat(0, this.config.getK()).concat(manager.ones(new Shape(nDArray2.getShape().get(0) * this.config.getK(), 1), DataType.INT64), 1);
                if (!$assertionsDisabled && nDList.get(0).getShape().get(2) + 1 != concat.getShape().getLastDimension()) {
                    throw new AssertionError("attentionMask_seq = past_seq + new_input_seq");
                }
                NDList prepareInput = prepareInput(reshape, concat, contrastiveBatchTensorList.getPastOutputIds().getShape().getLastDimension(), this.config.getK());
                prepareInput.addAll(nDList);
                CausalLMOutput predict2 = this.predictor.predict(prepareInput);
                contrastiveBatchTensorList = updateSearchState(contrastiveBatchTensorList, predict2, StepGeneration.constrastiveStepGeneration(nDArray2, contrastiveBatchTensorList.getLogits(), contrastiveBatchTensorList.getPastHiddenStates(), predict2.getHiddenState(), this.positionOffset, this.config.getAlpha()), manager);
                NDScope.unregister(contrastiveBatchTensorList.getPastOutputIds(), contrastiveBatchTensorList.getPastAttentionMask(), contrastiveBatchTensorList.getLogits(), contrastiveBatchTensorList.getPastHiddenStates());
                NDScope.unregister(contrastiveBatchTensorList.getPastKeyValues());
                nDScope.close();
                long[] longArray = contrastiveBatchTensorList.getPastOutputIds().toLongArray();
                for (int i = 0; i < this.endPosition.length; i++) {
                    int length = longArray.length;
                    int i2 = 0;
                    while (true) {
                        if (i2 >= length) {
                            break;
                        }
                        if (longArray[i2] == this.config.getEosTokenId()) {
                            this.endPosition[i] = contrastiveBatchTensorList.getPastOutputIds().getShape().get(1);
                            break;
                        }
                        i2++;
                    }
                }
            } catch (Throwable th) {
                try {
                    nDScope.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        } while (contrastiveBatchTensorList.getPastOutputIds().getShape().get(1) < this.config.getMaxSeqLength());
        return contrastiveBatchTensorList.getPastOutputIds();
    }

    private static BeamBatchTensorList updateSearchState(BeamBatchTensorList beamBatchTensorList, CausalLMOutput causalLMOutput, NDList nDList, NDManager nDManager) {
        NDList pastKeyValues = beamBatchTensorList.getPastKeyValues();
        long j = pastKeyValues.get(0).getShape().get(2);
        long lastDimension = pastKeyValues.get(0).getShape().getLastDimension();
        long j2 = beamBatchTensorList.getPastOutputIds().getShape().get(0);
        long j3 = beamBatchTensorList.getPastOutputIds().getShape().get(1);
        long lastDimension2 = beamBatchTensorList.getPastOutputIds().getShape().getLastDimension();
        NDArray nDArray = nDList.get(0);
        if (!$assertionsDisabled && nDArray.getShape().getShape().length != 3) {
            throw new AssertionError("Wrong Shape");
        }
        NDArray nDArray2 = nDList.get(1);
        NDIndex nDIndex = new NDIndex("{}, {}, ...", nDManager.arange(0.0f, (float) j2, 1.0f, DataType.INT64).expandDims(1).repeat(1, j3), nDList.get(2));
        return new BeamBatchTensorList(nDArray, beamBatchTensorList.getPastOutputIds().concat(beamBatchTensorList.getNextInputIds(), -1).get(nDIndex), new NDList((Collection<NDArray>) causalLMOutput.getPastKeyValuesList().stream().map(nDArray3 -> {
            return nDArray3.reshape(j2, j3, j, lastDimension2 + 1, lastDimension).get(nDIndex);
        }).collect(Collectors.toList())), beamBatchTensorList.getPastAttentionMask().concat(nDManager.ones(new Shape(j2, j3, 1), DataType.INT64), -1).get(nDIndex), nDArray2);
    }

    private static ContrastiveBatchTensorList updateSearchState(ContrastiveBatchTensorList contrastiveBatchTensorList, CausalLMOutput causalLMOutput, NDList nDList, NDManager nDManager) {
        if (!$assertionsDisabled && causalLMOutput.getLogits().getShape().get(1) != 1) {
            throw new AssertionError("dimension check: here, outputLogits corresponds to inputSeq == 1");
        }
        long j = contrastiveBatchTensorList.getLogits().getShape().get(0);
        long j2 = contrastiveBatchTensorList.getLogits().getShape().get(1);
        long j3 = contrastiveBatchTensorList.getPastOutputIds().getShape().get(1);
        long j4 = contrastiveBatchTensorList.getPastKeyValues().get(0).getShape().get(1);
        long j5 = contrastiveBatchTensorList.getPastKeyValues().get(0).getShape().get(3);
        long j6 = contrastiveBatchTensorList.getPastHiddenStates().getShape().get(2);
        long j7 = causalLMOutput.getLogits().getShape().get(0) / j;
        NDIndex nDIndex = new NDIndex("{}, {}, ...", nDManager.arange(0.0f, (float) j, 1.0f, DataType.INT64), nDList.get(1).flatten());
        NDArray nDArray = causalLMOutput.getLogits().reshape(j, j7, j2).get(nDIndex);
        NDList nDList2 = new NDList((Collection<NDArray>) causalLMOutput.getPastKeyValuesList().stream().map(nDArray2 -> {
            return nDArray2.reshape(j, j7, j4, j3 + 1, j5).get(nDIndex);
        }).collect(Collectors.toList()));
        NDArray hiddenState = causalLMOutput.getHiddenState();
        if (!$assertionsDisabled && hiddenState.getManager() != nDManager) {
            throw new AssertionError("possible leaky memory");
        }
        return new ContrastiveBatchTensorList(contrastiveBatchTensorList.getPastOutputIds().concat(nDList.get(0), 1), contrastiveBatchTensorList.getPastAttentionMask().concat(nDManager.ones(new Shape(j, 1), DataType.INT64), 1), contrastiveBatchTensorList.getPastHiddenStates().concat(hiddenState.reshape(j, j7, 1, j6).get(nDIndex), 1), nDArray, nDList2, new long[0]);
    }

    private NDArray prepareAttentionMaskOffset(NDArray nDArray, SearchConfig searchConfig) {
        boolean isSuffixPadding = searchConfig.isSuffixPadding();
        NDManager manager = nDArray.getManager();
        int intExact = Math.toIntExact(nDArray.getShape().get(0));
        int intExact2 = Math.toIntExact(nDArray.getShape().get(1));
        NDArray repeat = manager.ones(new Shape(1, nDArray.getShape().getLastDimension()), DataType.INT64).reshape(1, -1).repeat(0, intExact);
        long[][] jArr = new long[intExact][1];
        for (int i = 0; i < intExact; i++) {
            long[] longArray = nDArray.get("{},:", Integer.valueOf(i)).toLongArray();
            int i2 = 0;
            while (i2 < intExact2 && ((!isSuffixPadding || longArray[i2] != searchConfig.getPadTokenId()) && (isSuffixPadding || longArray[i2] == searchConfig.getPadTokenId()))) {
                i2++;
            }
            Object[] objArr = new Object[3];
            objArr[0] = Integer.valueOf(i);
            objArr[1] = Integer.valueOf(isSuffixPadding ? i2 : 0);
            objArr[2] = Integer.valueOf(isSuffixPadding ? intExact2 : i2);
            repeat.set(new NDIndex("{},{}:{}", objArr), (Number) 0);
            if (!isSuffixPadding) {
                jArr[i][0] = i2;
            }
        }
        this.positionOffset = manager.create(jArr);
        return repeat;
    }

    private NDList prepareInput(NDArray nDArray, NDArray nDArray2, long j, int i) {
        NDArray subi = nDArray.getManager().arange((float) j, (float) (j + nDArray.getShape().getLastDimension()), 1.0f, DataType.INT64).expandDims(0).repeat(0, nDArray.getShape().get(0)).subi(this.positionOffset.repeat(0, i));
        return new NDList(nDArray, subi.maximum(subi.zerosLike()), nDArray2);
    }

    public NDArray generate(NDArray nDArray) throws TranslateException {
        String str = this.searchName;
        boolean z = -1;
        switch (str.hashCode()) {
            case -2108649834:
                if (str.equals("contrastive")) {
                    z = 2;
                    break;
                }
                break;
            case -1237774176:
                if (str.equals("greedy")) {
                    z = false;
                    break;
                }
                break;
            case 3019695:
                if (str.equals("beam")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return greedySearch(nDArray);
            case true:
                return beamSearch(nDArray);
            case true:
                return contrastiveSearch(nDArray);
            default:
                throw new IllegalArgumentException("searchName not correctly specified. Please choose among: {greedy, beam, contrastive}");
        }
    }

    public NDArray getPositionOffset() {
        return this.positionOffset;
    }

    public long[] getEndPosition() {
        return this.endPosition;
    }

    static {
        $assertionsDisabled = !TextGenerator.class.desiredAssertionStatus();
    }
}
