/*
 * Decompiled with CFR 0.152.
 */
package com.solegendary.reignofnether.bot.ml;

import com.solegendary.reignofnether.ReignOfNether;
import com.solegendary.reignofnether.bot.ml.MLDecision;
import java.io.File;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class NeuralNetworkManager {
    private static final int INPUT_SIZE = 64;
    private static final int OUTPUT_SIZE = MLDecision.DecisionType.values().length;
    private static final int HIDDEN_LAYER_SIZE = 128;
    private static final double LEARNING_RATE = 0.001;
    private final Map<String, MultiLayerNetwork> botNetworks = new ConcurrentHashMap<String, MultiLayerNetwork>();
    private final Map<String, Integer> trainingIterations = new ConcurrentHashMap<String, Integer>();
    private final Map<String, Double> networkAccuracy = new ConcurrentHashMap<String, Double>();
    private final String modelsDirectory = "config/reignofnether/ml_models/";
    private MultiLayerNetwork masterNetwork;
    private boolean initialized = false;

    public void initialize() {
        try {
            ReignOfNether.LOGGER.info("Initializing Neural Network Manager...");
            File modelsDir = new File("config/reignofnether/ml_models/");
            if (!modelsDir.exists()) {
                modelsDir.mkdirs();
            }
            ReignOfNether.LOGGER.debug("Testing DeepLearning4J availability...");
            this.masterNetwork = this.createNetwork();
            ReignOfNether.LOGGER.debug("Created master network successfully");
            this.loadExistingModels();
            this.initialized = true;
            ReignOfNether.LOGGER.info("Neural Network Manager initialized with {} bot networks", (Object)this.botNetworks.size());
        }
        catch (Exception e) {
            ReignOfNether.LOGGER.error("Failed to initialize Neural Network Manager: {}", (Object)e.getMessage());
            ReignOfNether.LOGGER.error("This might be due to missing DeepLearning4J native libraries or CPU compatibility");
            ReignOfNether.LOGGER.error("Neural network details: ", (Throwable)e);
            throw new RuntimeException("Neural network initialization failed", e);
        }
    }

    private MultiLayerNetwork createNetwork() {
        MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().seed(123L).weightInit(WeightInit.XAVIER).updater((IUpdater)new Adam(0.001)).list().layer(0, (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(64)).nOut(128)).activation(Activation.RELU)).build()).layer(1, (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(128)).nOut(64)).activation(Activation.RELU)).build()).layer(2, (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(64)).nOut(OUTPUT_SIZE)).activation(Activation.SOFTMAX)).build()).build();
        MultiLayerNetwork network = new MultiLayerNetwork(config);
        network.init();
        return network;
    }

    public MultiLayerNetwork getNetworkForBot(String botName) {
        return this.botNetworks.computeIfAbsent(botName, name -> {
            MultiLayerNetwork network = this.createNetwork();
            this.trainingIterations.put((String)name, 0);
            this.networkAccuracy.put((String)name, 0.0);
            ReignOfNether.LOGGER.debug("Created new neural network for bot '{}'", name);
            return network;
        });
    }

    public MLDecision predictDecision(double[] stateVector, String botName) {
        if (!this.initialized) {
            return new MLDecision(MLDecision.DecisionType.FALLBACK_TO_RULES, 0.0);
        }
        try {
            MultiLayerNetwork network = this.getNetworkForBot(botName);
            INDArray input = Nd4j.create((double[])stateVector).reshape(1L, 64L);
            INDArray output = network.output(input);
            int maxIndex = 0;
            double maxConfidence = output.getDouble(0L, 0L);
            for (int i = 1; i < OUTPUT_SIZE; ++i) {
                double confidence = output.getDouble(0L, (long)i);
                if (!(confidence > maxConfidence)) continue;
                maxConfidence = confidence;
                maxIndex = i;
            }
            MLDecision.DecisionType decisionType = MLDecision.DecisionType.values()[maxIndex];
            if (maxConfidence < 0.6) {
                return new MLDecision(MLDecision.DecisionType.FALLBACK_TO_RULES, maxConfidence);
            }
            return new MLDecision(decisionType, maxConfidence, this.extractParameters(output), "Neural network prediction");
        }
        catch (Exception e) {
            ReignOfNether.LOGGER.error("Error making ML prediction for bot '{}': {}", (Object)botName, (Object)e.getMessage());
            return new MLDecision(MLDecision.DecisionType.FALLBACK_TO_RULES, 0.0);
        }
    }

    public void trainNetwork(String botName, DataSet trainingData) {
        if (!this.initialized) {
            return;
        }
        try {
            MultiLayerNetwork network = this.getNetworkForBot(botName);
            network.fit((org.nd4j.linalg.dataset.api.DataSet)trainingData);
            int iterations = this.trainingIterations.get(botName) + 1;
            this.trainingIterations.put(botName, iterations);
            if (iterations % 10 == 0) {
                double accuracy = this.evaluateNetwork(network, trainingData);
                this.networkAccuracy.put(botName, accuracy);
                ReignOfNether.LOGGER.debug("Bot '{}' network training iteration {}, accuracy: {:.2f}%", (Object)botName, (Object)iterations, (Object)(accuracy * 100.0));
            }
        }
        catch (Exception e) {
            ReignOfNether.LOGGER.error("Error training network for bot '{}': {}", (Object)botName, (Object)e.getMessage());
        }
    }

    public void performOnlineTraining(String botName, DataSet onlineData, double adaptiveLearningRate) {
        if (!this.initialized) {
            return;
        }
        try {
            MultiLayerNetwork network = this.getNetworkForBot(botName);
            int iterations = Math.max(1, (int)(adaptiveLearningRate * 10.0));
            for (int i = 0; i < iterations; ++i) {
                onlineData.shuffle();
                network.fit((org.nd4j.linalg.dataset.api.DataSet)onlineData);
            }
            ReignOfNether.LOGGER.debug("Online training completed for bot '{}' with {} iterations (adaptive rate: {:.4f})", (Object)botName, (Object)iterations, (Object)adaptiveLearningRate);
        }
        catch (Exception e) {
            ReignOfNether.LOGGER.error("Error during online training for bot '{}': {}", (Object)botName, (Object)e.getMessage());
        }
    }

    private double evaluateNetwork(MultiLayerNetwork network, DataSet testData) {
        try {
            INDArray predictions = network.output(testData.getFeatures());
            INDArray labels = testData.getLabels();
            int correct = 0;
            int total = (int)predictions.size(0);
            for (int i = 0; i < total; ++i) {
                int predictedClass = Nd4j.argMax((INDArray)predictions.getRow((long)i), (int[])new int[]{1}).getInt(new int[]{0});
                int actualClass = Nd4j.argMax((INDArray)labels.getRow((long)i), (int[])new int[]{1}).getInt(new int[]{0});
                if (predictedClass != actualClass) continue;
                ++correct;
            }
            return (double)correct / (double)total;
        }
        catch (Exception e) {
            ReignOfNether.LOGGER.error("Error evaluating network accuracy: {}", (Object)e.getMessage());
            return 0.0;
        }
    }

    private Map<String, Double> extractParameters(INDArray output) {
        ConcurrentHashMap<String, Double> parameters = new ConcurrentHashMap<String, Double>();
        for (int i = 0; i < OUTPUT_SIZE; ++i) {
            MLDecision.DecisionType type = MLDecision.DecisionType.values()[i];
            parameters.put(type.name() + "_confidence", output.getDouble(0L, (long)i));
        }
        double entropy = 0.0;
        for (int i = 0; i < OUTPUT_SIZE; ++i) {
            double prob = output.getDouble(0L, (long)i);
            if (!(prob > 0.0)) continue;
            entropy -= prob * Math.log(prob);
        }
        parameters.put("decision_entropy", entropy);
        return parameters;
    }

    public void saveAllModels() {
        if (!this.initialized) {
            return;
        }
        ReignOfNether.LOGGER.info("Saving {} neural network models...", (Object)this.botNetworks.size());
        for (Map.Entry<String, MultiLayerNetwork> entry : this.botNetworks.entrySet()) {
            String botName = entry.getKey();
            MultiLayerNetwork network = entry.getValue();
            try {
                String filename = "config/reignofnether/ml_models/bot_" + this.sanitizeFilename(botName) + ".zip";
                network.save(new File(filename));
                ReignOfNether.LOGGER.debug("Saved model for bot '{}' to {}", (Object)botName, (Object)filename);
            }
            catch (IOException e) {
                ReignOfNether.LOGGER.error("Failed to save model for bot '{}': {}", (Object)botName, (Object)e.getMessage());
            }
        }
    }

    private void loadExistingModels() {
        File modelsDir = new File("config/reignofnether/ml_models/");
        if (!modelsDir.exists() || !modelsDir.isDirectory()) {
            return;
        }
        File[] modelFiles = modelsDir.listFiles((dir, name) -> name.startsWith("bot_") && name.endsWith(".zip"));
        if (modelFiles == null) {
            return;
        }
        for (File modelFile : modelFiles) {
            try {
                String filename = modelFile.getName();
                String botName = filename.substring(4, filename.length() - 4);
                botName = botName.replaceAll("_", " ");
                MultiLayerNetwork network = MultiLayerNetwork.load((File)modelFile, (boolean)true);
                this.botNetworks.put(botName, network);
                this.trainingIterations.put(botName, 0);
                this.networkAccuracy.put(botName, 0.0);
                ReignOfNether.LOGGER.info("Loaded existing model for bot '{}'", (Object)botName);
            }
            catch (IOException e) {
                ReignOfNether.LOGGER.error("Failed to load model from {}: {}", (Object)modelFile.getName(), (Object)e.getMessage());
            }
        }
    }

    private String sanitizeFilename(String name) {
        return name.replaceAll("[^a-zA-Z0-9_-]", "_");
    }

    public int getTrainingIterations() {
        return this.trainingIterations.values().stream().mapToInt(Integer::intValue).sum();
    }

    public double getAverageAccuracy() {
        if (this.networkAccuracy.isEmpty()) {
            return 0.0;
        }
        return this.networkAccuracy.values().stream().mapToDouble(Double::doubleValue).average().orElse(0.0);
    }

    public int getActiveNetworkCount() {
        return this.botNetworks.size();
    }

    public void resetBotNetwork(String botName) {
        if (this.botNetworks.containsKey(botName)) {
            this.botNetworks.put(botName, this.createNetwork());
            this.trainingIterations.put(botName, 0);
            this.networkAccuracy.put(botName, 0.0);
            ReignOfNether.LOGGER.info("Reset neural network for bot '{}'", (Object)botName);
        }
    }

    public boolean isInitialized() {
        return this.initialized;
    }
}

