/*
 * Decompiled with CFR 0.152.
 */
package com.solegendary.reignofnether.bot.ml;

import com.solegendary.reignofnether.ReignOfNether;
import com.solegendary.reignofnether.bot.ml.NeuralNetworkManager;
import com.solegendary.reignofnether.bot.ml.TrainingDataManager;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.nd4j.linalg.dataset.DataSet;

public class LearningPipeline {
    private NeuralNetworkManager networkManager;
    private TrainingDataManager dataManager;
    private final ExecutorService trainingExecutor = Executors.newSingleThreadExecutor(r -> {
        Thread t = new Thread(r, "ML-Training-Pipeline");
        t.setDaemon(true);
        return t;
    });
    private final AtomicBoolean trainingInProgress = new AtomicBoolean(false);
    private final AtomicBoolean pipelineShutdown = new AtomicBoolean(false);
    private static final int TRAINING_BATCH_SIZE = 64;
    private static final int TRAINING_EPOCHS = 10;
    private static final long MIN_TRAINING_INTERVAL_MS = 300000L;
    private long lastTrainingTime = 0L;
    private int totalTrainingSessions = 0;

    public void initialize(NeuralNetworkManager networkManager, TrainingDataManager dataManager) {
        this.networkManager = networkManager;
        this.dataManager = dataManager;
        ReignOfNether.LOGGER.info("Learning Pipeline initialized and ready for training");
    }

    public void triggerTraining() {
        if (this.pipelineShutdown.get()) {
            return;
        }
        long currentTime = System.currentTimeMillis();
        if (currentTime - this.lastTrainingTime < 300000L) {
            return;
        }
        if (this.trainingInProgress.compareAndSet(false, true)) {
            this.lastTrainingTime = currentTime;
            CompletableFuture.runAsync(this::executeTrainingSession, this.trainingExecutor).whenComplete((result, throwable) -> {
                this.trainingInProgress.set(false);
                if (throwable != null) {
                    ReignOfNether.LOGGER.error("Training session failed: {}", (Object)throwable.getMessage());
                } else {
                    ++this.totalTrainingSessions;
                    ReignOfNether.LOGGER.info("Training session completed successfully (session #{})", (Object)this.totalTrainingSessions);
                }
            });
            ReignOfNether.LOGGER.info("Training session triggered and queued");
        }
    }

    private void executeTrainingSession() {
        try {
            ReignOfNether.LOGGER.info("Starting ML training session...");
            this.trainBotNetworks();
            this.performImitationLearning();
            this.updatePerformanceMetrics();
            ReignOfNether.LOGGER.info("ML training session completed successfully");
        }
        catch (Exception e) {
            ReignOfNether.LOGGER.error("Error during training session: {}", (Object)e.getMessage());
            throw new RuntimeException("Training session failed", e);
        }
    }

    private void trainBotNetworks() {
        ReignOfNether.LOGGER.debug("Training bot networks with decision data...");
        for (String botName : this.getBotNamesWithData()) {
            try {
                DataSet trainingData = this.dataManager.getTrainingDataSet(botName);
                if (trainingData == null) continue;
                ReignOfNether.LOGGER.debug("Training network for bot '{}' with {} examples", (Object)botName, (Object)trainingData.numExamples());
                int numBatches = Math.max(1, trainingData.numExamples() / 64);
                for (int epoch = 0; epoch < 10; ++epoch) {
                    int endIdx;
                    int startIdx;
                    trainingData.shuffle();
                    for (int batch = 0; batch < numBatches && (startIdx = batch * 64) < (endIdx = Math.min(startIdx + 64, trainingData.numExamples())); ++batch) {
                        DataSet batchData = (DataSet)trainingData.getRange(startIdx, endIdx);
                        this.networkManager.trainNetwork(botName, batchData);
                    }
                }
                ReignOfNether.LOGGER.debug("Completed training for bot '{}' ({} epochs, {} batches)", (Object)botName, (Object)10, (Object)numBatches);
            }
            catch (Exception e) {
                ReignOfNether.LOGGER.error("Error training network for bot '{}': {}", (Object)botName, (Object)e.getMessage());
            }
        }
    }

    private void performImitationLearning() {
        ReignOfNether.LOGGER.debug("Performing imitation learning from human data...");
        try {
            DataSet imitationData = this.dataManager.getImitationDataSet();
            if (imitationData == null) {
                ReignOfNether.LOGGER.debug("No human data available for imitation learning");
                return;
            }
            ReignOfNether.LOGGER.debug("Training all networks with {} human examples", (Object)imitationData.numExamples());
            for (String botName : this.getBotNamesWithData()) {
                try {
                    int imitationEpochs = Math.max(1, 3);
                    for (int epoch = 0; epoch < imitationEpochs; ++epoch) {
                        imitationData.shuffle();
                        this.networkManager.trainNetwork(botName, imitationData);
                    }
                }
                catch (Exception e) {
                    ReignOfNether.LOGGER.error("Error applying imitation learning to bot '{}': {}", (Object)botName, (Object)e.getMessage());
                }
            }
        }
        catch (Exception e) {
            ReignOfNether.LOGGER.error("Error in imitation learning: {}", (Object)e.getMessage());
        }
    }

    private void updatePerformanceMetrics() {
        ReignOfNether.LOGGER.debug("Updating network performance metrics...");
        int activeNetworks = this.networkManager.getActiveNetworkCount();
        int totalIterations = this.networkManager.getTrainingIterations();
        double avgAccuracy = this.networkManager.getAverageAccuracy();
        ReignOfNether.LOGGER.info("ML Performance Update: {} networks, {} training iterations, {:.2f}% average accuracy", (Object)activeNetworks, (Object)totalIterations, (Object)(avgAccuracy * 100.0));
    }

    private Set<String> getBotNamesWithData() {
        return new HashSet<String>();
    }

    public void shutdown() {
        if (this.pipelineShutdown.compareAndSet(false, true)) {
            ReignOfNether.LOGGER.info("Shutting down Learning Pipeline...");
            while (this.trainingInProgress.get()) {
                try {
                    Thread.sleep(1000L);
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    break;
                }
            }
            this.trainingExecutor.shutdown();
            try {
                if (!this.trainingExecutor.awaitTermination(10L, TimeUnit.SECONDS)) {
                    this.trainingExecutor.shutdownNow();
                }
            }
            catch (InterruptedException e) {
                this.trainingExecutor.shutdownNow();
                Thread.currentThread().interrupt();
            }
            ReignOfNether.LOGGER.info("Learning Pipeline shutdown complete");
        }
    }

    public boolean isTrainingInProgress() {
        return this.trainingInProgress.get();
    }

    public int getTotalTrainingSessions() {
        return this.totalTrainingSessions;
    }

    public long getTimeSinceLastTraining() {
        return System.currentTimeMillis() - this.lastTrainingTime;
    }

    public void forceTraining() {
        if (this.pipelineShutdown.get()) {
            return;
        }
        this.lastTrainingTime = 0L;
        this.triggerTraining();
        ReignOfNether.LOGGER.info("Forced ML training session initiated");
    }

    public void configureTraining(int batchSize, int epochs) {
        ReignOfNether.LOGGER.info("Training configuration: batch size {}, epochs {} (currently hardcoded)", (Object)batchSize, (Object)epochs);
    }
}

