package ai.djl.training.listener;

import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.listener.TrainingListener;
import java.time.Duration;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/training/listener/EarlyStoppingListener.class */
public final class EarlyStoppingListener implements TrainingListener {
    private final double objectiveSuccess;
    private final int minEpochs;
    private final long maxMillis;
    private final double earlyStopPctImprovement;
    private final int epochPatience;
    private long startTimeMills;
    private double prevLoss;
    private int numberOfEpochsWithoutImprovements;

    /* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/training/listener/EarlyStoppingListener$Builder.class */
    public static final class Builder {
        private final double objectiveSuccess = 0.0d;
        private int minEpochs = 0;
        private long maxMillis = Long.MAX_VALUE;
        private double earlyStopPctImprovement = 0.0d;
        private int epochPatience = 0;

        public Builder optMinEpochs(int i) {
            this.minEpochs = i;
            return this;
        }

        public Builder optMaxDuration(Duration duration) {
            this.maxMillis = duration.toMillis();
            return this;
        }

        public Builder optMaxMillis(int i) {
            this.maxMillis = i;
            return this;
        }

        public Builder optEarlyStopPctImprovement(double d) {
            this.earlyStopPctImprovement = d;
            return this;
        }

        public Builder optEpochPatience(int i) {
            this.epochPatience = i;
            return this;
        }

        public EarlyStoppingListener build() {
            return new EarlyStoppingListener(this.objectiveSuccess, this.minEpochs, this.maxMillis, this.earlyStopPctImprovement, this.epochPatience);
        }
    }

    /* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/training/listener/EarlyStoppingListener$EarlyStoppedException.class */
    public static class EarlyStoppedException extends RuntimeException {
        private static final long serialVersionUID = 1;
        private final int stopEpoch;

        public EarlyStoppedException(int i, String str) {
            super(str);
            this.stopEpoch = i;
        }

        public int getStopEpoch() {
            return this.stopEpoch;
        }
    }

    private EarlyStoppingListener(double d, int i, long j, double d2, int i2) {
        this.objectiveSuccess = d;
        this.minEpochs = i;
        this.maxMillis = j;
        this.earlyStopPctImprovement = d2;
        this.epochPatience = i2;
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onEpoch(Trainer trainer) {
        int epoch = trainer.getTrainingResult().getEpoch();
        double loss = getLoss(trainer.getTrainingResult());
        if (epoch >= this.minEpochs) {
            if (loss < this.objectiveSuccess) {
                throw new EarlyStoppedException(epoch, String.format("validation loss %s < objectiveSuccess %s", Double.valueOf(loss), Double.valueOf(this.objectiveSuccess)));
            }
            long currentTimeMillis = System.currentTimeMillis() - this.startTimeMills;
            if (currentTimeMillis >= this.maxMillis) {
                throw new EarlyStoppedException(epoch, String.format("%s ms elapsed >= %s maxMillis", Long.valueOf(currentTimeMillis), Long.valueOf(this.maxMillis)));
            }
            if (Double.isFinite(this.prevLoss)) {
                if (loss <= (this.prevLoss * (100.0d - this.earlyStopPctImprovement)) / 100.0d) {
                    this.numberOfEpochsWithoutImprovements = 0;
                } else {
                    this.numberOfEpochsWithoutImprovements++;
                    if (this.numberOfEpochsWithoutImprovements >= this.epochPatience) {
                        throw new EarlyStoppedException(epoch, String.format("failed to achieve %s%% improvement %s times in a row", Double.valueOf(this.earlyStopPctImprovement), Integer.valueOf(this.epochPatience)));
                    }
                }
            }
        }
        if (Double.isFinite(loss)) {
            this.prevLoss = loss;
        }
    }

    private static double getLoss(TrainingResult trainingResult) {
        if (trainingResult.getValidateLoss() != null) {
            return r0.floatValue();
        }
        if (trainingResult.getTrainLoss() == null) {
            return Double.NaN;
        }
        return r0.floatValue();
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData) {
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData) {
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onTrainingBegin(Trainer trainer) {
        this.startTimeMills = System.currentTimeMillis();
        this.prevLoss = Double.NaN;
        this.numberOfEpochsWithoutImprovements = 0;
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onTrainingEnd(Trainer trainer) {
    }

    public static Builder builder() {
        return new Builder();
    }
}
