All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.steveash.jg2p.wfst.LangModelToFst Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 Steve Ash
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.steveash.jg2p.wfst;

import com.google.common.base.Charsets;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.base.Splitter;
import com.google.common.primitives.Doubles;

import com.github.steveash.jopenfst.ImmutableFst;
import com.github.steveash.jopenfst.MutableFst;
import com.github.steveash.jopenfst.WriteableSymbolTable;
import com.github.steveash.jopenfst.operations.ArcSort;
import com.github.steveash.jopenfst.semiring.TropicalSemiring;
import com.github.steveash.kylm.model.ngram.NgramLM;
import com.github.steveash.kylm.model.ngram.NgramWalker;
import com.github.steveash.kylm.model.ngram.WalkerVisitor;
import com.github.steveash.kylm.model.ngram.reader.ArpaNgramReader;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.util.List;

import static com.github.steveash.jopenfst.Fst.EPS;
import static org.apache.commons.lang3.StringUtils.isBlank;
import static org.codehaus.groovy.runtime.DefaultGroovyMethods.last;

/**
 * Converts a lang model to an FST
 *
 * @author Steve Ash
 */
public class LangModelToFst {

  public static final double REALLY_HIGH = 999.0;
  public static final double PRETTY_HIGH = 99.0;

  private final Splitter graphoneSplitter = Splitter.on(SeqTransducer.GRAPHONE_DELIM).trimResults().limit(2);
  private final Joiner commaJoin = Joiner.on(',');

  private MutableFst fst;
  private int maxOrder;

  public SeqTransducer fromArpa(File arpaFile) {
    ArpaNgramReader reader = new ArpaNgramReader();
    try (BufferedReader br = Files.newBufferedReader(arpaFile.toPath(), Charsets.UTF_8)) {
      NgramLM model = reader.read(br);
      return fromModel(model);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  public SeqTransducer fromModel(NgramLM model) {
    Preconditions.checkNotNull(model.getStartSymbol(), "must use start symbol");
    Preconditions.checkNotNull(model.getTerminalSymbol(), "must use terminal symbol");
    Preconditions.checkArgument(model.getStartSymbol().equals(SeqTransducer.START), "Only using start %s", SeqTransducer.START);
    Preconditions.checkArgument(model.getTerminalSymbol().equals(SeqTransducer.END), "Only using end %s", SeqTransducer.END);

    this.maxOrder = model.getN();
    this.fst = new MutableFst(TropicalSemiring.INSTANCE);
    fst.useStateSymbols();
    fst.newStartState(SeqTransducer.START_STATE);
    fst.newState(SeqTransducer.END).setFinalWeight(TropicalSemiring.INSTANCE.one());

    for (String sym : SeqTransducer.ALL_SKIP_STRINGS) {
      fst.getInputSymbols().getOrAdd(sym);
      fst.getOutputSymbols().getOrAdd(sym);
    }

    addArc(SeqTransducer.START_STATE, SeqTransducer.START, SeqTransducer.START, SeqTransducer.START, 0.0);
    // add stuff for each order counts
    Preconditions.checkState(maxOrder > 1, "cant work with a unigram model");
    new NgramWalker(model).walk(new WalkerVisitor() {
      @Override
      public void visit(int order, List toks, float score, float backoffScore, boolean hasChildren,
                        boolean isLastOrder) {
        Preconditions.checkState(order == toks.size());
        if (order == 1) {
          String s = toks.get(0);
          if (s.equalsIgnoreCase(SeqTransducer.START)) {
            addArc(SeqTransducer.START, EPS, EPS, EPS, backoffScore);
          } else if (s.equalsIgnoreCase(SeqTransducer.END)) {
            addArc(EPS, SeqTransducer.END, SeqTransducer.END, SeqTransducer.END, score);
          } else {
            addArc(s, EPS, EPS, EPS, backoffScore);
            addArc(EPS, s, s, s, score);
          }
          return;
        }
        String last = last(toks);
        if (last.equalsIgnoreCase(SeqTransducer.END)) {
          // last in a sentence just ends in the terminal
          addArc(commaJoin.join(toks.subList(0, toks.size() - 1)),
                 last, last, last, score);
          return;
        }
        if (isLastOrder) {
          // we're moving one step
          addArc(commaJoin.join(toks.subList(0, toks.size() - 1)),
                 commaJoin.join(toks.subList(1, toks.size())),
                 last,
                 last,
                 score
          );
          return;
        }
        // we are in the middle so emit an arc for the backoff and an arc for the score
        addArc(commaJoin.join(toks),
               commaJoin.join(toks.subList(1, toks.size())),
               EPS,
               EPS,
               backoffScore
        );
        addArc(commaJoin.join(toks.subList(0, toks.size()-1)),
               commaJoin.join(toks),
               last,
               last,
               score
        );
      }
    });

    patchSymbols(fst.getInputSymbols(), true);
    patchSymbols(fst.getOutputSymbols(), false);

    ArcSort.sortByInput(fst);
    fst.dropStateSymbols(); // we dont need these for test time
    return new SeqTransducer(new ImmutableFst(fst), this.maxOrder);
  }

  // psaurus does this...but for the O-labels i dont really get this ...
  private void patchSymbols(WriteableSymbolTable symbols, boolean isInput) {
    for (int i = 0; i < symbols.size(); i++) {
      if (!symbols.invert().containsKey(i)) {
        continue;
      }
      String symbol = symbols.invert().keyForId(i);
      if (symbol.contains(SeqTransducer.GRAPHONE_DELIM)) {
        for (String unigram : graphoneSplitter.split(symbol)) {
          if (symbols.contains(unigram)) {
            continue;
          }
          // this unigram doesn't exist so add a backoff edge to the start
          if (isInput) {
            fst.addArc(SeqTransducer.START, unigram, SeqTransducer.SKIP, SeqTransducer.START_STATE, PRETTY_HIGH);
          } else {
            fst.addArc(SeqTransducer.START, SeqTransducer.SKIP, unigram, SeqTransducer.START_STATE, PRETTY_HIGH);
          }
        }
      }
    }
  }

  private void addArc(String thisStateSymbol, String nextStateSymbol, String inLabel, String outLabel, double weight) {
    weight = tropicalWeight(weight);
    List result = graphoneSplitter.splitToList(inLabel);
    String correctedIn = inLabel.trim();
    String correctedOut = outLabel.trim();
    if (result.size() > 1) {
      Preconditions.checkState(result.size() == 2, "we only support X:Y split or X:Y in input");
      correctedIn = result.get(0);
      correctedOut = result.get(1);
    }
    if (isBlank(correctedIn)) {
      correctedIn = EPS;
    }
    if (isBlank(correctedOut)) {
      correctedOut = EPS;
    }
    fst.addArc(thisStateSymbol, correctedIn, correctedOut, nextStateSymbol, weight);
  }

  private static double tropicalWeight(double inWeight) {
    double val = Math.log(10.0) * inWeight * -1.0;
    if (!Doubles.isFinite(val)) {
      val = REALLY_HIGH;
    }
    return val;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy