package edu.cmu.sphinx.decoder.search;

import edu.cmu.sphinx.decoder.pruner.Pruner;
import edu.cmu.sphinx.decoder.scorer.AcousticScorer;
import edu.cmu.sphinx.frontend.Data;
import edu.cmu.sphinx.linguist.Linguist;
import edu.cmu.sphinx.linguist.SearchState;
import edu.cmu.sphinx.linguist.SearchStateArc;
import edu.cmu.sphinx.linguist.WordSearchState;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Loader;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Sphinx3Loader;
import edu.cmu.sphinx.linguist.allphone.PhoneHmmSearchState;
import edu.cmu.sphinx.linguist.lextree.LexTreeLinguist;
import edu.cmu.sphinx.result.Result;
import edu.cmu.sphinx.util.props.PropertyException;
import edu.cmu.sphinx.util.props.PropertySheet;
import edu.cmu.sphinx.util.props.S4Component;
import edu.cmu.sphinx.util.props.S4Double;
import edu.cmu.sphinx.util.props.S4Integer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;

/* loaded from: input_file:META-INF/jars/sphinx4-core-5prealpha-SNAPSHOT.jar:edu/cmu/sphinx/decoder/search/WordPruningBreadthFirstLookaheadSearchManager.class */
public class WordPruningBreadthFirstLookaheadSearchManager extends WordPruningBreadthFirstSearchManager {

    @S4Component(type = Loader.class)
    public static final String PROP_LOADER = "loader";

    @S4Component(type = Linguist.class)
    public static final String PROP_FASTMATCH_LINGUIST = "fastmatchLinguist";

    @S4Component(type = ActiveListFactory.class)
    public static final String PROP_FM_ACTIVE_LIST_FACTORY = "fastmatchActiveListFactory";

    @S4Double(defaultValue = 1.0d)
    public static final String PROP_LOOKAHEAD_PENALTY_WEIGHT = "lookaheadPenaltyWeight";

    @S4Integer(defaultValue = 5)
    public static final String PROP_LOOKAHEAD_WINDOW = "lookaheadWindow";
    private Linguist fastmatchLinguist;
    private Loader loader;
    private ActiveListFactory fastmatchActiveListFactory;
    private int lookaheadWindow;
    private float lookaheadWeight;
    private HashMap<Integer, Float> penalties;
    private LinkedList<FrameCiScores> ciScores;
    private int currentFastMatchFrameNumber;
    protected ActiveList fastmatchActiveList;
    protected Map<SearchState, Token> fastMatchBestTokenMap;
    private boolean fastmatchStreamEnd;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:META-INF/jars/sphinx4-core-5prealpha-SNAPSHOT.jar:edu/cmu/sphinx/decoder/search/WordPruningBreadthFirstLookaheadSearchManager$FrameCiScores.class */
    public class FrameCiScores {
        public final float[] scores;
        public final float maxScore;

        public FrameCiScores(float[] fArr, float f) {
            this.scores = fArr;
            this.maxScore = f;
        }
    }

    public WordPruningBreadthFirstLookaheadSearchManager(Linguist linguist, Linguist linguist2, Loader loader, Pruner pruner, AcousticScorer acousticScorer, ActiveListManager activeListManager, ActiveListFactory activeListFactory, boolean z, double d, int i, boolean z2, boolean z3, int i2, float f, int i3, float f2, boolean z4) {
        super(linguist, pruner, acousticScorer, activeListManager, z, d, i, z2, z3, i3, f2, z4);
        this.loader = loader;
        this.fastmatchLinguist = linguist2;
        this.fastmatchActiveListFactory = activeListFactory;
        this.lookaheadWindow = i2;
        this.lookaheadWeight = f;
        if (i2 < 1 || i2 > 10) {
            throw new IllegalArgumentException("Unsupported lookahead window size: " + i2 + ". Value in range [1..10] is expected");
        }
        this.ciScores = new LinkedList<>();
        this.penalties = new HashMap<>();
        if ((loader instanceof Sphinx3Loader) && ((Sphinx3Loader) loader).hasTiedMixtures()) {
            ((Sphinx3Loader) loader).setGauScoresQueueLength(i2 + 2);
        }
    }

    public WordPruningBreadthFirstLookaheadSearchManager() {
    }

    @Override // edu.cmu.sphinx.decoder.search.WordPruningBreadthFirstSearchManager, edu.cmu.sphinx.decoder.search.TokenSearchManager, edu.cmu.sphinx.util.props.Configurable
    public void newProperties(PropertySheet propertySheet) throws PropertyException {
        super.newProperties(propertySheet);
        this.fastmatchLinguist = (Linguist) propertySheet.getComponent(PROP_FASTMATCH_LINGUIST);
        this.fastmatchActiveListFactory = (ActiveListFactory) propertySheet.getComponent(PROP_FM_ACTIVE_LIST_FACTORY);
        this.loader = (Loader) propertySheet.getComponent("loader");
        this.lookaheadWindow = propertySheet.getInt(PROP_LOOKAHEAD_WINDOW);
        this.lookaheadWeight = propertySheet.getFloat(PROP_LOOKAHEAD_PENALTY_WEIGHT);
        if (this.lookaheadWindow < 1 || this.lookaheadWindow > 10) {
            throw new PropertyException(WordPruningBreadthFirstLookaheadSearchManager.class.getName(), PROP_LOOKAHEAD_WINDOW, "Unsupported lookahead window size: " + this.lookaheadWindow + ". Value in range [1..10] is expected");
        }
        this.ciScores = new LinkedList<>();
        this.penalties = new HashMap<>();
        if ((this.loader instanceof Sphinx3Loader) && ((Sphinx3Loader) this.loader).hasTiedMixtures()) {
            ((Sphinx3Loader) this.loader).setGauScoresQueueLength(this.lookaheadWindow + 2);
        }
    }

    @Override // edu.cmu.sphinx.decoder.search.WordPruningBreadthFirstSearchManager, edu.cmu.sphinx.decoder.search.SearchManager
    public Result recognize(int i) {
        boolean z = false;
        this.streamEnd = false;
        for (int i2 = 0; i2 < i && !z; i2++) {
            if (!this.fastmatchStreamEnd) {
                fastMatchRecognize();
            }
            this.penalties.clear();
            this.ciScores.poll();
            z = recognize();
        }
        Result result = this.streamEnd ? null : new Result(this.loserManager, this.activeList, this.resultList, this.currentCollectTime, z, this.linguist.getSearchGraph().getWordTokenFirst(), true);
        if (this.showTokenCount) {
            showTokenCount();
        }
        return result;
    }

    private void fastMatchRecognize() {
        if (scoreFastMatchTokens()) {
            pruneFastMatchBranches();
            this.currentFastMatchFrameNumber++;
            createFastMatchBestTokenMap();
            growFastmatchBranches();
        }
    }

    protected void createFastMatchBestTokenMap() {
        int size = this.fastmatchActiveList.size() * 10;
        if (size == 0) {
            size = 1;
        }
        this.fastMatchBestTokenMap = new HashMap(size);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.cmu.sphinx.decoder.search.WordPruningBreadthFirstSearchManager
    public void localStart() {
        this.currentFastMatchFrameNumber = 0;
        if ((this.loader instanceof Sphinx3Loader) && ((Sphinx3Loader) this.loader).hasTiedMixtures()) {
            ((Sphinx3Loader) this.loader).clearGauScores();
        }
        this.fastmatchActiveList = this.fastmatchActiveListFactory.newInstance();
        this.fastmatchActiveList.add(new Token(this.fastmatchLinguist.getSearchGraph().getInitialState(), this.currentFastMatchFrameNumber));
        createFastMatchBestTokenMap();
        growFastmatchBranches();
        this.fastmatchStreamEnd = false;
        for (int i = 0; i < this.lookaheadWindow - 1 && !this.fastmatchStreamEnd; i++) {
            fastMatchRecognize();
        }
        super.localStart();
    }

    protected void growFastmatchBranches() {
        this.growTimer.start();
        ActiveList activeList = this.fastmatchActiveList;
        this.fastmatchActiveList = this.fastmatchActiveListFactory.newInstance();
        float beamThreshold = activeList.getBeamThreshold();
        float[] fArr = new float[100];
        Arrays.fill(fArr, -3.4028235E38f);
        float f = -3.4028235E38f;
        for (Token token : activeList) {
            float score = token.getScore();
            if (score >= beamThreshold) {
                if (token.getSearchState() instanceof PhoneHmmSearchState) {
                    int baseId = ((PhoneHmmSearchState) token.getSearchState()).getBaseId();
                    if (fArr[baseId] < score) {
                        fArr[baseId] = score;
                    }
                    if (f < score) {
                        f = score;
                    }
                }
                collectFastMatchSuccessorTokens(token);
            }
        }
        this.ciScores.add(new FrameCiScores(fArr, f));
        this.growTimer.stop();
    }

    protected boolean scoreFastMatchTokens() {
        this.scoreTimer.start();
        Data calculateScoresAndStoreData = this.scorer.calculateScoresAndStoreData(this.fastmatchActiveList.getTokens());
        this.scoreTimer.stop();
        Token token = null;
        if (calculateScoresAndStoreData instanceof Token) {
            token = (Token) calculateScoresAndStoreData;
        } else {
            this.fastmatchStreamEnd = true;
        }
        boolean z = token != null;
        this.fastmatchActiveList.setBestToken(token);
        monitorStates(this.fastmatchActiveList);
        this.curTokensScored.value += this.fastmatchActiveList.size();
        this.totalTokensScored.value += this.fastmatchActiveList.size();
        return z;
    }

    protected void pruneFastMatchBranches() {
        this.pruneTimer.start();
        this.fastmatchActiveList = this.pruner.prune(this.fastmatchActiveList);
        this.pruneTimer.stop();
    }

    protected Token getFastMatchBestToken(SearchState searchState) {
        return this.fastMatchBestTokenMap.get(searchState);
    }

    protected void setFastMatchBestToken(Token token, SearchState searchState) {
        this.fastMatchBestTokenMap.put(searchState, token);
    }

    protected void collectFastMatchSuccessorTokens(Token token) {
        for (SearchStateArc searchStateArc : token.getSearchState().getSuccessors()) {
            SearchState state = searchStateArc.getState();
            float score = token.getScore() + searchStateArc.getProbability();
            Token resultListPredecessor = getResultListPredecessor(token);
            if (state.isEmitting()) {
                Token fastMatchBestToken = getFastMatchBestToken(state);
                if (fastMatchBestToken == null) {
                    Token token2 = new Token(resultListPredecessor, state, score, searchStateArc.getInsertionProbability(), searchStateArc.getLanguageProbability(), this.currentFastMatchFrameNumber);
                    this.tokensCreated.value += 1.0d;
                    setFastMatchBestToken(token2, state);
                    this.fastmatchActiveList.add(token2);
                } else if (fastMatchBestToken.getScore() <= score) {
                    fastMatchBestToken.update(resultListPredecessor, state, score, searchStateArc.getInsertionProbability(), searchStateArc.getLanguageProbability(), this.currentFastMatchFrameNumber);
                }
            } else {
                Token token3 = new Token(resultListPredecessor, state, score, searchStateArc.getInsertionProbability(), searchStateArc.getLanguageProbability(), this.currentFastMatchFrameNumber);
                this.tokensCreated.value += 1.0d;
                if (!isVisited(token3)) {
                    collectFastMatchSuccessorTokens(token3);
                }
            }
        }
    }

    @Override // edu.cmu.sphinx.decoder.search.WordPruningBreadthFirstSearchManager
    protected void collectSuccessorTokens(Token token) {
        int i;
        if (token.isFinal()) {
            this.resultList.add(getResultListPredecessor(token));
            return;
        }
        if (!token.isEmitting() && this.keepAllTokens && isVisited(token)) {
            return;
        }
        SearchState searchState = token.getSearchState();
        SearchStateArc[] successors = searchState.getSuccessors();
        Token resultListPredecessor = getResultListPredecessor(token);
        float score = token.getScore();
        float beamThreshold = this.activeList.getBeamThreshold();
        boolean z = (searchState instanceof LexTreeLinguist.LexTreeNonEmittingHMMState) || (searchState instanceof LexTreeLinguist.LexTreeWordState) || (searchState instanceof LexTreeLinguist.LexTreeEndUnitState);
        for (SearchStateArc searchStateArc : successors) {
            SearchState state = searchStateArc.getState();
            if (z && (state instanceof LexTreeLinguist.LexTreeHMMState)) {
                int baseID = ((LexTreeLinguist.LexTreeHMMState) state).getHMMState().getHMM().getBaseUnit().getBaseID();
                Float f = this.penalties.get(Integer.valueOf(baseID));
                Float f2 = f;
                if (f == null) {
                    f2 = updateLookaheadPenalty(baseID);
                }
                i = score + (this.lookaheadWeight * f2.floatValue()) < beamThreshold ? i + 1 : 0;
            }
            if (this.checkStateOrder) {
                checkStateOrder(searchState, state);
            }
            float probability = score + searchStateArc.getProbability();
            Token bestToken = getBestToken(state);
            if (bestToken == null) {
                Token token2 = new Token(resultListPredecessor, state, probability, searchStateArc.getInsertionProbability(), searchStateArc.getLanguageProbability(), this.currentCollectTime);
                this.tokensCreated.value += 1.0d;
                setBestToken(token2, state);
                activeListAdd(token2);
            } else if (bestToken.getScore() < probability) {
                Token predecessor = bestToken.getPredecessor();
                bestToken.update(resultListPredecessor, state, probability, searchStateArc.getInsertionProbability(), searchStateArc.getLanguageProbability(), this.currentCollectTime);
                if (this.buildWordLattice && (state instanceof WordSearchState)) {
                    this.loserManager.addAlternatePredecessor(bestToken, predecessor);
                }
            } else if (this.buildWordLattice && (state instanceof WordSearchState) && resultListPredecessor != null) {
                this.loserManager.addAlternatePredecessor(bestToken, resultListPredecessor);
            }
        }
    }

    private Float updateLookaheadPenalty(int i) {
        if (this.ciScores.isEmpty()) {
            return Float.valueOf(0.0f);
        }
        float f = -3.4028235E38f;
        Iterator<FrameCiScores> it = this.ciScores.iterator();
        while (it.hasNext()) {
            FrameCiScores next = it.next();
            float f2 = next.scores[i] - next.maxScore;
            if (f2 > f) {
                f = f2;
            }
        }
        this.penalties.put(Integer.valueOf(i), Float.valueOf(f));
        return Float.valueOf(f);
    }
}
