package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/training/loss/L2Loss.class */
public class L2Loss extends Loss {
    private float weight;

    public L2Loss() {
        this("L2Loss");
    }

    public L2Loss(String str) {
        this(str, 0.5f);
    }

    public L2Loss(String str, float f) {
        super(str);
        this.weight = f;
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        NDArray singletonOrThrow = nDList2.singletonOrThrow();
        return nDList.singletonOrThrow().reshape(singletonOrThrow.getShape()).sub(singletonOrThrow).square().mul(Float.valueOf(this.weight)).mean();
    }
}
