package ai.topandrey15.reinforcemc.core;

import ai.topandrey15.reinforcemc.ReinforceMC;
import java.io.Serializable;

/* loaded from: input_file:ai/topandrey15/reinforcemc/core/NetworkDataManager.class */
public class NetworkDataManager {
    private float lastWeightsSum = 0.0f;
    private long weightsUpdateCount = 0;
    private float totalWeightChange = 0.0f;

    /* loaded from: input_file:ai/topandrey15/reinforcemc/core/NetworkDataManager$NetworkData.class */
    public static class NetworkData implements Serializable {
        private static final long serialVersionUID = 2;
        public int outputSize;
        public float[] biases1;
        public float[] biases2;
        public float[] biases3;
        public float[] weights1Flat;
        public int weights1Rows;
        public int weights1Cols;
        public float[] weights2Flat;
        public int weights2Rows;
        public int weights2Cols;
        public float[] weights3Flat;
        public int weights3Rows;
        public int weights3Cols;
    }

    /* loaded from: input_file:ai/topandrey15/reinforcemc/core/NetworkDataManager$WeightStats.class */
    public static class WeightStats {
        public float currentWeightsSum;
        public float totalWeightChange;
        public long updateCount;
        public float averageChangePerUpdate;
    }

    public NetworkData exportData(float[][] fArr, float[] fArr2, float[][] fArr3, float[] fArr4, float[][] fArr5, float[] fArr6, int i) {
        NetworkData networkData = new NetworkData();
        networkData.outputSize = i;
        networkData.biases1 = (float[]) fArr2.clone();
        networkData.biases2 = (float[]) fArr4.clone();
        networkData.biases3 = (float[]) fArr6.clone();
        networkData.weights1Flat = flattenMatrix(fArr);
        networkData.weights1Rows = fArr.length;
        networkData.weights1Cols = fArr[0].length;
        networkData.weights2Flat = flattenMatrix(fArr3);
        networkData.weights2Rows = fArr3.length;
        networkData.weights2Cols = fArr3[0].length;
        networkData.weights3Flat = flattenMatrix(fArr5);
        networkData.weights3Rows = fArr5.length;
        networkData.weights3Cols = fArr5[0].length;
        return networkData;
    }

    public boolean loadFromData(NetworkData networkData, float[][] fArr, float[] fArr2, float[][] fArr3, float[] fArr4, float[][] fArr5, float[] fArr6, int i) {
        if (networkData == null || networkData.outputSize != i) {
            ReinforceMC.LOGGER.warn("Cannot load network data: incompatible structure");
            return false;
        }
        float[][] reconstructMatrix = reconstructMatrix(networkData.weights1Flat, networkData.weights1Rows, networkData.weights1Cols);
        float[][] reconstructMatrix2 = reconstructMatrix(networkData.weights2Flat, networkData.weights2Rows, networkData.weights2Cols);
        float[][] reconstructMatrix3 = reconstructMatrix(networkData.weights3Flat, networkData.weights3Rows, networkData.weights3Cols);
        if (reconstructMatrix == null || reconstructMatrix2 == null || reconstructMatrix3 == null) {
            ReinforceMC.LOGGER.error("Failed to reconstruct weight matrices");
            return false;
        }
        copyMatrix(reconstructMatrix, fArr);
        copyMatrix(reconstructMatrix2, fArr3);
        copyMatrix(reconstructMatrix3, fArr5);
        System.arraycopy(networkData.biases1, 0, fArr2, 0, networkData.biases1.length);
        System.arraycopy(networkData.biases2, 0, fArr4, 0, networkData.biases2.length);
        System.arraycopy(networkData.biases3, 0, fArr6, 0, networkData.biases3.length);
        updateWeightMonitoring(fArr, fArr2, fArr3, fArr4, fArr5, fArr6);
        ReinforceMC.LOGGER.info("Neural network weights loaded from saved data (weights sum: {:.6f})", Float.valueOf(calculateCurrentWeightsSum(fArr, fArr2, fArr3, fArr4, fArr5, fArr6)));
        return true;
    }

    public float calculateCurrentWeightsSum(float[][] fArr, float[] fArr2, float[][] fArr3, float[] fArr4, float[][] fArr5, float[] fArr6) {
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            for (int i2 = 0; i2 < fArr[i].length; i2++) {
                f += Math.abs(fArr[i][i2]);
            }
        }
        for (int i3 = 0; i3 < fArr3.length; i3++) {
            for (int i4 = 0; i4 < fArr3[i3].length; i4++) {
                f += Math.abs(fArr3[i3][i4]);
            }
        }
        for (int i5 = 0; i5 < fArr5.length; i5++) {
            for (int i6 = 0; i6 < fArr5[i5].length; i6++) {
                f += Math.abs(fArr5[i5][i6]);
            }
        }
        for (float f2 : fArr2) {
            f += Math.abs(f2);
        }
        for (float f3 : fArr4) {
            f += Math.abs(f3);
        }
        for (float f4 : fArr6) {
            f += Math.abs(f4);
        }
        return f;
    }

    public void updateWeightMonitoring(float[][] fArr, float[] fArr2, float[][] fArr3, float[] fArr4, float[][] fArr5, float[] fArr6) {
        float calculateCurrentWeightsSum = calculateCurrentWeightsSum(fArr, fArr2, fArr3, fArr4, fArr5, fArr6);
        if (this.weightsUpdateCount > 0) {
            float abs = Math.abs(calculateCurrentWeightsSum - this.lastWeightsSum);
            this.totalWeightChange += abs;
            if (abs > 0.01f && this.weightsUpdateCount % 100 == 0) {
                ReinforceMC.LOGGER.info("Weight change detected: {:.6f} (total change: {:.6f}, updates: {})", Float.valueOf(abs), Float.valueOf(this.totalWeightChange), Long.valueOf(this.weightsUpdateCount));
            }
        }
        this.lastWeightsSum = calculateCurrentWeightsSum;
        this.weightsUpdateCount++;
    }

    public WeightStats getWeightStats(float[][] fArr, float[] fArr2, float[][] fArr3, float[] fArr4, float[][] fArr5, float[] fArr6) {
        WeightStats weightStats = new WeightStats();
        weightStats.currentWeightsSum = calculateCurrentWeightsSum(fArr, fArr2, fArr3, fArr4, fArr5, fArr6);
        weightStats.totalWeightChange = this.totalWeightChange;
        weightStats.updateCount = this.weightsUpdateCount;
        weightStats.averageChangePerUpdate = this.weightsUpdateCount > 0 ? this.totalWeightChange / ((float) this.weightsUpdateCount) : 0.0f;
        return weightStats;
    }

    private float[] flattenMatrix(float[][] fArr) {
        if (fArr == null) {
            return null;
        }
        int length = fArr.length;
        int length2 = fArr[0].length;
        float[] fArr2 = new float[length * length2];
        int i = 0;
        for (float[] fArr3 : fArr) {
            for (int i2 = 0; i2 < length2; i2++) {
                int i3 = i;
                i++;
                fArr2[i3] = fArr3[i2];
            }
        }
        return fArr2;
    }

    private float[][] reconstructMatrix(float[] fArr, int i, int i2) {
        if (fArr == null) {
            return (float[][]) null;
        }
        if (fArr.length != i * i2) {
            ReinforceMC.LOGGER.error("Matrix reconstruction error: Expected {} elements, got {}", Integer.valueOf(i * i2), Integer.valueOf(fArr.length));
            return new float[i][i2];
        }
        float[][] fArr2 = new float[i][i2];
        int i3 = 0;
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 = 0; i5 < i2; i5++) {
                int i6 = i3;
                i3++;
                fArr2[i4][i5] = fArr[i6];
            }
        }
        return fArr2;
    }

    private void copyMatrix(float[][] fArr, float[][] fArr2) {
        for (int i = 0; i < fArr.length; i++) {
            System.arraycopy(fArr[i], 0, fArr2[i], 0, fArr[i].length);
        }
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [float[], float[][]] */
    public float[][] cloneMatrix(float[][] fArr) {
        if (fArr == null) {
            return (float[][]) null;
        }
        ?? r0 = new float[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            r0[i] = (float[]) fArr[i].clone();
        }
        return r0;
    }
}
