package ai.topandrey15.reinforcemc.core;

import ai.topandrey15.reinforcemc.ReinforceMC;
import ai.topandrey15.reinforcemc.config.TrainingConfiguration;
import ai.topandrey15.reinforcemc.core.ExperienceBuffer;
import ai.topandrey15.reinforcemc.core.NeuralNetwork;
import java.io.Serializable;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

/* loaded from: input_file:ai/topandrey15/reinforcemc/core/NeuralNetworkTrainer.class */
public class NeuralNetworkTrainer {
    private final NeuralNetwork network;
    private final ExperienceBuffer experienceBuffer;
    private final ExecutorService trainingExecutor;
    private float[][] m_weights1;
    private float[][] v_weights1;
    private float[] m_biases1;
    private float[] v_biases1;
    private float[][] m_weights2;
    private float[][] v_weights2;
    private float[] m_biases2;
    private float[] v_biases2;
    private float[][] m_weights3;
    private float[][] v_weights3;
    private float[] m_biases3;
    private float[] v_biases3;
    private final AtomicBoolean isTraining = new AtomicBoolean(false);
    private volatile int pendingTrainingTasks = 0;
    private volatile int optimizerStep = 0;
    private volatile float averageLoss = 0.0f;
    private volatile long lastTrainingTime = 0;
    private volatile int totalTrainingOperations = 0;
    private final TrainingConfiguration config = TrainingConfiguration.getInstance();

    /* loaded from: input_file:ai/topandrey15/reinforcemc/core/NeuralNetworkTrainer$TrainerData.class */
    public static class TrainerData implements Serializable {
        private static final long serialVersionUID = 1;
        public int optimizerStep;
        public float averageLoss;
        public int totalTrainingOperations;
        public float[][] m_weights1;
        public float[][] v_weights1;
        public float[] m_biases1;
        public float[] v_biases1;
        public float[][] m_weights2;
        public float[][] v_weights2;
        public float[] m_biases2;
        public float[] v_biases2;
        public float[][] m_weights3;
        public float[][] v_weights3;
        public float[] m_biases3;
        public float[] v_biases3;
    }

    public NeuralNetworkTrainer(NeuralNetwork neuralNetwork, ExperienceBuffer experienceBuffer) {
        this.network = neuralNetwork;
        this.experienceBuffer = experienceBuffer;
        this.config.addChangeListener(this::onConfigurationChanged);
        this.trainingExecutor = Executors.newFixedThreadPool(this.config.getMaxConcurrentTraining(), runnable -> {
            Thread thread = new Thread(runnable, "RL-Training-Thread");
            thread.setDaemon(true);
            thread.setPriority(4);
            return thread;
        });
        initializeAdamParameters();
        ReinforceMC.LOGGER.info("FIXED: Asynchronous Adam optimizer initialized with {} training threads and config integration", Integer.valueOf(this.config.getMaxConcurrentTraining()));
    }

    private void onConfigurationChanged(TrainingConfiguration trainingConfiguration) {
        ReinforceMC.LOGGER.info("Training configuration updated: {}", trainingConfiguration.getConfigurationSummary());
    }

    private void initializeAdamParameters() {
        int hiddenLayer1Size = this.config.getHiddenLayer1Size();
        int hiddenLayer2Size = this.config.getHiddenLayer2Size();
        this.m_weights1 = new float[this.network.getTotalInputSize()][hiddenLayer1Size];
        this.v_weights1 = new float[this.network.getTotalInputSize()][hiddenLayer1Size];
        this.m_biases1 = new float[hiddenLayer1Size];
        this.v_biases1 = new float[hiddenLayer1Size];
        this.m_weights2 = new float[hiddenLayer1Size][hiddenLayer2Size];
        this.v_weights2 = new float[hiddenLayer1Size][hiddenLayer2Size];
        this.m_biases2 = new float[hiddenLayer2Size];
        this.v_biases2 = new float[hiddenLayer2Size];
        this.m_weights3 = new float[hiddenLayer2Size][this.network.getOutputSize()];
        this.v_weights3 = new float[hiddenLayer2Size][this.network.getOutputSize()];
        this.m_biases3 = new float[this.network.getOutputSize()];
        this.v_biases3 = new float[this.network.getOutputSize()];
    }

    public CompletableFuture<Void> trainOnBatchAsync() {
        if (this.pendingTrainingTasks >= this.config.getMaxConcurrentTraining()) {
            ReinforceMC.LOGGER.debug("Skipping training - too many pending operations: {}", Integer.valueOf(this.pendingTrainingTasks));
            return CompletableFuture.completedFuture(null);
        }
        int currentSize = this.experienceBuffer.getCurrentSize();
        boolean canTrain = this.experienceBuffer.canTrain();
        if (this.optimizerStep % 100 == 0) {
            ReinforceMC.LOGGER.info("DEBUG: Training attempt - bufferSize={}, canTrain={}, pendingTasks={}, step={}", Integer.valueOf(currentSize), Boolean.valueOf(canTrain), Integer.valueOf(this.pendingTrainingTasks), Integer.valueOf(this.optimizerStep));
        }
        if (this.experienceBuffer.canTrain()) {
            this.pendingTrainingTasks++;
            return CompletableFuture.runAsync(() -> {
                try {
                    if (this.isTraining.compareAndSet(false, true)) {
                        trainOnBatchInternal();
                        this.totalTrainingOperations++;
                        this.lastTrainingTime = System.currentTimeMillis();
                        if (this.optimizerStep % 50 == 0) {
                            ReinforceMC.LOGGER.info("SUCCESSFUL training step {} completed - weights updated", Integer.valueOf(this.optimizerStep));
                        }
                    } else {
                        ReinforceMC.LOGGER.debug("Training already in progress, skipping");
                    }
                } catch (Exception e) {
                    ReinforceMC.LOGGER.error("CRITICAL: Training failed in background thread", e);
                } finally {
                    this.isTraining.set(false);
                    this.pendingTrainingTasks--;
                }
            }, this.trainingExecutor).exceptionally(th -> {
                ReinforceMC.LOGGER.error("CRITICAL: Async training failed", th);
                this.pendingTrainingTasks--;
                this.isTraining.set(false);
                return null;
            });
        }
        if (this.optimizerStep % TrainingConfiguration.MIN_EPISODE_LENGTH == 0) {
            ReinforceMC.LOGGER.warn("DEBUG: Training blocked - experience buffer cannot train (size={})", Integer.valueOf(currentSize));
        }
        return CompletableFuture.completedFuture(null);
    }

    @Deprecated
    public void trainOnBatch() {
        ReinforceMC.LOGGER.warn("DEPRECATED: Synchronous trainOnBatch() called - this can block main thread! Use trainOnBatchAsync() instead");
        trainOnBatchAsync();
    }

    private void trainOnBatchInternal() {
        int batchSize = this.config.getBatchSize();
        if (batchSize <= 0) {
            batchSize = 32;
            ReinforceMC.LOGGER.warn("BATCH SIZE FIX: Invalid batch size from config ({}), using default 32", Integer.valueOf(this.config.getBatchSize()));
        }
        int currentSize = this.experienceBuffer.getCurrentSize();
        if (currentSize < 16) {
            ReinforceMC.LOGGER.debug("Buffer too small for training: {} < 16", Integer.valueOf(currentSize));
            return;
        }
        List<ExperienceBuffer.Experience> samplePrioritizedBatch = this.experienceBuffer.samplePrioritizedBatch(batchSize);
        if (samplePrioritizedBatch == null) {
            ReinforceMC.LOGGER.debug("Experience buffer returned null batch (buffer size: {})", Integer.valueOf(currentSize));
            return;
        }
        if (samplePrioritizedBatch.isEmpty()) {
            ReinforceMC.LOGGER.warn("RACE CONDITION DETECTED: Experience buffer returned empty batch despite pre-validation (buffer size was: {}, requested: {})", Integer.valueOf(currentSize), Integer.valueOf(batchSize));
            return;
        }
        if (samplePrioritizedBatch.size() == 0) {
            ReinforceMC.LOGGER.error("CRITICAL RACE CONDITION: Batch has size=0 but not isEmpty() - preventing batch_size=0.0 logging!");
            return;
        }
        if (samplePrioritizedBatch.size() != batchSize) {
            ReinforceMC.LOGGER.debug("BATCH SIZE FIX: Requested batch size {}, got actual size {}", Integer.valueOf(batchSize), Integer.valueOf(samplePrioritizedBatch.size()));
        }
        long nanoTime = System.nanoTime();
        this.optimizerStep++;
        float f = 0.0f;
        float[][][] initializeGradients = initializeGradients();
        for (ExperienceBuffer.Experience experience : samplePrioritizedBatch) {
            try {
                NeuralNetwork.NetworkState forwardPassWithState = this.network.forwardPassWithState(experience.state);
                float[] fArr = forwardPassWithState.output;
                float f2 = experience.reward;
                if (experience.nextState != null && !experience.done) {
                    f2 = (experience.reward + (this.config.getDiscountFactor() * this.network.calculateValue(this.network.predict(experience.nextState)))) - this.network.calculateValue(fArr);
                }
                float max = Math.max(-10.0f, Math.min(10.0f, f2));
                f += (-max) * ((float) Math.log(Math.max(fArr[experience.action], 1.0E-8f)));
                calculateGradients(forwardPassWithState, experience.action, max, initializeGradients);
            } catch (Exception e) {
                ReinforceMC.LOGGER.error("Error processing experience in training batch", e);
            }
        }
        synchronized (this) {
            try {
                applyGradientsWithAdam(initializeGradients);
                this.averageLoss = (this.averageLoss * 0.99f) + ((f / samplePrioritizedBatch.size()) * 0.01f);
                this.network.updateWeightMonitoring();
            } catch (Exception e2) {
                ReinforceMC.LOGGER.error("Error applying gradients", e2);
            }
        }
        ReinforceMC.LOGGER.debug("ASYNC training completed: loss={:.4f}, batch_size={}, step={}, time={:.2f}ms", Float.valueOf(f / samplePrioritizedBatch.size()), Integer.valueOf(samplePrioritizedBatch.size()), Integer.valueOf(this.optimizerStep), Float.valueOf(((float) (System.nanoTime() - nanoTime)) / 1000000.0f));
    }

    private void calculateGradients(NeuralNetwork.NetworkState networkState, int i, float f, float[][][] fArr) {
        int hiddenLayer1Size = this.config.getHiddenLayer1Size();
        int hiddenLayer2Size = this.config.getHiddenLayer2Size();
        float[] fArr2 = new float[this.network.getOutputSize()];
        for (int i2 = 0; i2 < this.network.getOutputSize(); i2++) {
            if (i2 == i) {
                fArr2[i2] = f * (1.0f - networkState.output[i2]);
            } else {
                fArr2[i2] = f * (-networkState.output[i2]);
            }
        }
        for (int i3 = 0; i3 < hiddenLayer2Size; i3++) {
            for (int i4 = 0; i4 < this.network.getOutputSize(); i4++) {
                float[] fArr3 = fArr[2][i3];
                int i5 = i4;
                fArr3[i5] = fArr3[i5] + (fArr2[i4] * networkState.hidden2[i3]);
            }
        }
        float[] fArr4 = new float[hiddenLayer2Size];
        for (int i6 = 0; i6 < hiddenLayer2Size; i6++) {
            float f2 = 0.0f;
            for (int i7 = 0; i7 < this.network.getOutputSize(); i7++) {
                f2 += fArr2[i7] * this.network.weights3[i6][i7];
            }
            fArr4[i6] = f2 * this.network.reluDerivative(networkState.hidden2[i6]);
        }
        for (int i8 = 0; i8 < hiddenLayer1Size; i8++) {
            for (int i9 = 0; i9 < hiddenLayer2Size; i9++) {
                float[] fArr5 = fArr[1][i8];
                int i10 = i9;
                fArr5[i10] = fArr5[i10] + (fArr4[i9] * networkState.hidden1[i8]);
            }
        }
        float[] fArr6 = new float[hiddenLayer1Size];
        for (int i11 = 0; i11 < hiddenLayer1Size; i11++) {
            float f3 = 0.0f;
            for (int i12 = 0; i12 < hiddenLayer2Size; i12++) {
                f3 += fArr4[i12] * this.network.weights2[i11][i12];
            }
            fArr6[i11] = f3 * this.network.reluDerivative(networkState.hidden1[i11]);
        }
        for (int i13 = 0; i13 < this.network.getTotalInputSize(); i13++) {
            for (int i14 = 0; i14 < hiddenLayer1Size; i14++) {
                float[] fArr7 = fArr[0][i13];
                int i15 = i14;
                fArr7[i15] = fArr7[i15] + (fArr6[i14] * networkState.input[i13]);
            }
        }
    }

    private synchronized void applyGradientsWithAdam(float[][][] fArr) {
        int hiddenLayer1Size = this.config.getHiddenLayer1Size();
        int hiddenLayer2Size = this.config.getHiddenLayer2Size();
        float learningRate = this.config.getLearningRate();
        float pow = (float) Math.pow(this.config.getBeta1(), this.optimizerStep);
        float pow2 = (float) Math.pow(this.config.getBeta2(), this.optimizerStep);
        float sqrt = (learningRate * ((float) Math.sqrt(1.0d - pow2))) / (1.0f - pow);
        for (int i = 0; i < this.network.getTotalInputSize(); i++) {
            for (int i2 = 0; i2 < hiddenLayer1Size; i2++) {
                float f = fArr[0][i][i2];
                this.m_weights1[i][i2] = (this.config.getBeta1() * this.m_weights1[i][i2]) + ((1.0f - this.config.getBeta1()) * f);
                this.v_weights1[i][i2] = (this.config.getBeta2() * this.v_weights1[i][i2]) + ((1.0f - this.config.getBeta2()) * f * f);
                float f2 = this.m_weights1[i][i2] / (1.0f - pow);
                float f3 = this.v_weights1[i][i2] / (1.0f - pow2);
                float[] fArr2 = this.network.weights1[i];
                int i3 = i2;
                fArr2[i3] = fArr2[i3] + ((sqrt * f2) / (((float) Math.sqrt(f3)) + this.config.getEpsilonAdam()));
            }
        }
        for (int i4 = 0; i4 < hiddenLayer1Size; i4++) {
            for (int i5 = 0; i5 < hiddenLayer2Size; i5++) {
                float f4 = fArr[1][i4][i5];
                this.m_weights2[i4][i5] = (this.config.getBeta1() * this.m_weights2[i4][i5]) + ((1.0f - this.config.getBeta1()) * f4);
                this.v_weights2[i4][i5] = (this.config.getBeta2() * this.v_weights2[i4][i5]) + ((1.0f - this.config.getBeta2()) * f4 * f4);
                float f5 = this.m_weights2[i4][i5] / (1.0f - pow);
                float f6 = this.v_weights2[i4][i5] / (1.0f - pow2);
                float[] fArr3 = this.network.weights2[i4];
                int i6 = i5;
                fArr3[i6] = fArr3[i6] + ((sqrt * f5) / (((float) Math.sqrt(f6)) + this.config.getEpsilonAdam()));
            }
        }
        for (int i7 = 0; i7 < hiddenLayer2Size; i7++) {
            for (int i8 = 0; i8 < this.network.getOutputSize(); i8++) {
                float f7 = fArr[2][i7][i8];
                this.m_weights3[i7][i8] = (this.config.getBeta1() * this.m_weights3[i7][i8]) + ((1.0f - this.config.getBeta1()) * f7);
                this.v_weights3[i7][i8] = (this.config.getBeta2() * this.v_weights3[i7][i8]) + ((1.0f - this.config.getBeta2()) * f7 * f7);
                float f8 = this.m_weights3[i7][i8] / (1.0f - pow);
                float f9 = this.v_weights3[i7][i8] / (1.0f - pow2);
                float[] fArr4 = this.network.weights3[i7];
                int i9 = i8;
                fArr4[i9] = fArr4[i9] + ((sqrt * f8) / (((float) Math.sqrt(f9)) + this.config.getEpsilonAdam()));
            }
        }
        ReinforceMC.LOGGER.debug("ASYNC weight update applied with Adam optimizer (step={})", Integer.valueOf(this.optimizerStep));
    }

    /* JADX WARN: Type inference failed for: r0v7, types: [float[][], float[][][]] */
    private float[][][] initializeGradients() {
        int hiddenLayer1Size = this.config.getHiddenLayer1Size();
        int hiddenLayer2Size = this.config.getHiddenLayer2Size();
        return new float[][]{new float[this.network.getTotalInputSize()][hiddenLayer1Size], new float[hiddenLayer1Size][hiddenLayer2Size], new float[hiddenLayer2Size][this.network.getOutputSize()]};
    }

    public boolean shouldTrain(int i) {
        return i % this.config.getTrainingFrequency() == 0 && this.experienceBuffer.canTrain() && this.pendingTrainingTasks < this.config.getMaxConcurrentTraining();
    }

    public void shutdown() {
        ReinforceMC.LOGGER.info("Shutting down async training executor...");
        this.trainingExecutor.shutdown();
        try {
            if (!this.trainingExecutor.awaitTermination(5L, TimeUnit.SECONDS)) {
                ReinforceMC.LOGGER.warn("Training executor did not terminate gracefully, forcing shutdown");
                this.trainingExecutor.shutdownNow();
            }
        } catch (InterruptedException e) {
            ReinforceMC.LOGGER.warn("Interrupted while waiting for training executor shutdown");
            this.trainingExecutor.shutdownNow();
            Thread.currentThread().interrupt();
        }
    }

    public float getAverageLoss() {
        return this.averageLoss;
    }

    public int getOptimizerStep() {
        return this.optimizerStep;
    }

    public boolean isCurrentlyTraining() {
        return this.isTraining.get();
    }

    public int getPendingTrainingTasks() {
        return this.pendingTrainingTasks;
    }

    public long getLastTrainingTime() {
        return this.lastTrainingTime;
    }

    public int getTotalTrainingOperations() {
        return this.totalTrainingOperations;
    }

    public static float getDiscountFactor() {
        return TrainingConfiguration.getInstance().getDiscountFactor();
    }

    public TrainerData exportData() {
        TrainerData trainerData = new TrainerData();
        trainerData.optimizerStep = this.optimizerStep;
        trainerData.averageLoss = this.averageLoss;
        trainerData.totalTrainingOperations = this.totalTrainingOperations;
        trainerData.m_weights1 = cloneMatrix(this.m_weights1);
        trainerData.v_weights1 = cloneMatrix(this.v_weights1);
        trainerData.m_biases1 = (float[]) this.m_biases1.clone();
        trainerData.v_biases1 = (float[]) this.v_biases1.clone();
        trainerData.m_weights2 = cloneMatrix(this.m_weights2);
        trainerData.v_weights2 = cloneMatrix(this.v_weights2);
        trainerData.m_biases2 = (float[]) this.m_biases2.clone();
        trainerData.v_biases2 = (float[]) this.v_biases2.clone();
        trainerData.m_weights3 = cloneMatrix(this.m_weights3);
        trainerData.v_weights3 = cloneMatrix(this.v_weights3);
        trainerData.m_biases3 = (float[]) this.m_biases3.clone();
        trainerData.v_biases3 = (float[]) this.v_biases3.clone();
        return trainerData;
    }

    public void loadFromData(TrainerData trainerData) {
        if (trainerData == null) {
            ReinforceMC.LOGGER.warn("Cannot load trainer data: data is null");
            return;
        }
        this.optimizerStep = trainerData.optimizerStep;
        this.averageLoss = trainerData.averageLoss;
        this.totalTrainingOperations = trainerData.totalTrainingOperations;
        if (trainerData.m_weights1 != null) {
            this.m_weights1 = cloneMatrix(trainerData.m_weights1);
        }
        if (trainerData.v_weights1 != null) {
            this.v_weights1 = cloneMatrix(trainerData.v_weights1);
        }
        if (trainerData.m_biases1 != null) {
            this.m_biases1 = (float[]) trainerData.m_biases1.clone();
        }
        if (trainerData.v_biases1 != null) {
            this.v_biases1 = (float[]) trainerData.v_biases1.clone();
        }
        if (trainerData.m_weights2 != null) {
            this.m_weights2 = cloneMatrix(trainerData.m_weights2);
        }
        if (trainerData.v_weights2 != null) {
            this.v_weights2 = cloneMatrix(trainerData.v_weights2);
        }
        if (trainerData.m_biases2 != null) {
            this.m_biases2 = (float[]) trainerData.m_biases2.clone();
        }
        if (trainerData.v_biases2 != null) {
            this.v_biases2 = (float[]) trainerData.v_biases2.clone();
        }
        if (trainerData.m_weights3 != null) {
            this.m_weights3 = cloneMatrix(trainerData.m_weights3);
        }
        if (trainerData.v_weights3 != null) {
            this.v_weights3 = cloneMatrix(trainerData.v_weights3);
        }
        if (trainerData.m_biases3 != null) {
            this.m_biases3 = (float[]) trainerData.m_biases3.clone();
        }
        if (trainerData.v_biases3 != null) {
            this.v_biases3 = (float[]) trainerData.v_biases3.clone();
        }
        ReinforceMC.LOGGER.info("Neural network trainer state loaded from saved data");
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [float[], float[][]] */
    private float[][] cloneMatrix(float[][] fArr) {
        if (fArr == null) {
            return (float[][]) null;
        }
        ?? r0 = new float[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            if (fArr[i] != null) {
                r0[i] = (float[]) fArr[i].clone();
            }
        }
        return r0;
    }
}
