package com.mna.tools.manaweave.neural;

/* loaded from: input_file:com/mna/tools/manaweave/neural/TrainSelfOrganizingMap.class */
public class TrainSelfOrganizingMap {
    private SelfOrganizingMap som;
    private LearningMethod learnMethod;
    private double learnRate;
    private double reduction = 0.99d;
    private double globalError;
    private double totalError;
    private double bestError;
    private final int inputNeuronCount;
    private final int outputNeuronCount;
    private final SelfOrganizingMap bestNet;
    private double[][] train;
    private int[] won;
    private Matrix work;
    private Matrix correc;

    public TrainSelfOrganizingMap(SelfOrganizingMap selfOrganizingMap, double[][] dArr, LearningMethod learningMethod, double d) {
        this.som = selfOrganizingMap;
        this.train = dArr;
        this.totalError = 1.0d;
        this.learnMethod = learningMethod;
        this.learnRate = d;
        this.outputNeuronCount = selfOrganizingMap.countOutputNeurons();
        this.inputNeuronCount = selfOrganizingMap.countInputNeurons();
        this.totalError = 1.0d;
        for (int i = 0; i < this.train.length; i++) {
            if (MatrixMath.vectorLength(Matrix.createColumnMatrix(this.train[i])) < 1.0E-30d) {
                throw new RuntimeException("Multiplicative normalization has null training case");
            }
        }
        this.bestNet = new SelfOrganizingMap(this.inputNeuronCount, this.outputNeuronCount, this.som.getNormalizationType());
        this.won = new int[this.outputNeuronCount];
        this.correc = new Matrix(this.outputNeuronCount, this.inputNeuronCount + 1);
        if (this.learnMethod == LearningMethod.ADDITIVE) {
            this.work = new Matrix(1, this.inputNeuronCount + 1);
        } else {
            this.work = null;
        }
        initialize();
        this.bestError = Double.MAX_VALUE;
    }

    public void initialize() {
        this.som.getOutputWeights().randomize(-1.0d, 1.0d);
        for (int i = 0; i < this.outputNeuronCount; i++) {
            normalizeWeight(this.som.getOutputWeights(), i);
        }
    }

    public double getBestError() {
        return this.bestError;
    }

    public double getTotalError() {
        return this.totalError;
    }

    protected void adjustWeights() {
        for (int i = 0; i < this.outputNeuronCount; i++) {
            if (this.won[i] != 0) {
                double d = 1.0d / this.won[i];
                if (this.learnMethod == LearningMethod.SUBTRACTIVE) {
                    d *= this.learnRate;
                }
                for (int i2 = 0; i2 <= this.inputNeuronCount; i2++) {
                    this.som.getOutputWeights().add(i, i2, d * this.correc.get(i, i2));
                }
            }
        }
    }

    private void copyWeights(SelfOrganizingMap selfOrganizingMap, SelfOrganizingMap selfOrganizingMap2) {
        MatrixMath.copy(selfOrganizingMap.getOutputWeights(), selfOrganizingMap2.getOutputWeights());
    }

    void evaluateErrors() throws RuntimeException {
        this.correc.clear();
        for (int i = 0; i < this.won.length; i++) {
            this.won[i] = 0;
        }
        this.globalError = 0.0d;
        for (int i2 = 0; i2 < this.train.length; i2++) {
            NormalizeInput normalizeInput = new NormalizeInput(this.train[i2], this.som.getNormalizationType());
            int winner = this.som.winner(normalizeInput);
            int[] iArr = this.won;
            iArr[winner] = iArr[winner] + 1;
            Matrix row = this.som.getOutputWeights().getRow(winner);
            double d = 0.0d;
            for (int i3 = 0; i3 < this.inputNeuronCount; i3++) {
                double normFac = (this.train[i2][i3] * normalizeInput.getNormFac()) - row.get(0, i3);
                d += normFac * normFac;
                if (this.learnMethod == LearningMethod.SUBTRACTIVE) {
                    this.correc.add(winner, i3, normFac);
                } else {
                    this.work.set(0, i3, (this.learnRate * this.train[i2][i3] * normalizeInput.getNormFac()) + row.get(0, i3));
                }
            }
            double synth = normalizeInput.getSynth() - row.get(0, this.inputNeuronCount);
            double d2 = d + (synth * synth);
            if (this.learnMethod == LearningMethod.SUBTRACTIVE) {
                this.correc.add(winner, this.inputNeuronCount, synth);
            } else {
                this.work.set(0, this.inputNeuronCount, (this.learnRate * normalizeInput.getSynth()) + row.get(0, this.inputNeuronCount));
            }
            if (d2 > this.globalError) {
                this.globalError = d2;
            }
            if (this.learnMethod == LearningMethod.ADDITIVE) {
                normalizeWeight(this.work, 0);
                for (int i4 = 0; i4 <= this.inputNeuronCount; i4++) {
                    this.correc.add(winner, i4, this.work.get(0, i4) - row.get(0, i4));
                }
            }
        }
        this.globalError = Math.sqrt(this.globalError);
    }

    void forceWin() throws RuntimeException {
        int i = 0;
        Matrix outputWeights = this.som.getOutputWeights();
        double d = Double.MAX_VALUE;
        for (int i2 = 0; i2 < this.train.length; i2++) {
            int winner = this.som.winner(this.train[i2]);
            double[] output = this.som.getOutput();
            if (output[winner] < d) {
                d = output[winner];
                i = i2;
            }
        }
        NormalizeInput normalizeInput = new NormalizeInput(this.train[i], this.som.getNormalizationType());
        this.som.winner(normalizeInput);
        double[] output2 = this.som.getOutput();
        double d2 = Double.MIN_VALUE;
        int i3 = this.outputNeuronCount;
        while (true) {
            int i4 = i3;
            i3--;
            if (i4 <= 0) {
                break;
            }
            if (this.won[i3] == 0 && output2[i3] > d2) {
                d2 = output2[i3];
                i = i3;
            }
        }
        for (int i5 = 0; i5 < normalizeInput.getInputMatrix().getCols(); i5++) {
            outputWeights.set(i, i5, normalizeInput.getInputMatrix().get(0, i5));
        }
        normalizeWeight(outputWeights, i);
    }

    public void iteration() throws RuntimeException {
        evaluateErrors();
        this.totalError = this.globalError;
        if (this.totalError < this.bestError) {
            this.bestError = this.totalError;
            copyWeights(this.som, this.bestNet);
        }
        int i = 0;
        for (int i2 = 0; i2 < this.won.length; i2++) {
            if (this.won[i2] != 0) {
                i++;
            }
        }
        if (i < this.outputNeuronCount && i < this.train.length) {
            forceWin();
            return;
        }
        adjustWeights();
        if (this.learnRate > 0.01d) {
            this.learnRate *= this.reduction;
        }
    }

    protected void normalizeWeight(Matrix matrix, int i) {
        double max = 1.0d / Math.max(MatrixMath.vectorLength(matrix.getRow(i)), 1.0E-30d);
        for (int i2 = 0; i2 < this.inputNeuronCount; i2++) {
            matrix.set(i, i2, matrix.get(i, i2) * max);
        }
        matrix.set(i, this.inputNeuronCount, 0.0d);
    }
}
