package ai.djl.training.evaluator;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.util.Pair;
import java.util.stream.IntStream;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/training/evaluator/TopKAccuracy.class */
public class TopKAccuracy extends AbstractAccuracy {
    private int topK;

    public TopKAccuracy(String str, int i) {
        super(str);
        if (i <= 1) {
            throw new IllegalArgumentException("Please use TopKAccuracy with topK more than 1");
        }
        this.topK = i;
    }

    public TopKAccuracy(int i) {
        this("Top_" + i + "_Accuracy", i);
    }

    @Override // ai.djl.training.evaluator.AbstractAccuracy
    protected Pair<Long, NDArray> accuracyHelper(NDList nDList, NDList nDList2) {
        NDArray add;
        NDArray head = nDList.head();
        NDArray head2 = nDList2.head();
        checkLabelShapes(head, head2);
        NDArray type = head2.argSort(this.axis).toType(DataType.INT64, false);
        int dimension = type.getShape().dimension();
        if (dimension == 1) {
            add = type.flatten().eq(head.flatten()).countNonzero();
        } else {
            if (dimension != 2) {
                throw new IllegalArgumentException("Prediction should be less than 2 dimensions");
            }
            int i = (int) type.getShape().get(1);
            this.topK = Math.min(this.topK, i);
            add = NDArrays.add((NDArray[]) IntStream.range(0, this.topK).mapToObj(i2 -> {
                return type.get(":, {}", Integer.valueOf((i - i2) - 1)).flatten().eq(head.flatten().toType(DataType.INT64, false)).countNonzero();
            }).toArray(i3 -> {
                return new NDArray[i3];
            }));
        }
        return new Pair<>(Long.valueOf(head.getShape().get(0)), add);
    }
}
