package edu.cmu.sphinx.fst.operations;

import edu.cmu.sphinx.fst.Arc;
import edu.cmu.sphinx.fst.Fst;
import edu.cmu.sphinx.fst.State;
import edu.cmu.sphinx.fst.semiring.Semiring;
import edu.cmu.sphinx.fst.utils.Pair;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.PriorityQueue;

/* loaded from: input_file:edu/cmu/sphinx/fst/operations/NShortestPaths.class */
public class NShortestPaths {
    private NShortestPaths() {
    }

    public static float[] shortestDistance(Fst fst) {
        Fst fst2 = Reverse.get(fst);
        float[] fArr = new float[fst2.getNumStates()];
        float[] fArr2 = new float[fst2.getNumStates()];
        Semiring semiring = fst2.getSemiring();
        Arrays.fill(fArr, semiring.zero());
        Arrays.fill(fArr2, semiring.zero());
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        linkedHashSet.add(fst2.getStart());
        fArr[fst2.getStart().getId()] = semiring.one();
        fArr2[fst2.getStart().getId()] = semiring.one();
        while (!linkedHashSet.isEmpty()) {
            State state = (State) linkedHashSet.iterator().next();
            linkedHashSet.remove(state);
            float f = fArr2[state.getId()];
            fArr2[state.getId()] = semiring.zero();
            for (int i = 0; i < state.getNumArcs(); i++) {
                Arc arc = state.getArc(i);
                State nextState = arc.getNextState();
                float f2 = fArr[arc.getNextState().getId()];
                float plus = semiring.plus(f2, semiring.times(f, arc.getWeight()));
                if (f2 != plus) {
                    fArr[arc.getNextState().getId()] = plus;
                    fArr2[arc.getNextState().getId()] = semiring.plus(fArr2[arc.getNextState().getId()], semiring.times(f, arc.getWeight()));
                    if (!linkedHashSet.contains(nextState)) {
                        linkedHashSet.add(nextState);
                    }
                }
            }
        }
        return fArr;
    }

    public static Fst get(Fst fst, int i, boolean z) {
        if (fst == null || fst.getSemiring() == null) {
            return null;
        }
        Fst fst2 = fst;
        if (z) {
            fst2 = Determinize.get(fst);
        }
        final Semiring semiring = fst2.getSemiring();
        Fst fst3 = new Fst(semiring);
        fst3.setIsyms(fst2.getIsyms());
        fst3.setOsyms(fst2.getOsyms());
        final float[] shortestDistance = shortestDistance(fst2);
        ExtendFinal.apply(fst2);
        int[] iArr = new int[fst2.getNumStates()];
        PriorityQueue priorityQueue = new PriorityQueue(10, new Comparator<Pair<State, Float>>() { // from class: edu.cmu.sphinx.fst.operations.NShortestPaths.1
            @Override // java.util.Comparator
            public int compare(Pair<State, Float> pair, Pair<State, Float> pair2) {
                float floatValue = pair.getRight().floatValue();
                float f = shortestDistance[pair.getLeft().getId()];
                float times = semiring.times(pair2.getRight().floatValue(), shortestDistance[pair2.getLeft().getId()]);
                float times2 = semiring.times(floatValue, f);
                if (semiring.naturalLess(times, times2)) {
                    return 1;
                }
                return times == times2 ? 0 : -1;
            }
        });
        HashMap hashMap = new HashMap(fst.getNumStates());
        HashMap hashMap2 = new HashMap(fst.getNumStates());
        Pair pair = new Pair(fst2.getStart(), Float.valueOf(semiring.one()));
        priorityQueue.add(pair);
        hashMap.put(pair, null);
        while (!priorityQueue.isEmpty()) {
            Pair pair2 = (Pair) priorityQueue.remove();
            State state = (State) pair2.getLeft();
            Float f = (Float) pair2.getRight();
            State state2 = new State(state.getFinalWeight());
            fst3.addState(state2);
            hashMap2.put(pair2, state2);
            if (hashMap.get(pair2) == null) {
                fst3.setStart(state2);
            } else {
                State state3 = (State) hashMap2.get(hashMap.get(pair2));
                State state4 = (State) ((Pair) hashMap.get(pair2)).getLeft();
                for (int i2 = 0; i2 < state4.getNumArcs(); i2++) {
                    Arc arc = state4.getArc(i2);
                    if (arc.getNextState().equals(state)) {
                        state3.addArc(new Arc(arc.getIlabel(), arc.getOlabel(), arc.getWeight(), state2));
                    }
                }
            }
            Integer valueOf = Integer.valueOf(state.getId());
            int intValue = valueOf.intValue();
            iArr[intValue] = iArr[intValue] + 1;
            if (iArr[valueOf.intValue()] == i && state.getFinalWeight() != semiring.zero()) {
                break;
            }
            if (iArr[valueOf.intValue()] <= i) {
                for (int i3 = 0; i3 < state.getNumArcs(); i3++) {
                    Arc arc2 = state.getArc(i3);
                    Pair pair3 = new Pair(arc2.getNextState(), Float.valueOf(semiring.times(f.floatValue(), arc2.getWeight())));
                    hashMap.put(pair3, pair2);
                    priorityQueue.add(pair3);
                }
            }
        }
        return fst3;
    }
}
