package ai.djl.training.loss;

import ai.djl.modality.cv.MultiBoxTarget;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.util.Pair;
import java.util.Arrays;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/training/loss/SingleShotDetectionLoss.class */
public class SingleShotDetectionLoss extends AbstractCompositeLoss {
    private MultiBoxTarget multiBoxTarget;

    public SingleShotDetectionLoss() {
        super("SingleShotDetectionLoss");
        this.multiBoxTarget = MultiBoxTarget.builder().build();
        this.components = Arrays.asList(Loss.softmaxCrossEntropyLoss("ClassLoss"), Loss.l1Loss("BoundingBoxLoss"));
    }

    @Override // ai.djl.training.loss.AbstractCompositeLoss
    protected Pair<NDList, NDList> inputForComponent(int i, NDList nDList, NDList nDList2) {
        NDArray nDArray = nDList2.get(0);
        NDArray nDArray2 = nDList2.get(1);
        NDList target = this.multiBoxTarget.target(new NDList(nDArray, nDList.head(), nDArray2.transpose(0, 2, 1)));
        switch (i) {
            case 0:
                return new Pair<>(new NDList(target.get(2)), new NDList(nDArray2));
            case 1:
                NDArray nDArray3 = nDList2.get(2);
                NDArray nDArray4 = target.get(0);
                NDArray nDArray5 = target.get(1);
                return new Pair<>(new NDList(nDArray4.mul(nDArray5)), new NDList(nDArray3.mul(nDArray5)));
            default:
                throw new IllegalArgumentException("Invalid component index");
        }
    }
}
