package ai.djl.training.hyperparameter;

import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.hyperparameter.optimizer.HpORandom;
import ai.djl.training.hyperparameter.param.HpSet;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/training/hyperparameter/EasyHpo.class */
public abstract class EasyHpo {
    private static final Logger logger = LoggerFactory.getLogger(EasyHpo.class);

    public Pair<Model, TrainingResult> fit() throws IOException, TranslateException {
        RandomAccessDataset dataset = getDataset(Dataset.Usage.TRAIN);
        RandomAccessDataset dataset2 = getDataset(Dataset.Usage.TEST);
        HpORandom hpORandom = new HpORandom(setupHyperParams());
        int numHyperParameterTests = numHyperParameterTests();
        for (int i = 0; i < numHyperParameterTests; i++) {
            HpSet nextConfig = hpORandom.nextConfig();
            Pair<Model, TrainingResult> train = train(nextConfig, dataset, dataset2);
            train.getKey().close();
            float floatValue = train.getValue().getValidateLoss().floatValue();
            hpORandom.update(nextConfig, floatValue);
            logger.info("--------- hp test {}/{} - Loss {} - {}", new Object[]{Integer.valueOf(i), Integer.valueOf(numHyperParameterTests), Float.valueOf(floatValue), nextConfig});
        }
        Pair<Model, TrainingResult> train2 = train(hpORandom.getBest().getKey(), dataset, dataset2);
        saveModel(train2.getKey(), train2.getValue());
        return train2;
    }

    private Pair<Model, TrainingResult> train(HpSet hpSet, RandomAccessDataset randomAccessDataset, RandomAccessDataset randomAccessDataset2) throws IOException, TranslateException {
        Model buildModel = buildModel(hpSet);
        Trainer newTrainer = buildModel.newTrainer(setupTrainingConfig(hpSet));
        try {
            newTrainer.setMetrics(new Metrics());
            newTrainer.initialize(inputShape(hpSet));
            EasyTrain.fit(newTrainer, numEpochs(hpSet), randomAccessDataset, randomAccessDataset2);
            Pair<Model, TrainingResult> pair = new Pair<>(buildModel, newTrainer.getTrainingResult());
            if (newTrainer != null) {
                newTrainer.close();
            }
            return pair;
        } catch (Throwable th) {
            if (newTrainer != null) {
                try {
                    newTrainer.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    protected abstract HpSet setupHyperParams();

    protected abstract RandomAccessDataset getDataset(Dataset.Usage usage) throws IOException;

    protected abstract TrainingConfig setupTrainingConfig(HpSet hpSet);

    protected abstract Model buildModel(HpSet hpSet);

    protected abstract Shape inputShape(HpSet hpSet);

    protected abstract int numEpochs(HpSet hpSet);

    protected abstract int numHyperParameterTests();

    protected void saveModel(Model model, TrainingResult trainingResult) throws IOException {
    }
}
