All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.steveash.jg2p.seq.PhonemeCrfTrainer Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2014 Steve Ash
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.steveash.jg2p.seq;

import com.google.common.base.Function;
import com.google.common.base.Stopwatch;
import com.google.common.base.Throwables;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;

import com.github.steveash.jg2p.align.Alignment;
import com.github.steveash.jg2p.align.TrainOptions;
import com.github.steveash.jg2p.util.FeatureSelections;
import com.github.steveash.jg2p.util.GramBuilder;
import com.github.steveash.jg2p.util.ModelReadWrite;
import com.github.steveash.jg2p.util.ReadWrite;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.util.Collection;
import java.util.List;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByThreadedLabelLikelihood;
import cc.mallet.fst.TokenAccuracyEvaluator;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2LabelSequence;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.RankedFeatureVector;

import static org.apache.commons.lang3.StringUtils.isBlank;

/**
 * Trains a CRF to use the alignment model
 *
 * @author Steve Ash
 */
public class PhonemeCrfTrainer {

//  public static final String PROP_ALIGNMENT = "prop_alignment";
  public static final String PROP_STRUCTURE = "prop_structure";

  private static final Logger log = LoggerFactory.getLogger(PhonemeCrfTrainer.class);

  public static PhonemeCrfTrainer open(TrainOptions opts) {
    PhonemeCrfTrainer pct = new PhonemeCrfTrainer(opts);
    return pct;
  }

  private final TrainOptions opts;

  private CRF crf = null;
  private CRF crfFrom = null;
  private TransducerTrainer lastTrainer = null;

  private PhonemeCrfTrainer(TrainOptions opts) {
    this.opts = opts;
    if (opts.initCrfFromModelFile != null) {
      try {
        log.info("Loading initial weights from " + opts.initCrfFromModelFile);
        this.crfFrom = readCrfFrom();
      } catch (Exception e) {
        throw Throwables.propagate(e);
      }
    }
  }

  private void initializeFor(InstanceList examples) {
    this.crf = new CRF(examples.getPipe(), null);
    crf.addOrderNStates(examples, new int[]{1}, null, null, null, null, false);
    crf.addStartState();
    crf.setWeightsDimensionAsIn(examples, false);

    if (crfFrom != null) {
      crf.initializeApplicableParametersFrom(crfFrom);
    }
  }

  private CRF readCrfFrom() throws IOException, ClassNotFoundException {
    return ModelReadWrite.readPronouncerFrom(opts.initCrfFromModelFile).getCrf();
  }

  public void trainFor(Collection inputs) {
    // this pipe is the default pipe with new alphabet
    Stopwatch watch = Stopwatch.createStarted();
    trainRound(inputs, new Alphabet(), 0);

    crf.getInputAlphabet().stopGrowth();
    crf.getOutputAlphabet().stopGrowth();
    watch.stop();
    log.info("Training took " + watch);
  }

  private void trainRound(Collection inputs, Alphabet alpha, int trainRound) {
    SerialPipes initialPipe = makePipe(alpha);
    InstanceList examples = makeExamplesFromAligns(initialPipe, inputs);
    initializeFor(examples);

    CRFTrainerByThreadedLabelLikelihood trainer = makeNewTrainer(crf);
    this.lastTrainer = trainer;

    trainer.train(examples, opts.maxPronouncerTrainingIterations);
    trainer.shutdown(); // just closes the pool; next call to train will create a new one

    if (trainRound == 0 && opts.trimFeaturesUnderPercentile > 0) {
      trainer.getCRF().pruneFeaturesBelowPercentile(opts.trimFeaturesUnderPercentile);
      trainer.train(examples);
      trainer.shutdown();
    }
    if (trainRound == 0 && opts.trimFeaturesByGradientGain > 0) {

      // calc the gradients, report some stats on them, then move on for now
      log.info("Trimming based on gradiant gain ratio...");
//      String dateString = DateFormatUtils.format(new Date(), "yyMMddmmss");
      RankedFeatureVector rfv = FeatureSelections.gradientGainRatioFrom(examples, crf);

//      writeRankedToFile(pair.getLeft(), new File("grads" + dateString + ".txt"));
//      writeRankedToFile(pair.getRight(), new File("gradratio" + dateString + ".txt"));
//      writeRankedToFile(featureCountsFrom(examples), new File("featcounts" + dateString + ".txt"));

      Alphabet newDict = new Alphabet();
      for (int i = 0; i < rfv.singleSize(); i++) {
        double ratio = rfv.value(i);
        if (ratio > opts.trimFeaturesByGradientGain) {
          newDict.lookupIndex(alpha.lookupObject(i), true);
        }
      }
      log.info("Feature selection before count " + alpha.size() + " after " + newDict.size());
      newDict.stopGrowth();
      this.crfFrom = this.crf;

      trainRound(inputs, newDict, trainRound + 1);
    }
  }

  private double accuracyFor(InstanceList examples) {
    TokenAccuracyEvaluator teval = new TokenAccuracyEvaluator(examples, "train");
    teval.evaluate(lastTrainer);
    return teval.getAccuracy("train");
  }

  public PhonemeCrfModel buildModel() {

    return new PhonemeCrfModel(crf);
  }

  private static CRFTrainerByThreadedLabelLikelihood makeNewTrainer(CRF crf) {
    CRFTrainerByThreadedLabelLikelihood trainer = new CRFTrainerByThreadedLabelLikelihood(crf, getCpuCount());
    trainer.setGaussianPriorVariance(2);
    trainer.setAddNoFactors(true);
    trainer.setUseSomeUnsupportedTrick(false);
    return trainer;
  }

  private static CRFTrainerByLabelLikelihood makeNewTrainerSingleThreaded(CRF crf) {
    CRFTrainerByLabelLikelihood trainer = new CRFTrainerByLabelLikelihood(crf);
    trainer.setGaussianPriorVariance(2);
//    trainer.setUseHyperbolicPrior(true);
    trainer.setAddNoFactors(true);
    trainer.setUseSomeUnsupportedTrick(false);
    return trainer;
  }

//  private static CRF makeNewCrf(InstanceList examples, Pipe pipe) {
//    CRF crf = new CRF(pipe, null);
//    crf.addOrderNStates(examples, new int[]{1}, null, null, null, null, false);
//    crf.addStartState();
//    crf.setWeightsDimensionDensely();
//    crf.setWeightsDimensionAsIn(examples, false);
//    crf.setWeightsDimensionWithFilterAsIn(examples, 2);
//    crf.addFullyConnectedStatesForBiLabels();
//    crf.addStartState();
//    return crf;
//  }

  private static int getCpuCount() {
    return Runtime.getRuntime().availableProcessors();
  }

  public void writeModel(File target) throws IOException {
    CRF crf = (CRF) lastTrainer.getTransducer();
    ReadWrite.writeTo(new PhonemeCrfModel(crf), target);
    log.info("Wrote for whole data");
  }

  private InstanceList makeExamplesFromAligns(Pipe pipe, Iterable alignsToTrain) {
    int count = 0;
    InstanceList instances = new InstanceList(pipe);
    for (Alignment align : alignsToTrain) {
      List phones = align.getAllYTokensAsList();
      updateEpsilons(phones);
      Instance ii = new Instance(align, phones, null, null);
      instances.addThruPipe(ii);
      count += 1;
    }
    log.info("Read {} instances of training data for pronouncer training", count);
    return instances;
  }

  private Iterable getAlignsFromGroup(List inputs) {
    return FluentIterable.from(inputs).transformAndConcat(
        new Function>() {
          @Override
          public Iterable apply(SeqInputReader.AlignGroup input) {
            return input.alignments;
          }
        });
  }

  private static void updateEpsilons(List phones) {
    String last = GramBuilder.EPS;
    int blankCount = 0;
    for (int i = 0; i < phones.size(); i++) {
      String p = phones.get(i);
      if (isBlank(p)) {
//        phones.set(i, last + "_" + blankCount);
        phones.set(i, GramBuilder.EPS);
        blankCount += 1;
      } else {
        last = p;
        blankCount = 0;
      }
    }
  }

  private SerialPipes makePipe(Alphabet alpha) {
    Target2LabelSequence labelPipe = new Target2LabelSequence();
    LabelAlphabet labelAlpha = (LabelAlphabet) labelPipe.getTargetAlphabet();

    return new SerialPipes(ImmutableList.of(
        new AlignmentToTokenSequence(alpha, labelAlpha, true, true, false),   // convert to token sequence
        new TokenSequenceLowercase(),                       // make all lowercase
        new NeighborTokenFeature(true, makeNeighbors()),         // grab neighboring graphemes
        new NeighborShapeFeature(true, makeShapeNeighs()),
//        new WindowFeature(false, 4),
//        new WindowFeature(true, 6),
        new NeighborSyllableFeature(-2, -1, 1, 2),
        new SyllCountingFeature(),
        new SyllCharRoleFeature(),
//        new NearSyllFeature(true),
//        new NearSyllFeature(false),
//        new SyllMarkingFeature(),
//        new SyllSequenceFeature(),
//        new SyllRelativeMarkFeature(),
        new EndingVowelFeature(),
        //new SonorityFeature2(true),
        //new SonorityFeature2(false),
//        new WindowFeature(false, 4),
        new VowelWindowFeature(2, 1, "PRESYL_", -1, false),
        new VowelWindowFeature(2, 1, "PSTSYL_", 1, false),
//        new VowelWindowFeature(3, 2, "LSTSYL_", 0, true),
        new SurroundingTokenFeature2(false, 1, 1),
//        new SurroundingTokenFeature2(true, 1, 1),
        new SurroundingTokenFeature2(false, 2, 2),
//        new SurroundingTokenFeature2(false, 3, 2),
        new SurroundingTokenFeature2(true, 3, 3),
//        new SurroundingTokenFeature2(true, 4, 4),
//        new LeadingTrailingFeature(),
        new TokenSequenceToFeature(),                       // convert the strings in the text to features
        new TokenSequence2FeatureVectorSequence(alpha, true, false),
        labelPipe
    ));
  }

  private static List makeShapeNeighs() {
    return ImmutableList.of(
        new TokenWindow(-6, 6),
        new TokenWindow(-5, 5),
        new TokenWindow(-4, 4),
//        new TokenWindow(-3, 3),
//        new TokenWindow(-2, 2),
//        new TokenWindow(-1, 1),
//        new TokenWindow(1, 1),
//        new TokenWindow(1, 2),
//        new TokenWindow(1, 3),
        new TokenWindow(1, 4),
        new TokenWindow(1, 5),
        new TokenWindow(1, 6)
    );
  }

  private static List makeNeighbors() {
    return ImmutableList.of(
        new TokenWindow(1, 1),
        new TokenWindow(1, 2),
        new TokenWindow(2, 1),
        new TokenWindow(1, 3),
        new TokenWindow(3, 1),
        new TokenWindow(1, 4),
//        new TokenWindow(4, 1),
        new TokenWindow(-1, 1),
        new TokenWindow(-2, 2),
        new TokenWindow(-2, 1),
        new TokenWindow(-3, 3),
        new TokenWindow(-4, 4),
        new TokenWindow(-5, 5)

//        new TokenWindow(-2, 2),
//        new TokenWindow(-3, 3)
//        new TokenWindow(-4, 4),
    );
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy