marytts.tools.voiceimport.PauseDurationTrainer Maven / Gradle / Ivy
The newest version!
/**
* Copyright 2000-2009 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*
*/
package marytts.tools.voiceimport;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.LineNumberReader;
import java.util.ArrayList;
import java.util.List;
import java.util.SortedMap;
import java.util.TreeMap;
import marytts.cart.StringPredictionTree;
import marytts.features.FeatureDefinition;
import marytts.features.FeatureVector;
import marytts.machinelearning.GmmDiscretizer;
import weka.classifiers.trees.j48.BinC45ModelSelection;
import weka.classifiers.trees.j48.C45PruneableClassifierTree;
import weka.classifiers.trees.j48.TreeConverter;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
/*
* This Class traverses intonised xml-files, and labeller files and trains a duration
* model to predict the labeled durations from the features of the xml file
*
*/
public class PauseDurationTrainer extends VoiceImportComponent {
/**
* Tuple holding list of feature vectors and feature definition.
*/
private class VectorsAndDefinition {
public VectorsAndDefinition(List fv, FeatureDefinition fd) {
this.fv = fv;
this.fd = fd;
}
private List fv;
private FeatureDefinition fd;
public List getFv() {
return fv;
}
public void setFv(List fv) {
this.fv = fv;
}
public FeatureDefinition getFd() {
return fd;
}
public void setFd(FeatureDefinition fd) {
this.fd = fd;
}
}
// maybe specify in config file?
public final String[] featureNames = new String[] { "breakindex", "ph_cplace", "ph_ctype", "next_pos",
"next_wordbegin_ctype", "next_wordbegin_cplace", "words_from_phrase_end", "words_from_phrase_start"/**/};
// feature files used for training ("pause features")
public final String FVFILES = "PauseDurationTrainer.featureDir";
// label files providing durations
public final String LABFILES = "PauseDurationTrainer.lab";
// resulting trained decision tree
public final String TRAINEDTREE = "PauseDurationTrainer.tree";
protected DatabaseLayout db = null;
private String fvExt = ".pfeats";
private String labExt = ".lab";
public SortedMap getDefaultProps(DatabaseLayout db) {
this.db = db;
if (props == null) {
props = new TreeMap();
// dir with pause feature files
String pauseFv = System.getProperty(FVFILES);
if (pauseFv == null) {
pauseFv = db.getProp(db.ROOTDIR) + "pausefeatures" + System.getProperty("file.separator");
}
props.put(FVFILES, pauseFv);
// dir with lab files containing the pause durations
String labs = System.getProperty(LABFILES);
if (labs == null) {
labs = db.getProp(db.ROOTDIR) + "lab" + System.getProperty("file.separator");
}
props.put(LABFILES, labs);
// resulting decision tree
String tree = System.getProperty(TRAINEDTREE);
if (tree == null) {
tree = db.getProp(db.ROOTDIR) + "durations.tree";
}
props.put(TRAINEDTREE, tree);
}
return props;
}
public boolean compute() throws Exception {
// object to store all instances
Instances data = null;
FeatureDefinition fd = null;
// pause durations are added at the end
// all of them are collected first
// then discretized
List durs = new ArrayList();
for (int i = 0; i < bnl.getLength(); i++) {
VectorsAndDefinition features = this.readFeaturesFor(bnl.getName(i));
if (null == features)
continue;
List vectors = features.getFv();
fd = features.getFd();
if (data == null)
data = initData(fd);
// reader for label file.
BufferedReader lab = new BufferedReader(new FileReader(getProp(LABFILES) + bnl.getName(i) + labExt));
List labSyms = new ArrayList();
List labDurs = new ArrayList();
int prevTime = 0;
int currTime = 0;
String line;
while ((line = lab.readLine()) != null) {
if (line.startsWith("#"))
continue;
String[] lineLmnts = line.split("\\s+");
if (lineLmnts.length != 3)
throw new IllegalArgumentException("Expected three columns in label file, got " + lineLmnts.length);
labSyms.add(lineLmnts[2]);
// collect durations
currTime = (int) (1000 * Float.parseFloat(lineLmnts[0]));
int dur = currTime - prevTime;
labDurs.add(dur);
prevTime = currTime;
}
int symbolFeature = fd.getFeatureIndex("phone");
int breakindexFeature = fd.getFeatureIndex("breakindex");
int currLabelNr = 0;
// treatment of first pause(s)...
while (labSyms.get(currLabelNr).equals("_"))
currLabelNr++;
for (FeatureVector fv : vectors) {
String fvSym = fv.getFeatureAsString(symbolFeature, fd);
// all pauses on feature vector side are ignored, they are captured within boundary treatment
if (fvSym.equals("_"))
continue;
if (!fvSym.equals(labSyms.get(currLabelNr)))
throw new IllegalArgumentException("Phone symbol of label file (" + fvSym + ") and of feature vector ("
+ labSyms.get(currLabelNr) + ") don't correspond. Run CorrectedTranscriptionAligner first.");
int pauseDur = 0;
// durations are taken from pauses on label side
if ((currLabelNr + 1) < labSyms.size() && labSyms.get(currLabelNr + 1).equals("_")) {
currLabelNr++;
pauseDur = labDurs.get(currLabelNr);
}
int bi = fv.getFeatureAsInt(breakindexFeature);
if (bi > 1) {
// add new training point with fv
durs.add(pauseDur);
data.add(createInstance(data, fd, fv));
} // for each break index > 1
currLabelNr++;
}// for each featurevector
} // for each file
// set duration target attribute
data = enterDurations(data, durs);
// train classifier
StringPredictionTree wagonTree = trainTree(data, fd);
FileWriter fw = new FileWriter(getProp(TRAINEDTREE));
fw.write(wagonTree.toString());
fw.close();
return true;
}
private StringPredictionTree trainTree(Instances data, FeatureDefinition fd) throws Exception {
System.out.println("training duration tree (" + data.numInstances() + " instances) ...");
// build the tree without using the J48 wrapper class
// standard parameters are:
// binary split selection with minimum x instances at the leaves, tree is pruned, confidence value, subtree raising,
// cleanup, don't collapse
C45PruneableClassifierTree decisionTree = new C45PruneableClassifierTree(new BinC45ModelSelection(2, data, true), true,
0.25f, true, true, false);
decisionTree.buildClassifier(data);
System.out.println("...done");
return TreeConverter.c45toStringPredictionTree(decisionTree, fd, data);
}
private Instances enterDurations(Instances data, List durs) {
// System.out.println("discretizing durations...");
// now discretize and set target attributes (= pause durations)
// for that, first train discretizer
GmmDiscretizer discr = GmmDiscretizer.trainDiscretizer(durs, 6, true);
// used to store the collected values
ArrayList targetVals = new ArrayList();
for (int mappedDur : discr.getPossibleValues()) {
targetVals.add(mappedDur + "ms");
}
// FastVector attributeDeclarations = data.;
// attribute declaration finished
data.insertAttributeAt(new Attribute("target", targetVals), data.numAttributes());
for (int i = 0; i < durs.size(); i++) {
Instance currInst = data.instance(i);
int dur = durs.get(i);
// System.out.println(" mapping " + dur + " to " + discr.discretize(dur) + " - bi:" +
// data.instance(i).value(data.attribute("breakindex")));
currInst.setValue(data.numAttributes() - 1, discr.discretize(dur) + "ms");
}
// Make the last attribute be the class
data.setClassIndex(data.numAttributes() - 1);
return data;
}
private Instance createInstance(Instances data, FeatureDefinition fd, FeatureVector fv) {
// relevant features + one target
Instance currInst = new DenseInstance(data.numAttributes());
currInst.setDataset(data);
// read only relevant features
for (String attName : this.featureNames) {
int featNr = fd.getFeatureIndex(attName);
String value = fv.getFeatureAsString(featNr, fd);
currInst.setValue(data.attribute(attName), value);
}
return currInst;
}
private Instances initData(FeatureDefinition fd) {
// this stores the attributes together with allowed values
ArrayList attributeDeclarations = new ArrayList();
// first declare all the relevant attributes.
// Assume that the feature definition and relevant features of the first
// in the list are the same as the others.
for (int attribute = 0; attribute < fd.getNumberOfFeatures(); attribute++) {
String attName = fd.getFeatureName(attribute);
// skip phone
if (attName.equals("phone")) {
continue;
}
// ...collect possible values
ArrayList attVals = new ArrayList();
for (String value : fd.getPossibleValues(attribute)) {
attVals.add(value);
}
attributeDeclarations.add(new Attribute(attName, attVals));
}
// now, create the dataset adding the datapoints
return new Instances("pausedurations", attributeDeclarations, 0);
}
/**
* This reads in the features for the symbols in the input (phonemic/automatic) file from a feature stream stored in textual
* format.
*
* @param featureTable
* a LineNumberReader from which the feature table is read.
* @throws IOException
* if the input stream is ill-formed
*/
private VectorsAndDefinition readFeatureTable(LineNumberReader featureTable) throws IOException {
List featureVectors = new ArrayList();
// read the beginning of the file, containing the feature definition
FeatureDefinition fd = new FeatureDefinition(featureTable, false);
try {
// for later checks, get index of phone identity feature
fd.getFeatureIndex("phone");
fd.getFeatureIndex("breakindex");
} catch (IllegalArgumentException e) {
throw new IOException("Unexpected FeatureDefinition: Does not contain the features 'phone' and 'breakindex'.");
}
// skip section with string representation
while (!featureTable.readLine().equals("")) {
}
// now, read the features line by line
String line = "";
while ((line = featureTable.readLine()) != null) {
FeatureVector fv;
try {
fv = fd.toFeatureVector(0, line);
} catch (Exception e) {
e.printStackTrace();
throw new IOException("Unexpected Input in line " + String.valueOf(featureTable.getLineNumber()));
}
featureVectors.add(fv);
}
return new VectorsAndDefinition(featureVectors, fd);
}
/**
* This reads in some pause feature file and returns feature vectors
*
*
* @param basename
* basename
* @return readFeatureTable(lnr)
* @throws IOException
* IOException
*/
private VectorsAndDefinition readFeaturesFor(String basename) throws IOException {
FileInputStream fis;
// First, test if there is a corresponding .rawmaryxml file in textdir:
File fvFile = new File(getProp(FVFILES) + basename + fvExt);
if (fvFile.exists()) {
fis = new FileInputStream(fvFile);
} else {
return null;
}
System.out.println("processing " + getProp(FVFILES) + basename + fvExt);
// didn't work ... FeatureFileReader ffr = new FeatureFileReader();
LineNumberReader lnr = new LineNumberReader(new InputStreamReader(fis));
return readFeatureTable(lnr);
}
public String getName() {
return "PauseDurationTrainer";
}
@Override
public int getProgress() {
return 0;
}
protected void setupHelp() {
props2Help = new TreeMap();
props2Help.put(FVFILES, "Directory containing the pause feature files.");
props2Help.put(LABFILES, "Directory containing label files from which pause durations are taken.");
props2Help.put(TRAINEDTREE, "Result of training.");
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy