marytts.tools.newlanguage.LTSTrainer Maven / Gradle / Ivy
The newest version!
/**
* Copyright 2000-2009 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*
*/
package marytts.tools.newlanguage;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import marytts.cart.CART;
import marytts.cart.DecisionNode;
import marytts.cart.io.MaryCARTWriter;
import marytts.exceptions.MaryConfigurationException;
import marytts.features.FeatureDefinition;
import marytts.fst.AlignerTrainer;
import marytts.fst.StringPair;
import marytts.modules.phonemiser.Allophone;
import marytts.modules.phonemiser.AllophoneSet;
import marytts.modules.phonemiser.TrainedLTS;
import org.apache.log4j.BasicConfigurator;
import weka.classifiers.trees.j48.BinC45ModelSelection;
import weka.classifiers.trees.j48.C45PruneableClassifierTree;
import weka.classifiers.trees.j48.C45PruneableClassifierTreeWithUnary;
import weka.classifiers.trees.j48.TreeConverter;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
/**
*
* This class is a generic approach to predict a phone sequence from a grapheme sequence.
*
* the normal sequence of steps is: 1) initialize the trainer with a phone set and a locale
*
* 2) read in the lexicon, preserve stress if you like
*
* 3) make some alignment iterations (usually 5-10)
*
* 4) train the trees and save them in wagon format in a specified directory
*
* see main method for an example.
*
* Apply the model using TrainedLTS
*
* @author benjaminroth
*
*/
public class LTSTrainer extends AlignerTrainer {
protected AllophoneSet phSet;
protected int context;
protected boolean convertToLowercase;
protected boolean considerStress;
/**
* Create a new LTSTrainer.
*
* @param aPhSet
* the allophone set to use.
* @param convertToLowercase
* whether to convert all graphemes to lowercase, using the locale of the allophone set.
* @param considerStress
* indicator if stress is preserved
* @param context
* context
*/
public LTSTrainer(AllophoneSet aPhSet, boolean convertToLowercase, boolean considerStress, int context) {
super();
this.phSet = aPhSet;
this.convertToLowercase = convertToLowercase;
this.considerStress = considerStress;
this.context = context;
BasicConfigurator.configure();
}
/**
* Train the tree, using binary decision nodes.
*
* @param minLeafData
* the minimum number of instances that have to occur in at least two subsets induced by split
* @return bigTree
* @throws IOException
* IOException
*/
public CART trainTree(int minLeafData) throws IOException {
Map> grapheme2align = new HashMap>();
for (String gr : this.graphemeSet) {
grapheme2align.put(gr, new ArrayList());
}
Set phChains = new HashSet();
// for every alignment pair collect counts
for (int i = 0; i < this.inSplit.size(); i++) {
StringPair[] alignment = this.getAlignment(i);
for (int inNr = 0; inNr < alignment.length; inNr++) {
// System.err.println(alignment[inNr]);
// quotation signs needed to represent empty string
String outAlNr = "'" + alignment[inNr].getString2() + "'";
// TODO: don't consider alignments to more than three characters
if (outAlNr.length() > 5)
continue;
phChains.add(outAlNr);
// storing context and target
String[] datapoint = new String[2 * context + 2];
for (int ct = 0; ct < 2 * context + 1; ct++) {
int pos = inNr - context + ct;
if (pos >= 0 && pos < alignment.length) {
datapoint[ct] = alignment[pos].getString1();
} else {
datapoint[ct] = "null";
}
}
// set target
datapoint[2 * context + 1] = outAlNr;
// add datapoint
grapheme2align.get(alignment[inNr].getString1()).add(datapoint);
}
}
// for conversion need feature definition file
FeatureDefinition fd = this.graphemeFeatureDef(phChains);
int centerGrapheme = fd.getFeatureIndex("att" + (context + 1));
List stl = new ArrayList(fd.getNumberOfValues(centerGrapheme));
for (String gr : fd.getPossibleValues(centerGrapheme)) {
System.out.println(" Training decision tree for: " + gr);
logger.debug(" Training decision tree for: " + gr);
ArrayList attributeDeclarations = new ArrayList();
// attributes with values
for (int att = 1; att <= context * 2 + 1; att++) {
// ...collect possible values
ArrayList attVals = new ArrayList();
String featureName = "att" + att;
for (String usableGrapheme : fd.getPossibleValues(fd.getFeatureIndex(featureName))) {
attVals.add(usableGrapheme);
}
attributeDeclarations.add(new Attribute(featureName, attVals));
}
List datapoints = grapheme2align.get(gr);
// maybe training is faster with targets limited to grapheme
Set graphSpecPh = new HashSet();
for (String[] dp : datapoints) {
graphSpecPh.add(dp[dp.length - 1]);
}
// targetattribute
// ...collect possible values
ArrayList targetVals = new ArrayList();
for (String phc : graphSpecPh) {// todo: use either fd of phChains
targetVals.add(phc);
}
attributeDeclarations.add(new Attribute(TrainedLTS.PREDICTED_STRING_FEATURENAME, targetVals));
// now, create the dataset adding the datapoints
Instances data = new Instances(gr, attributeDeclarations, 0);
// datapoints
for (String[] point : datapoints) {
Instance currInst = new DenseInstance(data.numAttributes());
currInst.setDataset(data);
for (int i = 0; i < point.length; i++) {
currInst.setValue(i, point[i]);
}
data.add(currInst);
}
// Make the last attribute be the class
data.setClassIndex(data.numAttributes() - 1);
// build the tree without using the J48 wrapper class
// standard parameters are:
// binary split selection with minimum x instances at the leaves, tree is pruned, confidenced value, subtree raising,
// cleanup, don't collapse
// Here is used a modifed version of C45PruneableClassifierTree that allow using Unary Classes (see Issue #51)
C45PruneableClassifierTree decisionTree;
try {
decisionTree = new C45PruneableClassifierTreeWithUnary(new BinC45ModelSelection(minLeafData, data, true), true,
0.25f, true, true, false);
decisionTree.buildClassifier(data);
} catch (Exception e) {
throw new RuntimeException("couldn't train decisiontree using weka: ", e);
}
CART maryTree = TreeConverter.c45toStringCART(decisionTree, fd, data);
stl.add(maryTree);
}
DecisionNode.ByteDecisionNode rootNode = new DecisionNode.ByteDecisionNode(centerGrapheme, stl.size(), fd);
for (CART st : stl) {
rootNode.addDaughter(st.getRootNode());
}
Properties props = new Properties();
props.setProperty("lowercase", String.valueOf(convertToLowercase));
props.setProperty("stress", String.valueOf(considerStress));
props.setProperty("context", String.valueOf(context));
CART bigTree = new CART(rootNode, fd, props);
return bigTree;
}
/**
*
* Convenience method to save files to graph2phon.wagon and graph2phon.pfeats in a specified directory with UTF-8 encoding.
*
* @param tree
* tree
* @param saveTreefile
* saveTreefile
* @throws IOException
* IOException
*/
public void save(CART tree, String saveTreefile) throws IOException {
MaryCARTWriter mcw = new MaryCARTWriter();
mcw.dumpMaryCART(tree, saveTreefile);
}
private FeatureDefinition graphemeFeatureDef(Set phChains) throws IOException {
String lineBreak = System.getProperty("line.separator");
StringBuilder fdString = new StringBuilder("ByteValuedFeatureProcessors");
fdString.append(lineBreak);
// add attribute features
for (int att = 1; att <= context * 2 + 1; att++) {
fdString.append("att").append(att);
for (String gr : this.graphemeSet) {
fdString.append(" ").append(gr);
}
fdString.append(lineBreak);
}
fdString.append("ShortValuedFeatureProcessors").append(lineBreak);
// add class features
fdString.append(TrainedLTS.PREDICTED_STRING_FEATURENAME);
for (String ph : phChains) {
fdString.append(" ").append(ph);
}
fdString.append(lineBreak);
fdString.append("ContinuousFeatureProcessors").append(lineBreak);
BufferedReader featureReader = new BufferedReader(new StringReader(fdString.toString()));
return new FeatureDefinition(featureReader, false);
}
/**
*
* reads in a lexicon in text format, lines are of the kind:
*
* graphemechain | phonechain | otherinformation
*
* Stress is optionally preserved, marking the first vowel of a stressed syllable with "1".
*
* @param lexicon
* reader with lines of lexicon
* @param splitPattern
* a regular expression used for identifying the field separator in each line.
* @throws IOException
* IOException
*/
public void readLexicon(BufferedReader lexicon, String splitPattern) throws IOException {
String line;
while ((line = lexicon.readLine()) != null) {
String[] lineParts = line.trim().split(splitPattern);
String graphStr = lineParts[0];
if (convertToLowercase)
graphStr = graphStr.toLowerCase(phSet.getLocale());
graphStr = graphStr.replaceAll("['-.]", "");
// remove all secondary stress markers
String phonStr = lineParts[1].replaceAll(",", "");
String[] syllables = phonStr.split("-");
List separatedPhones = new ArrayList();
List separatedGraphemes = new ArrayList();
String currPh;
for (String syl : syllables) {
boolean stress = false;
if (syl.startsWith("'")) {
syl = syl.substring(1);
stress = true;
}
for (Allophone ph : phSet.splitIntoAllophones(syl)) {
currPh = ph.name();
if (stress && considerStress && ph.isVowel()) {
currPh += "1";
stress = false;
}
separatedPhones.add(currPh);
}// ... for each allophone
}
for (int i = 0; i < graphStr.length(); i++) {
this.graphemeSet.add(graphStr.substring(i, i + 1));
separatedGraphemes.add(graphStr.substring(i, i + 1));
}
this.addAlreadySplit(separatedGraphemes, separatedPhones);
}
// Need one entry for the "null" grapheme, which maps to the empty string:
this.addAlreadySplit(new String[] { "null" }, new String[] { "" });
}
/**
* reads in a lexicon in text format, lines are of the kind:
*
* graphemechain | phonechain | otherinformation
*
* Stress is optionally preserved, marking the first vowel of a stressed syllable with "1".
*
* @param lexicon
* lexicon
*/
public void readLexicon(HashMap lexicon) {
Iterator it = lexicon.keySet().iterator();
while (it.hasNext()) {
String graphStr = it.next();
// remove all secondary stress markers
String phonStr = lexicon.get(graphStr).replaceAll(",", "");
if (convertToLowercase)
graphStr = graphStr.toLowerCase(phSet.getLocale());
graphStr = graphStr.replaceAll("['-.]", "");
String[] syllables = phonStr.split("-");
List separatedPhones = new ArrayList();
List separatedGraphemes = new ArrayList();
String currPh;
for (String syl : syllables) {
boolean stress = false;
if (syl.startsWith("'")) {
syl = syl.substring(1);
stress = true;
}
for (Allophone ph : phSet.splitIntoAllophones(syl)) {
currPh = ph.name();
if (stress && considerStress && ph.isVowel()) {
currPh += "1";
stress = false;
}
separatedPhones.add(currPh);
}// ... for each allophone
}
for (int i = 0; i < graphStr.length(); i++) {
this.graphemeSet.add(graphStr.substring(i, i + 1));
separatedGraphemes.add(graphStr.substring(i, i + 1));
}
this.addAlreadySplit(separatedGraphemes, separatedPhones);
}
// Need one entry for the "null" grapheme, which maps to the empty string:
this.addAlreadySplit(new String[] { "null" }, new String[] { "" });
}
public static void main(String[] args) throws IOException, MaryConfigurationException {
String phFileLoc = "/Users/benjaminroth/Desktop/mary/english/phone-list-engba.xml";
// initialize trainer
LTSTrainer tp = new LTSTrainer(AllophoneSet.getAllophoneSet(phFileLoc), true, true, 2);
BufferedReader lexReader = new BufferedReader(new InputStreamReader(new FileInputStream(
"/Users/benjaminroth/Desktop/mary/english/sampa-lexicon.txt"), "ISO-8859-1"));
// read lexicon for training
tp.readLexicon(lexReader, "\\\\");
// make some alignment iterations
for (int i = 0; i < 5; i++) {
System.out.println("iteration " + i);
tp.alignIteration();
}
CART st = tp.trainTree(100);
tp.save(st, "/Users/benjaminroth/Desktop/mary/english/trees/");
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy