com.github.steveash.jg2p.seq.PhonemeACrfTrainer2 Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2015 Steve Ash
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.steveash.jg2p.seq;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.github.steveash.jg2p.align.Alignment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import cc.mallet.grmm.examples.CrossTemplate1;
import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.ACRFTrainer;
import cc.mallet.grmm.learning.DefaultAcrfTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2LabelSequence;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import static org.apache.commons.lang3.StringUtils.isBlank;
/**
* General CRF version which uses indicator functions as latent variables to see if that increases accuracy
* @author Steve Ash
*/
public class PhonemeACrfTrainer2 {
private static final Logger log = LoggerFactory.getLogger(PhonemeACrfTrainer2.class);
public void train(Collection examples) {
Pipe pipe = makePipe();
InstanceList instances = makeExamplesFromAligns(examples, pipe);
ACRF.Template[] tmpls = new ACRF.Template[]{
new ACRF.BigramTemplate(0),
new ACRF.BigramTemplate (1),
new ACRF.PairwiseFactorTemplate (0,1),
new CrossTemplate1(0,1)
};
ACRF acrf = new ACRF(pipe, tmpls);
ACRFTrainer trainer = new DefaultAcrfTrainer();
acrf.setSupportedOnly(true);
acrf.setGaussianPriorVariance(2.0);
DefaultAcrfTrainer.LogEvaluator eval = new DefaultAcrfTrainer.LogEvaluator();
eval.setNumIterToSkip(2);
trainer.train(acrf, instances, null, null, eval, 9999);
}
private static InstanceList makeExamplesFromAligns(Iterable alignsToTrain, Pipe pipe) {
int count = 0;
InstanceList instances = new InstanceList(pipe);
for (Alignment align : alignsToTrain) {
List xs = align.getWordUnigrams();
List bs = align.getXBoundaryMarks();
Iterator ys = align.getAllYTokensAsList().iterator();
Preconditions.checkState(xs.size() == bs.size());
List targets = Lists.newArrayListWithCapacity(xs.size());
for (int i = 0; i < xs.size(); i++) {
targets.add(new String[] {
bs.get(i) ? "1" : "0",
bs.get(i) ? ys.next() : "<>"
});
}
Instance ii = new Instance(xs, targets, null, null);
instances.addThruPipe(ii);
count += 1;
// if (count > 1000) {
// break;
// }
}
log.info("Read {} instances of training data", count);
return instances;
}
private Iterable getAlignsFromGroup(List inputs) {
return FluentIterable.from(inputs).transformAndConcat(
new Function>() {
@Override
public Iterable apply(SeqInputReader.AlignGroup input) {
return input.alignments;
}
});
}
private static Pipe makePipe() {
Alphabet alpha = new Alphabet();
JointInputToTokenSequence inputPipe = new JointInputToTokenSequence(alpha, new LabelAlphabet(), new LabelAlphabet());
return new SerialPipes(ImmutableList.of(
inputPipe,
new TokenSequenceLowercase(), // make all lowercase
new NeighborTokenFeature(true, makeNeighbors()), // grab neighboring graphemes
new NeighborShapeFeature(true, makeShapeNeighs()),
new TokenSequenceToFeature(), // convert the strings in the text to features
new TokenSequence2FeatureVectorSequence(alpha, true, true)
));
}
private static List makeShapeNeighs() {
return ImmutableList.of(
new TokenWindow(-5, 5),
new TokenWindow(-4, 4),
new TokenWindow(-3, 3),
new TokenWindow(-2, 2),
new TokenWindow(-1, 1),
new TokenWindow(1, 1),
new TokenWindow(1, 2),
new TokenWindow(1, 3),
new TokenWindow(1, 4),
new TokenWindow(1, 5)
);
}
private static List makeNeighbors() {
return ImmutableList.of(
new TokenWindow(1, 1),
new TokenWindow(2, 1),
new TokenWindow(3, 1),
// new TokenWindow(1, 2),
// new TokenWindow(1, 3),
new TokenWindow(-1, 1),
new TokenWindow(-2, 1),
new TokenWindow(-3, 1),
new TokenWindow(-2, 2)
// new TokenWindow(-3, 3)
// new TokenWindow(-4, 4),
);
}
}