package ai.djl.training.optimizer;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.ParameterTracker;
import ai.djl.training.tracker.Tracker;
import ai.djl.util.Preconditions;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/training/optimizer/Adam.class */
public class Adam extends Optimizer {
    private ParameterTracker learningRateTracker;
    private float beta1;
    private float beta2;
    private float epsilon;
    private Map<String, Map<Device, NDArray>> means;
    private Map<String, Map<Device, NDArray>> variances;

    /* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/training/optimizer/Adam$Builder.class */
    public static final class Builder extends Optimizer.OptimizerBuilder<Builder> {
        private ParameterTracker learningRateTracker = Tracker.fixed(0.001f);
        private float beta1 = 0.9f;
        private float beta2 = 0.999f;
        private float epsilon = 1.0E-8f;

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.training.optimizer.Optimizer.OptimizerBuilder
        public Builder self() {
            return this;
        }

        public Builder optLearningRateTracker(ParameterTracker parameterTracker) {
            this.learningRateTracker = parameterTracker;
            return this;
        }

        public Builder optBeta1(float f) {
            this.beta1 = f;
            return this;
        }

        public Builder optBeta2(float f) {
            this.beta2 = f;
            return this;
        }

        public Builder optEpsilon(float f) {
            this.epsilon = f;
            return this;
        }

        public Adam build() {
            return new Adam(this);
        }
    }

    protected Adam(Builder builder) {
        super(builder);
        this.learningRateTracker = builder.learningRateTracker;
        this.beta1 = builder.beta1;
        this.beta2 = builder.beta2;
        this.epsilon = builder.epsilon;
        this.means = new ConcurrentHashMap();
        this.variances = new ConcurrentHashMap();
    }

    @Override // ai.djl.training.optimizer.Optimizer
    public void update(String str, NDArray nDArray, NDArray nDArray2) {
        int updateCount = updateCount(str);
        double pow = 1.0d - Math.pow(this.beta1, updateCount);
        double pow2 = 1.0d - Math.pow(this.beta2, updateCount);
        float newValue = this.learningRateTracker.getNewValue(str, updateCount);
        float sqrt = (float) ((newValue * Math.sqrt(pow2)) / pow);
        float weightDecay = getWeightDecay();
        Preconditions.checkArgument((Float.isNaN(sqrt) || Float.isNaN(weightDecay) || Float.isInfinite(sqrt) || Float.isInfinite(weightDecay)) ? false : true, "learning rate or weight decay is nan or infinite");
        nDArray.getNDArrayInternal().adamUpdate(new NDList(nDArray, nDArray2, withDefaultState(this.means, str, nDArray.getDevice(), str2 -> {
            return nDArray.zerosLike();
        }), withDefaultState(this.variances, str, nDArray.getDevice(), str3 -> {
            return nDArray.zerosLike();
        })), new NDList(nDArray), newValue, sqrt, weightDecay, this.rescaleGrad, this.clipGrad, this.beta1, this.beta2, this.epsilon, true, false);
    }

    public static Builder builder() {
        return new Builder();
    }
}
