All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.steveash.jg2p.rerank.Rerank3Trainer Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2015 Steve Ash
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.steveash.jg2p.rerank;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import com.github.steveash.jg2p.syll.PhoneSyllTagModel;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;
import java.util.List;

import cc.mallet.classify.RankMaxEnt;
import cc.mallet.classify.RankMaxEntTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;

/**
 * Trains a maxent classifier to do A-B reranking
 *
 * @author Steve Ash
 */
public class Rerank3Trainer {

  private static final Logger log = LoggerFactory.getLogger(Rerank3Trainer.class);

  private Pipe pipe = null;
  private PhoneSyllTagModel phoneSyllModel = null;

  public void setPhoneSyllModel(PhoneSyllTagModel phoneSyllModel) {
    this.phoneSyllModel = phoneSyllModel;
  }

  /**
   * Takes "one-sided" rerank examples and will create the "flip" side and then train on both (so we dont learn to just
   * prefer one side over the other
   */
  public Rerank3Model trainFor(Collection> trainingData) {
    pipe = makePipe();
    InstanceList instances = convert(trainingData);
    RankMaxEntTrainer trainer = new RankMaxEntTrainer(10.0);
//    AdaBoostTrainer trainer = new AdaBoostTrainer(new MaxEntL1Trainer(), 10);
    RankMaxEnt model = (RankMaxEnt) trainer.train(instances);
//    Trial trial = new Trial(model, instances);
//    log.info("Trained reranker. Final accuracy on itself: " + trial.getAccuracy());
//    log.info(new ConfusionMatrix(trial).toString());
    return new Rerank3Model(model);
  }

  private InstanceList convert(Collection> trainingData) {
    InstanceList instances = new InstanceList(pipe, trainingData.size());
    int count = 0;
    for (List data : trainingData) {
      instances.addThruPipe(new Instance(data, 1 /*just putting something here triggers pipe*/, null, data.get(0).getWordGraphs()));
      count += 1;

      if (count % 10000 == 0) {
        log.info("Loaded " + count + " instances ...");
      }
    }
    log.info("Loaded all " + instances.size() + " instances");
    return instances;
  }

  private Pipe makePipe() {
    Alphabet alpha = new Alphabet();
    LabelAlphabet labelAlpha = new LabelAlphabet();

    List features = Lists.newArrayList(
        new DupsPipe(),
        new ModePipe(),
        new PrefixPipe(),
        new RanksPipe(),
        new ScoresPipe(),
        new ShapePipe(),
        new ShapePrefixPipe()
//        new VowelBigramPipe(),
//        new VowelPatternPipe()
    );
    if (phoneSyllModel != null) {
      features.add(new SyllAgreeRerankFeature(phoneSyllModel));
      log.info("Using the syll phone tagger in the reranker");
    }
    return new SerialPipes(ImmutableList.of(
        new LoadTargetPipe(alpha, labelAlpha),
        new RerankFeaturePipe(alpha, labelAlpha, features)
    ));

  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy