package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.training.loss.Loss;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/nn/transformer/BertMaskedLanguageModelLoss.class */
public class BertMaskedLanguageModelLoss extends Loss {
    private int labelIdx;
    private int maskIdx;
    private int logProbsIdx;

    public BertMaskedLanguageModelLoss(int i, int i2, int i3) {
        super("BertMLLoss");
        this.labelIdx = i;
        this.maskIdx = i2;
        this.logProbsIdx = i3;
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        NDManager subManagerOf = NDManager.subManagerOf(nDList);
        try {
            subManagerOf.tempAttachAll(nDList, nDList2);
            NDArray nDArray = nDList2.get(this.logProbsIdx);
            int i = (int) nDArray.getShape().get(1);
            NDArray flatten = nDList.get(this.labelIdx).flatten();
            NDArray type = nDList.get(this.maskIdx).flatten().toType(DataType.FLOAT32, false);
            NDArray nDArray2 = (NDArray) subManagerOf.ret(nDArray.mul(flatten.oneHot(i)).sum(new int[]{1}).mul((Number) (-1)).mul(type).sum().div(type.sum().add(Float.valueOf(1.0E-5f))));
            if (subManagerOf != null) {
                subManagerOf.close();
            }
            return nDArray2;
        } catch (Throwable th) {
            if (subManagerOf != null) {
                try {
                    subManagerOf.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public NDArray accuracy(NDList nDList, NDList nDList2) {
        NDManager subManagerOf = NDManager.subManagerOf(nDList);
        try {
            subManagerOf.tempAttachAll(nDList, nDList2);
            NDArray flatten = nDList.get(this.maskIdx).flatten();
            NDArray nDArray = (NDArray) subManagerOf.ret(nDList2.get(this.logProbsIdx).argMax(1).toType(DataType.INT32, false).eq(nDList.get(this.labelIdx).flatten()).mul(flatten).sum().toType(DataType.FLOAT32, false).div(flatten.sum().toType(DataType.FLOAT32, false)));
            if (subManagerOf != null) {
                subManagerOf.close();
            }
            return nDArray;
        } catch (Throwable th) {
            if (subManagerOf != null) {
                try {
                    subManagerOf.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }
}
