marytts.tools.voiceimport.DurationTreeTrainer Maven / Gradle / Ivy
The newest version!
/**
* Copyright 2007 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*
*/
package marytts.tools.voiceimport;
import java.io.IOException;
import java.util.SortedMap;
import java.util.TreeMap;
import marytts.cart.DecisionNode;
import marytts.cart.DirectedGraph;
import marytts.cart.DirectedGraphNode;
import marytts.cart.LeafNode;
import marytts.cart.Node;
import marytts.cart.LeafNode.FeatureVectorLeafNode;
import marytts.cart.LeafNode.FloatLeafNode;
import marytts.cart.io.DirectedGraphWriter;
import marytts.exceptions.MaryConfigurationException;
import marytts.features.FeatureVector;
import marytts.tools.voiceimport.traintrees.AgglomerativeClusterer;
import marytts.tools.voiceimport.traintrees.DurationDistanceMeasure;
import marytts.unitselection.data.FeatureFileReader;
import marytts.unitselection.data.UnitFileReader;
import marytts.util.math.MathUtils;
/**
* A class which converts a text file in festvox format into a one-file-per-utterance format in a given directory.
*
* @author schroed
*
*/
public class DurationTreeTrainer extends VoiceImportComponent {
protected DatabaseLayout db = null;
private final String name = "DurationTreeTrainer";
public final String DURTREE = name + ".durTree";
public final String FEATUREFILE = name + ".featureFile";
public final String UNITFILE = name + ".unitFile";
public final String MAXDATA = name + ".maxData";
public final String PROPORTIONTESTDATA = name + ".propTestData";
public String getName() {
return name;
}
public SortedMap getDefaultProps(DatabaseLayout theDb) {
this.db = theDb;
if (props == null) {
props = new TreeMap();
String fileSeparator = System.getProperty("file.separator");
props.put(FEATUREFILE, db.getProp(db.FILEDIR) + "phoneFeatures" + db.getProp(db.MARYEXT));
props.put(UNITFILE, db.getProp(db.FILEDIR) + "phoneUnits" + db.getProp(db.MARYEXT));
props.put(DURTREE, db.getProp(db.FILEDIR) + "dur.graph.mry");
props.put(MAXDATA, "0");
props.put(PROPORTIONTESTDATA, "0.1");
}
return props;
}
protected void setupHelp() {
props2Help = new TreeMap();
props2Help.put(FEATUREFILE, "file containing all phone units and their target cost features");
props2Help.put(UNITFILE, "file containing all phone units");
props2Help.put(DURTREE, "file containing the duration tree. Will be created by this module");
props2Help.put(MAXDATA, "if >0, gives the maximum number of syllables to use for training the tree");
props2Help.put(PROPORTIONTESTDATA,
"the proportion of the data to use as test data (choose so that 1/value is an integer)");
}
@Override
public boolean compute() throws IOException, MaryConfigurationException {
logger.info("Duration tree trainer started.");
FeatureFileReader featureFile = FeatureFileReader.getFeatureFileReader(getProp(FEATUREFILE));
UnitFileReader unitFile = new UnitFileReader(getProp(UNITFILE));
FeatureVector[] allFeatureVectors = featureFile.getFeatureVectors();
int maxData = Integer.parseInt(getProp(MAXDATA));
if (maxData == 0)
maxData = allFeatureVectors.length;
FeatureVector[] featureVectors = new FeatureVector[Math.min(maxData, allFeatureVectors.length)];
System.arraycopy(allFeatureVectors, 0, featureVectors, 0, featureVectors.length);
logger.debug("Total of " + allFeatureVectors.length + " feature vectors -- will use " + featureVectors.length);
AgglomerativeClusterer clusterer = new AgglomerativeClusterer(featureVectors, featureFile.getFeatureDefinition(), null,
new DurationDistanceMeasure(unitFile), Float.parseFloat(getProp(PROPORTIONTESTDATA)));
DirectedGraphWriter writer = new DirectedGraphWriter();
DirectedGraph graph;
int iteration = 0;
do {
graph = clusterer.cluster();
iteration++;
if (graph != null) {
writer.saveGraph(graph, getProp(DURTREE) + ".level" + iteration);
}
} while (clusterer.canClusterMore());
if (graph == null) {
return false;
}
// Now replace each leaf with a FloatLeafNode containing mean and stddev
for (LeafNode leaf : graph.getLeafNodes()) {
FeatureVectorLeafNode fvLeaf = (FeatureVectorLeafNode) leaf;
FeatureVector[] fvs = fvLeaf.getFeatureVectors();
double[] dur = new double[fvs.length];
for (int i = 0; i < fvs.length; i++) {
dur[i] = unitFile.getUnit(fvs[i].getUnitIndex()).duration / (float) unitFile.getSampleRate();
}
double mean = MathUtils.mean(dur);
double stddev = MathUtils.standardDeviation(dur, mean);
FloatLeafNode floatLeaf = new FloatLeafNode(new float[] { (float) stddev, (float) mean });
Node mother = fvLeaf.getMother();
assert mother != null;
if (mother.isDecisionNode()) {
((DecisionNode) mother).replaceDaughter(floatLeaf, fvLeaf.getNodeIndex());
} else {
assert mother.isDirectedGraphNode();
assert ((DirectedGraphNode) mother).getLeafNode() == fvLeaf;
((DirectedGraphNode) mother).setLeafNode(floatLeaf);
}
}
writer.saveGraph(graph, getProp(DURTREE));
return true;
}
/**
* Provide the progress of computation, in percent, or -1 if that feature is not implemented.
*
* @return -1 if not implemented, or an integer between 0 and 100.
*/
public int getProgress() {
return -1;
}
public static void main(String[] args) throws Exception {
DurationTreeTrainer dct = new DurationTreeTrainer();
DatabaseLayout db = new DatabaseLayout(dct);
dct.compute();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy