com.github.steveash.jg2p.syllchain.SyllChainTrainer Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2016 Steve Ash
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.steveash.jg2p.syllchain;
import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import com.github.steveash.jg2p.Grams;
import com.github.steveash.jg2p.Word;
import com.github.steveash.jg2p.align.Alignment;
import com.github.steveash.jg2p.seq.LeadingTrailingFeature;
import com.github.steveash.jg2p.seq.NeighborShapeFeature;
import com.github.steveash.jg2p.seq.NeighborTokenFeature;
import com.github.steveash.jg2p.seq.StringListToTokenSequence;
import com.github.steveash.jg2p.seq.SurroundingTokenFeature;
import com.github.steveash.jg2p.seq.TokenSequenceToFeature;
import com.github.steveash.jg2p.seq.TokenWindow;
import com.github.steveash.jg2p.syll.SWord;
import com.github.steveash.jg2p.syll.SyllTagTrainer;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByThreadedLabelLikelihood;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2LabelSequence;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
/**
* @author Steve Ash
*/
public class SyllChainTrainer {
// private static final boolean USE_BIO_CODING = true;
private static final Logger log = LoggerFactory.getLogger(SyllChainTrainer.class);
private CRF initFrom = null;
public void setInitFrom(CRF initFrom) {
this.initFrom = initFrom;
}
public SyllChainModel train(List aligns) {
log.info("About to train the syll chain...");
InstanceList examples = makeExamplesFromAligns(aligns);
Pipe pipe = examples.getPipe();
log.info("Training test-time syll chain tagger on whole data...");
TransducerTrainer trainer = trainOnce(pipe, examples);
return new SyllChainModel((CRF) trainer.getTransducer());
}
private TransducerTrainer trainOnce(Pipe pipe, InstanceList examples) {
Stopwatch watch = Stopwatch.createStarted();
CRF crf = new CRF(pipe, null);
crf.addOrderNStates(examples, new int[]{1}, null, null, null, null, false);
crf.addStartState();
crf.setWeightsDimensionAsIn(examples, true);
if (initFrom != null) {
crf.initializeApplicableParametersFrom(initFrom);
}
log.info("Starting syllchain training...");
CRFTrainerByThreadedLabelLikelihood trainer = new CRFTrainerByThreadedLabelLikelihood(crf, 8);
trainer.setGaussianPriorVariance(2);
trainer.setAddNoFactors(true);
// trainer.setUseSomeUnsupportedTrick(true);
trainer.train(examples);
trainer.shutdown();
watch.stop();
log.info("SyllChain CRF Training took " + watch.toString());
crf.getInputAlphabet().stopGrowth();
crf.getOutputAlphabet().stopGrowth();
return trainer;
}
private InstanceList makeExamplesFromAligns(List aligns) {
Pipe pipe = makePipe();
int count = 0;
InstanceList instances = new InstanceList(pipe);
for (Alignment align : aligns) {
Set graphStarts = SyllChainTrainer.splitGraphsByPhoneSylls(align);
Word orig = Word.fromSpaceSeparated(align.getWordAsSpaceString());
// Word marks = Word.fromGrams(SyllTagTrainer.makeSyllableGraphEndMarksFor(align));
// Word marks = Word.fromGrams(SyllTagTrainer.makeOncForGraphemes(align));
Word marks = Word.fromGrams(SyllTagTrainer.makeSyllableGraphEndMarksFromGraphStarts(align.getInputWord(), graphStarts));
Preconditions.checkState(orig.unigramCount() == marks.unigramCount());
Instance ii = new Instance(orig.getValue(), marks.getValue(), null, null);
instances.addThruPipe(ii);
count += 1;
}
log.info("Read {} instances of training data for align tag", count);
return instances;
}
private Pipe makePipe() {
Alphabet alpha = new Alphabet();
Target2LabelSequence labelPipe = new Target2LabelSequence();
LabelAlphabet labelAlpha = (LabelAlphabet) labelPipe.getTargetAlphabet();
return new SerialPipes(ImmutableList.of(
new StringListToTokenSequence(alpha, labelAlpha), // convert to token sequence
new TokenSequenceLowercase(), // make all lowercase
new NeighborTokenFeature(true, makeNeighbors()), // grab neighboring graphemes
new SurroundingTokenFeature(false),
new SurroundingTokenFeature(true),
new NeighborShapeFeature(true, makeShapeNeighs()),
new LeadingTrailingFeature(),
new TokenSequenceToFeature(), // convert the strings in the text to features
new TokenSequence2FeatureVectorSequence(alpha, true, false),
labelPipe
));
}
private static List makeShapeNeighs() {
return ImmutableList.of(
// new TokenWindow(-5, 5),
new TokenWindow(-4, 4),
new TokenWindow(-3, 3),
new TokenWindow(-2, 2),
new TokenWindow(-1, 1),
new TokenWindow(1, 1),
new TokenWindow(1, 2),
new TokenWindow(1, 3),
new TokenWindow(1, 4)
// new TokenWindow(1, 5)
);
}
private List makeNeighbors() {
return ImmutableList.of(
new TokenWindow(1, 1),
new TokenWindow(1, 2),
new TokenWindow(2, 1),
new TokenWindow(1, 3),
new TokenWindow(4, 1),
new TokenWindow(-1, 1),
new TokenWindow(-2, 2),
new TokenWindow(-3, 3),
new TokenWindow(-4, 1)
);
}
public static Set splitGraphsByPhoneSylls(Alignment ali) {
SWord sword = ali.getSyllWord();
Preconditions.checkNotNull(sword, "cant use this at test time");
Preconditions.checkArgument(ali.getGraphones().size() > 0, "empty alignment");
HashSet starts = Sets.newHashSet();
starts.add(0); // always the first is a start
int owedSyll = 0;
int xx = 0;
int yy = 0;
// we skip the first labelled phoneme because we always deliberately add 0 and sometimes epsilon graph
boolean sawFirst = false;
for (Pair, List> graphone : ali.getGraphonesSplit()) {
if (owedSyll > 0) {
starts.add(xx);
owedSyll = 0;
}
List graphs = graphone.getLeft();
List phones = graphone.getRight();
boolean sawNewSyll = false;
boolean sawMultiSyll = false;
for (int i = 0; i < phones.size(); i++) {
if (phones.get(i).equals(Grams.EPSILON)) continue; // skip epsilons
if (sword.isStartOfSyllable(yy)) {
if (sawFirst) {
starts.add(xx);
} else {
sawFirst = true;
}
if (sawNewSyll) {
sawMultiSyll = true;
}
sawNewSyll = true;
}
yy += 1;
}
if (sawMultiSyll) {
owedSyll += 1;
}
if (graphs.size() > 0 && !graphs.get(0).equals(Grams.EPSILON)) {
xx += graphs.size();
}
}
Preconditions.checkState(xx == ali.getInputWord().unigramCount(), "bad ending gram count", ali.getInputWord());
Preconditions.checkState(yy == sword.unigramCount(), "bad ending phone count ", ali.getInputWord());
return starts;
}
}