package ai.djl.modality.nlp.generate;

import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;
import java.util.Arrays;
import java.util.Collection;
import java.util.stream.Collectors;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/modality/nlp/generate/ContrastiveSeqBatchScheduler.class */
public class ContrastiveSeqBatchScheduler extends SeqBatchScheduler {
    static final /* synthetic */ boolean $assertionsDisabled;

    public ContrastiveSeqBatchScheduler(Predictor<NDList, CausalLMOutput> predictor, SearchConfig searchConfig) {
        super(predictor, searchConfig);
    }

    @Override // ai.djl.modality.nlp.generate.SeqBatchScheduler
    public SeqBatcher initForward(NDArray nDArray, NDArray nDArray2) throws TranslateException {
        NDScope nDScope = new NDScope();
        try {
            nDScope.suppressNotUsedWarning();
            this.manager = nDArray.getManager();
            NDArray computeOffSets = computeOffSets(nDArray, this.config);
            NDArray computeAttentionMask = computeAttentionMask(nDArray, this.config);
            CausalLMOutput predict = this.predictor.predict(new NDList(nDArray, computePositionIds(nDArray, computeOffSets, 0L, 1), computeAttentionMask));
            NDArray nDArray3 = predict.getLogits().get(":, -1, :", new Object[0]);
            long[] jArr = new long[28];
            Arrays.fill(jArr, 0, 3, 1L);
            jArr[3] = -1;
            Arrays.fill(jArr, 4, jArr.length, 2L);
            SeqBatcher seqBatcher = new SeqBatcher(new ContrastiveBatchTensorList(nDArray, computeAttentionMask, predict.getHiddenState(), nDArray3, predict.getPastKeyValuesList(), jArr), nDArray2, computeOffSets, this.manager);
            NDScope.unregister(predict.getPastKeyValuesList());
            NDScope.unregister(predict.getHiddenState(), computeAttentionMask, nDArray3);
            NDScope.unregister(seqBatcher.offSets, seqBatcher.batchUid);
            nDScope.close();
            return seqBatcher;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.modality.nlp.generate.SeqBatchScheduler
    public NDArray inferenceCall() throws TranslateException {
        NDScope nDScope = new NDScope();
        try {
            nDScope.suppressNotUsedWarning();
            NDArray logits = ((ContrastiveBatchTensorList) this.seqBatcher.getData()).getLogits();
            NDArray nDArray = logits.topK(this.config.getK(), -1, true, false).get(1);
            ContrastiveBatchTensorList contrastiveBatchTensorList = (ContrastiveBatchTensorList) this.seqBatcher.data;
            NDArray reshape = nDArray.flatten().reshape(-1, 1);
            if (!$assertionsDisabled && reshape.getDataType() != DataType.INT64) {
                throw new AssertionError("inputIds datatype should be int64");
            }
            if (!$assertionsDisabled && reshape.getShape().getShape().length != 2) {
                throw new AssertionError("shape not right");
            }
            NDList nDList = new NDList((Collection<NDArray>) contrastiveBatchTensorList.getPastKeyValues().stream().map(nDArray2 -> {
                return nDArray2.repeat(0, this.config.getK());
            }).collect(Collectors.toList()));
            if (!$assertionsDisabled && nDList.get(0).getDataType() != DataType.FLOAT32) {
                throw new AssertionError("inputIds datatype should be Float32");
            }
            long j = nDArray.getShape().get(0);
            NDArray concat = contrastiveBatchTensorList.getPastAttentionMask().repeat(0, this.config.getK()).concat(this.manager.ones(new Shape(j * this.config.getK(), 1), DataType.INT64), 1);
            if (!$assertionsDisabled && nDList.get(0).getShape().get(2) + 1 != concat.getShape().getLastDimension()) {
                throw new AssertionError("attentionMask_seq = past_seq + new_input_seq");
            }
            NDList nDList2 = new NDList(reshape, computePositionIds(reshape, this.seqBatcher.offSets, contrastiveBatchTensorList.getPastOutputIds().getShape().getLastDimension(), this.config.getK()), concat);
            nDList2.addAll(nDList);
            CausalLMOutput predict = this.predictor.predict(nDList2);
            NDList constrastiveStepGeneration = StepGeneration.constrastiveStepGeneration(nDArray, logits, contrastiveBatchTensorList.getPastHiddenStates(), predict.getHiddenState(), this.seqBatcher.offSets, this.config.getAlpha());
            long j2 = logits.getShape().get(1);
            long j3 = contrastiveBatchTensorList.getPastKeyValues().get(0).getShape().get(1);
            long j4 = contrastiveBatchTensorList.getPastKeyValues().get(0).getShape().get(3);
            long j5 = contrastiveBatchTensorList.getPastOutputIds().getShape().get(1);
            long j6 = contrastiveBatchTensorList.getPastHiddenStates().getShape().get(2);
            NDIndex nDIndex = new NDIndex("{}, {}, ...", this.manager.arange(0.0f, (float) j, 1.0f, DataType.INT64), constrastiveStepGeneration.get(1).flatten());
            NDArray nDArray3 = predict.getLogits().reshape(j, this.config.getK(), j2).get(nDIndex);
            NDList nDList3 = new NDList((Collection<NDArray>) predict.getPastKeyValuesList().stream().map(nDArray4 -> {
                return nDArray4.reshape(j, this.config.getK(), j3, j5 + 1, j4).get(nDIndex);
            }).collect(Collectors.toList()));
            NDArray hiddenState = predict.getHiddenState();
            if (!$assertionsDisabled && hiddenState.getManager() != this.manager) {
                throw new AssertionError("possible leaky memory");
            }
            NDArray concat2 = contrastiveBatchTensorList.getPastHiddenStates().concat(hiddenState.reshape(j, this.config.getK(), 1, j6).get(nDIndex), 1);
            NDArray nDArray5 = constrastiveStepGeneration.get(0);
            NDArray concat3 = contrastiveBatchTensorList.getPastOutputIds().concat(nDArray5, 1);
            NDArray concat4 = contrastiveBatchTensorList.getPastAttentionMask().concat(this.manager.ones(new Shape(j, 1), DataType.INT64), 1);
            this.seqBatcher.seqLength++;
            this.seqBatcher.data = new ContrastiveBatchTensorList(concat3, concat4, concat2, nDArray3, nDList3, contrastiveBatchTensorList.getSeqDimOrder());
            this.seqBatcher.exitCriteria(nDArray5, this.config.getMaxSeqLength(), this.config.getEosTokenId());
            NDScope.unregister(concat3);
            NDScope.unregister(concat4);
            NDScope.unregister(concat2);
            NDScope.unregister(nDArray3);
            NDScope.unregister(nDList3);
            NDScope.unregister(nDArray5);
            nDScope.close();
            return nDArray5;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    static {
        $assertionsDisabled = !ContrastiveSeqBatchScheduler.class.desiredAssertionStatus();
    }
}
