/*
 * Decompiled with CFR 0.152.
 */
package com.pulsar.soulforge.ai;

import com.pulsar.soulforge.SoulForge;
import java.util.Random;

public class NeuralNetwork {
    public int[] layerSizes;
    public double[][][] weights;
    public double[][] biases;
    public double[][] activations;
    public double[][] z;

    public NeuralNetwork(int[] layerSizes) {
        this.layerSizes = layerSizes;
        this.initializeNetwork();
    }

    public NeuralNetwork(int[] layerSizes, double[][][] weights, double[][] biases) {
        this.layerSizes = layerSizes;
        SoulForge.LOGGER.info("creating network with layer sizes: {}", (Object)layerSizes);
        this.weights = weights;
        this.biases = biases;
        SoulForge.LOGGER.info("weights: {}, biases: {}", (Object)weights, (Object)biases);
        this.activations = new double[layerSizes.length][];
        this.z = new double[layerSizes.length - 1][];
        for (int i = 0; i < layerSizes.length - 1; ++i) {
            this.activations[i] = new double[layerSizes[i]];
            this.z[i] = new double[layerSizes[i + 1]];
        }
        this.activations[layerSizes.length - 1] = new double[layerSizes[layerSizes.length - 1]];
    }

    private void initializeNetwork() {
        Random rand = new Random();
        this.weights = new double[this.layerSizes.length - 1][][];
        this.biases = new double[this.layerSizes.length - 1][];
        this.activations = new double[this.layerSizes.length][];
        this.z = new double[this.layerSizes.length - 1][];
        for (int i = 0; i < this.layerSizes.length - 1; ++i) {
            this.weights[i] = new double[this.layerSizes[i + 1]][this.layerSizes[i]];
            this.biases[i] = new double[this.layerSizes[i + 1]];
            this.activations[i] = new double[this.layerSizes[i]];
            this.z[i] = new double[this.layerSizes[i + 1]];
            for (int j = 0; j < this.layerSizes[i + 1]; ++j) {
                this.biases[i][j] = rand.nextGaussian();
                for (int k = 0; k < this.layerSizes[i]; ++k) {
                    this.weights[i][j][k] = rand.nextGaussian();
                }
            }
        }
        this.activations[this.layerSizes.length - 1] = new double[this.layerSizes[this.layerSizes.length - 1]];
    }

    private double sigmoid(double x) {
        return 1.0 / (1.0 + Math.exp(-x));
    }

    private double sigmoidDerivative(double x) {
        return this.sigmoid(x) * (1.0 - this.sigmoid(x));
    }

    public double[] feedforward(double[] input) {
        this.activations[0] = input;
        for (int i = 0; i < this.weights.length; ++i) {
            for (int j = 0; j < this.weights[i].length; ++j) {
                this.z[i][j] = this.biases[i][j];
                for (int k = 0; k < this.weights[i][j].length; ++k) {
                    double[] dArray = this.z[i];
                    int n = j;
                    dArray[n] = dArray[n] + this.weights[i][j][k] * this.activations[i][k];
                }
                this.activations[i + 1][j] = this.sigmoid(this.z[i][j]);
            }
        }
        return this.activations[this.activations.length - 1];
    }

    public void backpropagate(double[] input, double[] target, double learningRate) {
        int i;
        this.feedforward(input);
        double[][] delta = new double[this.layerSizes.length - 1][];
        for (i = 0; i < delta.length; ++i) {
            delta[i] = new double[this.layerSizes[i + 1]];
        }
        for (i = 0; i < delta[delta.length - 1].length; ++i) {
            double error = this.activations[this.activations.length - 1][i] - target[i];
            delta[delta.length - 1][i] = error * this.sigmoidDerivative(this.z[this.z.length - 1][i]);
        }
        for (i = delta.length - 2; i >= 0; --i) {
            for (int j = 0; j < delta[i].length; ++j) {
                double error = 0.0;
                for (int k = 0; k < delta[i + 1].length; ++k) {
                    error += delta[i + 1][k] * this.weights[i + 1][k][j];
                }
                delta[i][j] = error * this.sigmoidDerivative(this.z[i][j]);
            }
        }
        for (i = 0; i < this.weights.length; ++i) {
            for (int j = 0; j < this.weights[i].length; ++j) {
                for (int k = 0; k < this.weights[i][j].length; ++k) {
                    double[] dArray = this.weights[i][j];
                    int n = k;
                    dArray[n] = dArray[n] - learningRate * delta[i][j] * this.activations[i][k];
                }
                double[] dArray = this.biases[i];
                int n = j;
                dArray[n] = dArray[n] - learningRate * delta[i][j];
            }
        }
    }
}

