com.github.steveash.jg2p.seq.PhonemeACrfTrainer Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2015 Steve Ash
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.steveash.jg2p.seq;
import com.google.common.base.Function;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
import com.github.steveash.jg2p.align.Alignment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.List;
import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.ACRFTrainer;
import cc.mallet.grmm.learning.DefaultAcrfTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2LabelSequence;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import static org.apache.commons.lang3.StringUtils.isBlank;
/**
* Just the BP version of the linear chain CRF, performs similarly to the CRF class version (albeit 10x slower)
* @author Steve Ash
*/
public class PhonemeACrfTrainer {
private static final Logger log = LoggerFactory.getLogger(PhonemeACrfTrainer.class);
public void train(Collection examples) {
Pipe pipe = makePipe();
InstanceList instances = makeExamplesFromAligns(examples, pipe);
ACRF.Template[] tmpls = new ACRF.Template[]{
new ACRF.BigramTemplate(0)
// new ACRF.BigramTemplate (1),
// new ACRF.PairwiseFactorTemplate (0,1),
// new CrossTemplate1(0,1)
};
ACRF acrf = new ACRF(pipe, tmpls);
ACRFTrainer trainer = new DefaultAcrfTrainer();
acrf.setSupportedOnly(true);
acrf.setGaussianPriorVariance(2.0);
DefaultAcrfTrainer.LogEvaluator eval = new DefaultAcrfTrainer.LogEvaluator();
eval.setNumIterToSkip(2);
trainer.train(acrf, instances, null, null, eval, 9999);
}
private static InstanceList makeExamplesFromAligns(Iterable alignsToTrain, Pipe pipe) {
int count = 0;
InstanceList instances = new InstanceList(pipe);
for (Alignment align : alignsToTrain) {
List phones = align.getAllYTokensAsList();
updateEpsilons(phones);
Instance ii = new Instance(align.getAllXTokensAsList(), phones, null, null);
instances.addThruPipe(ii);
count += 1;
// if (count > 1000) {
// break;
// }
}
log.info("Read {} instances of training data", count);
return instances;
}
private Iterable getAlignsFromGroup(List inputs) {
return FluentIterable.from(inputs).transformAndConcat(
new Function>() {
@Override
public Iterable apply(SeqInputReader.AlignGroup input) {
return input.alignments;
}
});
}
private static void updateEpsilons(List phones) {
String last = "";
int blankCount = 0;
for (int i = 0; i < phones.size(); i++) {
String p = phones.get(i);
if (isBlank(p)) {
// phones.set(i, last + "_" + blankCount);
phones.set(i, "");
blankCount += 1;
} else {
last = p;
blankCount = 0;
}
}
}
private static Pipe makePipe() {
Alphabet alpha = new Alphabet();
Target2LabelSequence labelPipe = new Target2LabelSequence();
LabelAlphabet labelAlpha = (LabelAlphabet) labelPipe.getTargetAlphabet();
return new SerialPipes(ImmutableList.of(
new StringListToTokenSequence(alpha, labelAlpha), // convert to token sequence
new TokenSequenceLowercase(), // make all lowercase
new NeighborTokenFeature(true, makeNeighbors()), // grab neighboring graphemes
new NeighborShapeFeature(true, makeShapeNeighs()),
new TokenSequenceToFeature(), // convert the strings in the text to features
new TokenSequence2FeatureVectorSequence(alpha, true, true),
labelPipe,
new LabelSequenceToLabelsAssignment(alpha, labelAlpha)
));
}
private static List makeShapeNeighs() {
return ImmutableList.of(
new TokenWindow(-5, 5),
new TokenWindow(-4, 4),
new TokenWindow(-3, 3),
new TokenWindow(-2, 2),
new TokenWindow(-1, 1),
new TokenWindow(1, 1),
new TokenWindow(1, 2),
new TokenWindow(1, 3),
new TokenWindow(1, 4),
new TokenWindow(1, 5)
);
}
private static List makeNeighbors() {
return ImmutableList.of(
new TokenWindow(1, 1),
new TokenWindow(2, 1),
new TokenWindow(3, 1),
// new TokenWindow(1, 2),
// new TokenWindow(1, 3),
new TokenWindow(-1, 1),
new TokenWindow(-2, 1),
new TokenWindow(-3, 1),
new TokenWindow(-2, 2)
// new TokenWindow(-3, 3)
// new TokenWindow(-4, 4),
);
}
}