All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.steveash.jg2p.seq.PhonemeACrfTrainer Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2015 Steve Ash
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.steveash.jg2p.seq;

import com.google.common.base.Function;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;

import com.github.steveash.jg2p.align.Alignment;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;
import java.util.List;

import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.ACRFTrainer;
import cc.mallet.grmm.learning.DefaultAcrfTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2LabelSequence;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;

import static org.apache.commons.lang3.StringUtils.isBlank;

/**
 * Just the BP version of the linear chain CRF, performs similarly to the CRF class version (albeit 10x slower)
 * @author Steve Ash
 */
public class PhonemeACrfTrainer {

  private static final Logger log = LoggerFactory.getLogger(PhonemeACrfTrainer.class);

  public void train(Collection examples) {
    Pipe pipe = makePipe();
    InstanceList instances = makeExamplesFromAligns(examples, pipe);

    ACRF.Template[] tmpls = new ACRF.Template[]{
        new ACRF.BigramTemplate(0)
//                new ACRF.BigramTemplate (1),
//                new ACRF.PairwiseFactorTemplate (0,1),
//                new CrossTemplate1(0,1)
    };

    ACRF acrf = new ACRF(pipe, tmpls);

    ACRFTrainer trainer = new DefaultAcrfTrainer();
    acrf.setSupportedOnly(true);
    acrf.setGaussianPriorVariance(2.0);
    DefaultAcrfTrainer.LogEvaluator eval = new DefaultAcrfTrainer.LogEvaluator();
    eval.setNumIterToSkip(2);
    trainer.train(acrf, instances, null, null, eval, 9999);

  }

  private static InstanceList makeExamplesFromAligns(Iterable alignsToTrain, Pipe pipe) {
    int count = 0;
    InstanceList instances = new InstanceList(pipe);
    for (Alignment align : alignsToTrain) {
      List phones = align.getAllYTokensAsList();
      updateEpsilons(phones);
      Instance ii = new Instance(align.getAllXTokensAsList(), phones, null, null);
      instances.addThruPipe(ii);
      count += 1;

      //      if (count > 1000) {
      //        break;
      //      }
    }
    log.info("Read {} instances of training data", count);
    return instances;
  }

  private Iterable getAlignsFromGroup(List inputs) {
    return FluentIterable.from(inputs).transformAndConcat(
        new Function>() {
          @Override
          public Iterable apply(SeqInputReader.AlignGroup input) {
            return input.alignments;
          }
        });
  }

  private static void updateEpsilons(List phones) {
    String last = "";
    int blankCount = 0;
    for (int i = 0; i < phones.size(); i++) {
      String p = phones.get(i);
      if (isBlank(p)) {
        //        phones.set(i, last + "_" + blankCount);
        phones.set(i, "");
        blankCount += 1;
      } else {
        last = p;
        blankCount = 0;
      }
    }
  }

  private static Pipe makePipe() {
    Alphabet alpha = new Alphabet();
    Target2LabelSequence labelPipe = new Target2LabelSequence();
    LabelAlphabet labelAlpha = (LabelAlphabet) labelPipe.getTargetAlphabet();

    return new SerialPipes(ImmutableList.of(
        new StringListToTokenSequence(alpha, labelAlpha),   // convert to token sequence
        new TokenSequenceLowercase(),                       // make all lowercase
        new NeighborTokenFeature(true, makeNeighbors()),         // grab neighboring graphemes
        new NeighborShapeFeature(true, makeShapeNeighs()),
        new TokenSequenceToFeature(),                       // convert the strings in the text to features
        new TokenSequence2FeatureVectorSequence(alpha, true, true),
        labelPipe,
        new LabelSequenceToLabelsAssignment(alpha, labelAlpha)
    ));
  }

  private static List makeShapeNeighs() {
    return ImmutableList.of(
        new TokenWindow(-5, 5),
        new TokenWindow(-4, 4),
        new TokenWindow(-3, 3),
        new TokenWindow(-2, 2),
        new TokenWindow(-1, 1),
        new TokenWindow(1, 1),
        new TokenWindow(1, 2),
        new TokenWindow(1, 3),
        new TokenWindow(1, 4),
        new TokenWindow(1, 5)
    );
  }

  private static List makeNeighbors() {
    return ImmutableList.of(
        new TokenWindow(1, 1),
        new TokenWindow(2, 1),
        new TokenWindow(3, 1),
        //        new TokenWindow(1, 2),
        //              new TokenWindow(1, 3),
        new TokenWindow(-1, 1),
        new TokenWindow(-2, 1),
        new TokenWindow(-3, 1),
        new TokenWindow(-2, 2)
        //        new TokenWindow(-3, 3)
        //        new TokenWindow(-4, 4),
    );
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy