All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.sphinx.fst.operations.Compose Maven / Gradle / Ivy

The newest version!
/**
 * 
 * Copyright 1999-2012 Carnegie Mellon University.  
 * Portions Copyright 2002 Sun Microsystems, Inc.  
 * Portions Copyright 2002 Mitsubishi Electric Research Laboratories.
 * All Rights Reserved.  Use is subject to license terms.
 * 
 * See the file "license.terms" for information on usage and
 * redistribution of this file, and for a DISCLAIMER OF ALL 
 * WARRANTIES.
 *
 */

package edu.cmu.sphinx.fst.operations;

import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Queue;

import edu.cmu.sphinx.fst.Arc;
import edu.cmu.sphinx.fst.Fst;
import edu.cmu.sphinx.fst.ImmutableFst;
import edu.cmu.sphinx.fst.State;
import edu.cmu.sphinx.fst.semiring.Semiring;
import edu.cmu.sphinx.fst.utils.Pair;

/**
 * Compose operation.
 * 
 * See: M. Mohri, "Weighted automata algorithms", Handbook of Weighted Automata.
 * Springer, pp. 213-250, 2009.
 * 
 * @author John Salatas
 */

public class Compose {
    /**
     * Default Constructor
     */
    private Compose() {
    }

    /**
     * Computes the composition of two Fsts. Assuming no epsilon transitions.
     * 
     * Input Fsts are not modified.
     * 
     * @param fst1 the first Fst
     * @param fst2 the second Fst
     * @param semiring the semiring to use in the operation
     * @param sorted sort result
     * @return the composed Fst
     */
    public static Fst compose(Fst fst1, Fst fst2, Semiring semiring,
            boolean sorted) {
        if (!Arrays.equals(fst1.getOsyms(), fst2.getIsyms())) {
            // symboltables do not match
            return null;
        }

        Fst res = new Fst(semiring);

        HashMap, State> stateMap = new HashMap, State>();
        Queue> queue = new LinkedList>();

        State s1 = fst1.getStart();
        State s2 = fst2.getStart();

        if ((s1 == null) || (s2 == null)) {
            System.err.println("Cannot find initial state.");
            return null;
        }

        Pair p = new Pair(s1, s2);
        State s = new State(semiring.times(s1.getFinalWeight(),
                s2.getFinalWeight()));

        res.addState(s);
        res.setStart(s);
        stateMap.put(p, s);
        queue.add(p);

        while (!queue.isEmpty()) {
            p = queue.remove();
            s1 = p.getLeft();
            s2 = p.getRight();
            s = stateMap.get(p);
            int numArcs1 = s1.getNumArcs();
            int numArcs2 = s2.getNumArcs();
            for (int i = 0; i < numArcs1; i++) {
                Arc a1 = s1.getArc(i);
                for (int j = 0; j < numArcs2; j++) {
                    Arc a2 = s2.getArc(j);
                    if (sorted && a1.getOlabel() < a2.getIlabel())
                        break;
                    if (a1.getOlabel() == a2.getIlabel()) {
                        State nextState1 = a1.getNextState();
                        State nextState2 = a2.getNextState();
                        Pair nextPair = new Pair(
                                nextState1, nextState2);
                        State nextState = stateMap.get(nextPair);
                        if (nextState == null) {
                            nextState = new State(semiring.times(
                                    nextState1.getFinalWeight(),
                                    nextState2.getFinalWeight()));
                            res.addState(nextState);
                            stateMap.put(nextPair, nextState);
                            queue.add(nextPair);
                        }
                        Arc a = new Arc(a1.getIlabel(), a2.getOlabel(),
                                semiring.times(a1.getWeight(), a2.getWeight()),
                                nextState);
                        s.addArc(a);
                    }
                }
            }
        }

        res.setIsyms(fst1.getIsyms());
        res.setOsyms(fst2.getOsyms());

        return res;
    }

    /**
     * Computes the composition of two Fsts. The two Fsts are augmented in order
     * to avoid multiple epsilon paths in the resulting Fst
     * 
     * @param fst1 the first Fst
     * @param fst2 the second Fst
     * @param semiring the semiring to use in the operation
     * @return the composed Fst
     */
    public static Fst get(Fst fst1, Fst fst2, Semiring semiring) {
        if ((fst1 == null) || (fst2 == null)) {
            return null;
        }

        if (!Arrays.equals(fst1.getOsyms(), fst2.getIsyms())) {
            // symboltables do not match
            return null;
        }

        Fst filter = getFilter(fst1.getOsyms(), semiring);
        augment(1, fst1, semiring);
        augment(0, fst2, semiring);

        Fst tmp = Compose.compose(fst1, filter, semiring, false);

        Fst res = Compose.compose(tmp, fst2, semiring, false);

        // Connect.apply(res);

        return res;
    }

    /**
     * Get a filter to use for avoiding multiple epsilon paths in the resulting
     * Fst
     * 
     * See: M. Mohri, "Weighted automata algorithms", Handbook of Weighted
     * Automata. Springer, pp. 213-250, 2009.
     * 
     * @param syms the gilter's input/output symbols
     * @param semiring the semiring to use in the operation
     * @return the filter
     */
    public static Fst getFilter(String[] syms, Semiring semiring) {
        Fst filter = new Fst(semiring);

        int e1index = syms.length;
        int e2index = syms.length + 1;

        filter.setIsyms(syms);
        filter.setOsyms(syms);

        // State 0
        State s0 = new State(syms.length + 3);
        s0.setFinalWeight(semiring.one());
        State s1 = new State(syms.length);
        s1.setFinalWeight(semiring.one());
        State s2 = new State(syms.length);
        s2.setFinalWeight(semiring.one());
        filter.addState(s0);
        s0.addArc(new Arc(e2index, e1index, semiring.one(), s0));
        s0.addArc(new Arc(e1index, e1index, semiring.one(), s1));
        s0.addArc(new Arc(e2index, e2index, semiring.one(), s2));
        for (int i = 1; i < syms.length; i++) {
            s0.addArc(new Arc(i, i, semiring.one(), s0));
        }
        filter.setStart(s0);

        // State 1
        filter.addState(s1);
        s1.addArc(new Arc(e1index, e1index, semiring.one(), s1));
        for (int i = 1; i < syms.length; i++) {
            s1.addArc(new Arc(i, i, semiring.one(), s0));
        }

        // State 2
        filter.addState(s2);
        s2.addArc(new Arc(e2index, e2index, semiring.one(), s2));
        for (int i = 1; i < syms.length; i++) {
            s2.addArc(new Arc(i, i, semiring.one(), s0));
        }

        return filter;
    }

    /**
     * Augments the labels of an Fst in order to use it for composition avoiding
     * multiple epsilon paths in the resulting Fst
     * 
     * Augment can be applied to both {@link edu.cmu.sphinx.fst.Fst} and
     * {@link edu.cmu.sphinx.fst.ImmutableFst}, as immutable fsts hold an
     * additional null arc for that operation
     * 
     * @param label constant denoting if the augment should take place on input
     *            or output labels For value equal to 0 augment will take place
     *            for input labels For value equal to 1 augment will take place
     *            for output labels
     * @param fst the fst to augment
     * @param semiring the semiring to use in the operation
     */
    public static void augment(int label, Fst fst, Semiring semiring) {
        // label: 0->augment on ilabel
        // 1->augment on olabel

        String[] isyms = fst.getIsyms();
        String[] osyms = fst.getOsyms();

        int e1inputIndex = isyms.length;
        int e2inputIndex = isyms.length + 1;

        int e1outputIndex = osyms.length;
        int e2outputIndex = osyms.length + 1;

        int numStates = fst.getNumStates();
        for (int i = 0; i < numStates; i++) {
            State s = fst.getState(i);
            // Immutable fsts hold an additional (null) arc for augmention
            int numArcs = (fst instanceof ImmutableFst) ? s.getNumArcs() - 1
                    : s.getNumArcs();
            for (int j = 0; j < numArcs; j++) {
                Arc a = s.getArc(j);
                if ((label == 1) && (a.getOlabel() == 0)) {
                    a.setOlabel(e2outputIndex);
                } else if ((label == 0) && (a.getIlabel() == 0)) {
                    a.setIlabel(e1inputIndex);
                }
            }
            if (label == 0) {
                if (fst instanceof ImmutableFst) {
                    s.setArc(numArcs, new Arc(e2inputIndex, 0, semiring.one(),
                            s));
                } else {
                    s.addArc(new Arc(e2inputIndex, 0, semiring.one(), s));
                }
            } else if (label == 1) {
                if (fst instanceof ImmutableFst) {
                    s.setArc(numArcs, new Arc(0, e1outputIndex, semiring.one(),
                            s));
                } else {
                    s.addArc(new Arc(0, e1outputIndex, semiring.one(), s));
                }
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy