All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.steveash.jg2p.train.SimpleEncoderTrainer Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2014 Steve Ash
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.steveash.jg2p.train;

import com.google.common.base.Preconditions;

import com.github.steveash.jg2p.PhoneticEncoder;
import com.github.steveash.jg2p.PhoneticEncoderFactory;
import com.github.steveash.jg2p.align.AlignModel;
import com.github.steveash.jg2p.align.AlignerTrainer;
import com.github.steveash.jg2p.align.Alignment;
import com.github.steveash.jg2p.align.InputRecord;
import com.github.steveash.jg2p.align.TrainOptions;
import com.github.steveash.jg2p.aligntag.AlignTagModel;
import com.github.steveash.jg2p.aligntag.AlignTagTrainer;
import com.github.steveash.jg2p.seq.PhonemeCrfModel;
import com.github.steveash.jg2p.seq.PhonemeCrfTrainer;

import java.util.List;

/**
 * Encoder trainer that only does a forward pass from alignment stage to phoneme classification stage and doesnt feed
 * anything back to the alignment model.  This only trains: trainingAligner, testingAligner, and pronouncer no reranking
 *
 * @author Steve Ash
 */
public class SimpleEncoderTrainer extends AbstractEncoderTrainer {

  private boolean useCrf = true;
  private boolean skipAlignTrain = false;

  public SimpleEncoderTrainer() {
  }

  @Override
  public PhoneticEncoder train(List inputs, TrainOptions opts) {
    AlignTagTrainer alignTagTrainer = new AlignTagTrainer();

    AlignModel model;
    if (!skipAlignTrain) {
      AlignerTrainer alignTrainer = new AlignerTrainer(opts);
      model = alignTrainer.train(inputs);
      setAlignModel(model);
    } else {
      model = getAlignModel();
      Preconditions.checkArgument(model != null, "cant skip align training if you dont set a model");
    }

    List crfExamples = makeCrfExamples(inputs, model, opts);
    AlignTagModel alignTagModel = alignTagTrainer.train(crfExamples);

    PhonemeCrfModel crfModel;
    PhonemeCrfTrainer crfTrainer = PhonemeCrfTrainer.open(opts);
    crfTrainer.trainFor(crfExamples);
    crfModel = crfTrainer.buildModel();
    PhoneticEncoder encoder = PhoneticEncoderFactory.make(alignTagModel, crfModel);
    encoder.setAlignModel(model);
    return encoder;
  }

  public boolean isSkipAlignTrain() {
    return skipAlignTrain;
  }

  public void setSkipAlignTrain(boolean skipAlignTrain) {
    this.skipAlignTrain = skipAlignTrain;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy