package edu.cmu.sphinx.linguist.acoustic.tiedstate.HTK;

import edu.cmu.sphinx.util.LogMath;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.StringTokenizer;

/* loaded from: input_file:META-INF/jars/sphinx4-core-5prealpha-SNAPSHOT.jar:edu/cmu/sphinx/linguist/acoustic/tiedstate/HTK/GMMDiag.class */
public class GMMDiag {
    public int nT;
    public String nom;
    public LogMath logMath;
    private int ncoefs;
    private int ngauss;
    protected float[] weights;
    protected float[][] means;
    protected float[][] covar;
    private float[] logPreComputedGaussianFactor;
    protected float[] loglikes;
    private static final float distFloor = -3.4028235E38f;

    public GMMDiag() {
    }

    public GMMDiag(int i, int i2) {
        this.ngauss = i;
        this.ncoefs = i2;
        allocate();
    }

    public int getNgauss() {
        return this.ngauss;
    }

    public float getWeight(int i) {
        return (float) this.logMath.logToLinear(this.weights[i]);
    }

    public float getVar(int i, int i2) {
        return (-1.0f) / (2.0f * this.covar[i][i2]);
    }

    public void setWeight(int i, float f) {
        if (this.weights == null) {
            this.weights = new float[this.ngauss];
        }
        this.weights[i] = this.logMath.linearToLog(f);
    }

    public void setVar(int i, int i2, float f) {
        if (f <= 0.0f) {
            System.err.println("WARNING: setVar " + f);
        }
        this.covar[i][i2] = (-1.0f) / (2.0f * f);
    }

    public void setMean(int i, int i2, float f) {
        this.means[i][i2] = f;
    }

    public float getMean(int i, int i2) {
        return this.means[i][i2];
    }

    public void save(String str) {
        try {
            PrintWriter printWriter = new PrintWriter(new FileWriter(str));
            printWriter.println(this.ngauss + " " + this.ncoefs);
            for (int i = 0; i < this.ngauss; i++) {
                printWriter.println("gauss " + i + ' ' + getWeight(i));
                for (int i2 = 0; i2 < this.ncoefs; i2++) {
                    printWriter.print(this.means[i][i2] + " ");
                }
                printWriter.println();
                for (int i3 = 0; i3 < this.ncoefs; i3++) {
                    printWriter.print(getVar(i, i3) + " ");
                }
                printWriter.println();
            }
            printWriter.println(this.nT);
            printWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void load(String str) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            String[] split = bufferedReader.readLine().split(" ");
            this.ngauss = Integer.parseInt(split[0]);
            this.ncoefs = Integer.parseInt(split[1]);
            allocate();
            for (int i = 0; i < this.ngauss; i++) {
                String readLine = bufferedReader.readLine();
                String[] split2 = readLine.split(" ");
                if (!split2[0].equals("gauss") || Integer.parseInt(split2[1]) != i) {
                    System.err.println("Error loading GMM " + readLine + ' ' + i);
                    System.exit(1);
                }
                setWeight(i, Float.parseFloat(split2[2]));
                String[] split3 = bufferedReader.readLine().split(" ");
                for (int i2 = 0; i2 < this.ncoefs; i2++) {
                    setMean(i, i2, Float.parseFloat(split3[i2]));
                }
                String[] split4 = bufferedReader.readLine().split(" ");
                for (int i3 = 0; i3 < this.ncoefs; i3++) {
                    setVar(i, i3, Float.parseFloat(split4[i3]));
                }
            }
            String readLine2 = bufferedReader.readLine();
            if (readLine2 != null) {
                this.nT = Integer.parseInt(readLine2);
            }
            bufferedReader.close();
            precomputeDistance();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void saveHTK(String str, String str2) {
        saveHTK(str, str2, "<USER>");
    }

    public PrintWriter saveHTKheader(String str, String str2) {
        try {
            PrintWriter printWriter = new PrintWriter(new FileWriter(str));
            printWriter.println("~o");
            printWriter.println("<HMMSETID> tree");
            printWriter.println("<STREAMINFO> 1 " + getNcoefs());
            printWriter.println("<VECSIZE> " + getNcoefs() + "<NULLD>" + str2 + "<DIAGC>");
            printWriter.println("~r \"rtree_1\"");
            printWriter.println("<REGTREE> 1");
            printWriter.println("<TNODE> 1 " + getNgauss());
            return printWriter;
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    public void saveHTKState(PrintWriter printWriter) {
        printWriter.println("<NUMMIXES> " + getNgauss());
        for (int i = 1; i <= getNgauss(); i++) {
            printWriter.println("<MIXTURE> " + i + ' ' + getWeight(i - 1));
            printWriter.println("<RCLASS> 1");
            printWriter.println("<MEAN> " + getNcoefs());
            for (int i2 = 0; i2 < getNcoefs(); i2++) {
                printWriter.print(getMean(i - 1, i2) + " ");
            }
            printWriter.println();
            printWriter.println("<VARIANCE> " + getNcoefs());
            for (int i3 = 0; i3 < getNcoefs(); i3++) {
                printWriter.print(getVar(i - 1, i3) + " ");
            }
            printWriter.println();
        }
    }

    public void saveHTKtailer(int i, PrintWriter printWriter) {
        printWriter.println("<TRANSP> " + i);
        for (int i2 = 0; i2 < i; i2++) {
            printWriter.print("0 ");
        }
        printWriter.println();
        for (int i3 = 1; i3 < i - 1; i3++) {
            for (int i4 = 0; i4 < i3; i4++) {
                printWriter.print("0 ");
            }
            printWriter.print("0.5 0.5");
            for (int i5 = i3 + 3; i5 < i; i5++) {
                printWriter.print("0 ");
            }
        }
        printWriter.println();
        printWriter.println("0 0 0");
        printWriter.println("<ENDHMM>");
    }

    public void saveHTK(String str, String str2, String str3) {
        try {
            PrintWriter printWriter = new PrintWriter(new FileWriter(str));
            printWriter.println("~o");
            printWriter.println("<HMMSETID> tree");
            printWriter.println("<STREAMINFO> 1 " + getNcoefs());
            printWriter.println("<VECSIZE> " + getNcoefs() + "<NULLD>" + str3 + "<DIAGC>");
            printWriter.println("~r \"rtree_1\"");
            printWriter.println("<REGTREE> 1");
            printWriter.println("<TNODE> 1 " + getNgauss());
            printWriter.println("~h \"" + str2 + '\"');
            printWriter.println("<BEGINHMM>");
            printWriter.println("<NUMSTATES> 3");
            printWriter.println("<STATE> 2");
            printWriter.println("<NUMMIXES> " + getNgauss());
            for (int i = 1; i <= getNgauss(); i++) {
                printWriter.println("<MIXTURE> " + i + ' ' + getWeight(i - 1));
                printWriter.println("<RCLASS> 1");
                printWriter.println("<MEAN> " + getNcoefs());
                for (int i2 = 0; i2 < getNcoefs(); i2++) {
                    printWriter.print(getMean(i - 1, i2) + " ");
                }
                printWriter.println();
                printWriter.println("<VARIANCE> " + getNcoefs());
                for (int i3 = 0; i3 < getNcoefs(); i3++) {
                    printWriter.print(getVar(i - 1, i3) + " ");
                }
                printWriter.println();
            }
            printWriter.println("<TRANSP> 3");
            printWriter.println("0 1 0");
            printWriter.println("0 0.7 0.3");
            printWriter.println("0 0 0");
            printWriter.println("<ENDHMM>");
            printWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void loadHTK(String str) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            this.ngauss = 0;
            this.ncoefs = 0;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                if (readLine.contains("<MEAN>")) {
                    this.ngauss++;
                    if (this.ncoefs == 0) {
                        StringTokenizer stringTokenizer = new StringTokenizer(readLine);
                        stringTokenizer.nextToken();
                        this.ncoefs = Integer.parseInt(stringTokenizer.nextToken());
                    }
                }
            }
            bufferedReader.close();
            allocate();
            BufferedReader bufferedReader2 = new BufferedReader(new FileReader(str));
            int i = 0;
            while (true) {
                String readLine2 = bufferedReader2.readLine();
                if (readLine2 == null) {
                    bufferedReader2.close();
                    precomputeDistance();
                    return;
                }
                if (readLine2.contains("<MEAN>")) {
                    StringTokenizer stringTokenizer2 = new StringTokenizer(bufferedReader2.readLine());
                    int i2 = 0;
                    while (stringTokenizer2.hasMoreTokens()) {
                        setMean(i, i2, Float.parseFloat(stringTokenizer2.nextToken()));
                        i2++;
                    }
                    if (!bufferedReader2.readLine().contains("<VARIANCE>")) {
                        bufferedReader2.close();
                        throw new IOException();
                    }
                    StringTokenizer stringTokenizer3 = new StringTokenizer(bufferedReader2.readLine());
                    int i3 = 0;
                    while (stringTokenizer3.hasMoreTokens()) {
                        setVar(i, i3, Float.parseFloat(stringTokenizer3.nextToken()));
                        i3++;
                    }
                    i++;
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void loadScaleKMeans(String str) {
        int i = 0;
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            while (bufferedReader.readLine() != null) {
                i++;
            }
            this.ngauss = i / 2;
            bufferedReader.close();
            BufferedReader bufferedReader2 = new BufferedReader(new FileReader(str));
            this.ncoefs = bufferedReader2.readLine().split(" ").length - 1;
            bufferedReader2.close();
            BufferedReader bufferedReader3 = new BufferedReader(new FileReader(str));
            allocate();
            this.nT = 0;
            for (int i2 = 0; i2 < this.ngauss; i2++) {
                String[] split = bufferedReader3.readLine().split(" ");
                this.weights[i2] = Float.parseFloat(split[0]);
                this.nT = (int) (this.nT + this.weights[i2]);
                for (int i3 = 0; i3 < this.ncoefs; i3++) {
                    setMean(i2, i3, Float.parseFloat(split[i3 + 1]));
                }
                String[] split2 = bufferedReader3.readLine().split(" ");
                for (int i4 = 0; i4 < this.ncoefs; i4++) {
                    setVar(i2, i4, Float.parseFloat(split2[i4]));
                }
            }
            for (int i5 = 0; i5 < this.ngauss; i5++) {
                setWeight(i5, this.weights[i5] / this.nT);
            }
            bufferedReader3.close();
            precomputeDistance();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void allocateWeights() {
        this.logMath = LogMath.getLogMath();
        this.weights = new float[this.ngauss];
        for (int i = 0; i < this.ngauss; i++) {
            setWeight(i, 1.0f / this.ngauss);
        }
    }

    public void precomputeDistance() {
        for (int i = 0; i < this.ngauss; i++) {
            float f = 0.0f;
            for (int i2 = 0; i2 < this.ncoefs; i2++) {
                f += this.logMath.linearToLog(getVar(i, i2));
            }
            this.logPreComputedGaussianFactor[i] = (f + (this.logMath.linearToLog(6.283185307179586d) * this.ncoefs)) * 0.5f;
        }
    }

    private void allocate() {
        if (this.weights == null) {
            allocateWeights();
        }
        if (this.means == null) {
            this.loglikes = new float[this.ngauss];
            this.means = new float[this.ngauss][this.ncoefs];
            this.covar = new float[this.ngauss][this.ncoefs];
            this.logPreComputedGaussianFactor = new float[this.ngauss];
        }
    }

    public void computeLogLikes(float[] fArr) {
        for (int i = 0; i < this.ngauss; i++) {
            float f = 0.0f;
            for (int i2 = 0; i2 < fArr.length; i2++) {
                float f2 = fArr[i2] - this.means[i][i2];
                f += f2 * f2 * this.covar[i][i2];
            }
            float f3 = f - this.logPreComputedGaussianFactor[i];
            if (Float.isNaN(f3)) {
                System.err.println("gs2 is Nan, converting to 0 debug " + i + ' ' + this.logPreComputedGaussianFactor[i] + ' ' + this.means[i][0] + ' ' + this.covar[i][0]);
                f3 = -3.4028235E38f;
            }
            if (f3 < -3.4028235E38f) {
                f3 = -3.4028235E38f;
            }
            this.loglikes[i] = this.weights[i] + f3;
        }
    }

    public float getLogLike() {
        float f = this.loglikes[0];
        for (int i = 1; i < this.ngauss; i++) {
            f = this.logMath.addAsLinear(f, this.loglikes[i]);
        }
        return f;
    }

    public int getWinningGauss() {
        int i = 0;
        for (int i2 = 1; i2 < this.ngauss; i2++) {
            if (this.loglikes[i2] > this.loglikes[i]) {
                i = i2;
            }
        }
        return i;
    }

    public int getNcoefs() {
        return this.ncoefs;
    }

    public GMMDiag getMarginal(boolean[] zArr) {
        int i = 0;
        for (boolean z : zArr) {
            if (z) {
                i++;
            }
        }
        GMMDiag gMMDiag = new GMMDiag(getNgauss(), i);
        int i2 = 0;
        for (int i3 = 0; i3 < this.ncoefs; i3++) {
            if (zArr[i3]) {
                for (int i4 = 0; i4 < this.ngauss; i4++) {
                    gMMDiag.setMean(i4, i2, getMean(i4, i3));
                    gMMDiag.setVar(i4, i2, getVar(i4, i3));
                }
                i2++;
            }
        }
        for (int i5 = 0; i5 < this.ngauss; i5++) {
            gMMDiag.setWeight(i5, getWeight(i5));
        }
        gMMDiag.precomputeDistance();
        return gMMDiag;
    }

    public GMMDiag merge(GMMDiag gMMDiag, float f) {
        GMMDiag gMMDiag2 = new GMMDiag(getNgauss() + gMMDiag.getNgauss(), getNcoefs());
        for (int i = 0; i < getNgauss(); i++) {
            System.arraycopy(this.means[i], 0, gMMDiag2.means[i], 0, getNcoefs());
            System.arraycopy(this.covar[i], 0, gMMDiag2.covar[i], 0, getNcoefs());
            gMMDiag2.setWeight(i, getWeight(i) * f);
        }
        for (int i2 = 0; i2 < gMMDiag.getNgauss(); i2++) {
            System.arraycopy(gMMDiag.means[i2], 0, gMMDiag2.means[this.ngauss + i2], 0, getNcoefs());
            System.arraycopy(gMMDiag.covar[i2], 0, gMMDiag2.covar[this.ngauss + i2], 0, getNcoefs());
            gMMDiag2.setWeight(this.ngauss + i2, gMMDiag.getWeight(i2) * (1.0f - f));
        }
        gMMDiag2.precomputeDistance();
        return gMMDiag2;
    }

    public GMMDiag getGauss(int i) {
        GMMDiag gMMDiag = new GMMDiag(1, getNcoefs());
        System.arraycopy(this.means[i], 0, gMMDiag.means[0], 0, getNcoefs());
        System.arraycopy(this.covar[i], 0, gMMDiag.covar[0], 0, getNcoefs());
        gMMDiag.setWeight(0, 1.0f);
        gMMDiag.precomputeDistance();
        return gMMDiag;
    }

    public void setNom(String str) {
        this.nom = str;
    }

    public boolean isEqual(GMMDiag gMMDiag) {
        if (getNgauss() != gMMDiag.getNgauss() || getNgauss() != gMMDiag.getNcoefs()) {
            return false;
        }
        for (int i = 0; i < getNgauss(); i++) {
            if (isDiff(getWeight(i), gMMDiag.getWeight(i))) {
                return false;
            }
            for (int i2 = 0; i2 < getNcoefs(); i2++) {
                if (isDiff(getMean(i, i2), gMMDiag.getMean(i, i2)) || isDiff(getVar(i, i2), gMMDiag.getVar(i, i2))) {
                    return false;
                }
            }
        }
        return true;
    }

    private boolean isDiff(float f, float f2) {
        return ((double) Math.abs(1.0f - (f2 / f))) > 0.01d;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < getNgauss(); i++) {
            sb.append(getMean(i, 0)).append(' ').append(getVar(i, 0)).append('\n');
        }
        return sb.toString();
    }
}
