All Downloads are FREE. Search and download functionalities are using the official Maven repository.

marytts.tools.voiceimport.DurationTreeTrainer Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2007 DFKI GmbH.
 * All Rights Reserved.  Use is subject to license terms.
 *
 * This file is part of MARY TTS.
 *
 * MARY TTS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, version 3 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see .
 *
 */
package marytts.tools.voiceimport;

import java.io.IOException;
import java.util.SortedMap;
import java.util.TreeMap;

import marytts.cart.DecisionNode;
import marytts.cart.DirectedGraph;
import marytts.cart.DirectedGraphNode;
import marytts.cart.LeafNode;
import marytts.cart.Node;
import marytts.cart.LeafNode.FeatureVectorLeafNode;
import marytts.cart.LeafNode.FloatLeafNode;
import marytts.cart.io.DirectedGraphWriter;
import marytts.exceptions.MaryConfigurationException;
import marytts.features.FeatureVector;
import marytts.tools.voiceimport.traintrees.AgglomerativeClusterer;
import marytts.tools.voiceimport.traintrees.DurationDistanceMeasure;
import marytts.unitselection.data.FeatureFileReader;
import marytts.unitselection.data.UnitFileReader;
import marytts.util.math.MathUtils;

/**
 * A class which converts a text file in festvox format into a one-file-per-utterance format in a given directory.
 * 
 * @author schroed
 *
 */
public class DurationTreeTrainer extends VoiceImportComponent {
	protected DatabaseLayout db = null;

	private final String name = "DurationTreeTrainer";
	public final String DURTREE = name + ".durTree";
	public final String FEATUREFILE = name + ".featureFile";
	public final String UNITFILE = name + ".unitFile";
	public final String MAXDATA = name + ".maxData";
	public final String PROPORTIONTESTDATA = name + ".propTestData";

	public String getName() {
		return name;
	}

	public SortedMap getDefaultProps(DatabaseLayout theDb) {
		this.db = theDb;
		if (props == null) {
			props = new TreeMap();
			String fileSeparator = System.getProperty("file.separator");
			props.put(FEATUREFILE, db.getProp(db.FILEDIR) + "phoneFeatures" + db.getProp(db.MARYEXT));
			props.put(UNITFILE, db.getProp(db.FILEDIR) + "phoneUnits" + db.getProp(db.MARYEXT));
			props.put(DURTREE, db.getProp(db.FILEDIR) + "dur.graph.mry");
			props.put(MAXDATA, "0");
			props.put(PROPORTIONTESTDATA, "0.1");
		}
		return props;
	}

	protected void setupHelp() {
		props2Help = new TreeMap();
		props2Help.put(FEATUREFILE, "file containing all phone units and their target cost features");
		props2Help.put(UNITFILE, "file containing all phone units");
		props2Help.put(DURTREE, "file containing the duration tree. Will be created by this module");
		props2Help.put(MAXDATA, "if >0, gives the maximum number of syllables to use for training the tree");
		props2Help.put(PROPORTIONTESTDATA,
				"the proportion of the data to use as test data (choose so that 1/value is an integer)");
	}

	@Override
	public boolean compute() throws IOException, MaryConfigurationException {
		logger.info("Duration tree trainer started.");
		FeatureFileReader featureFile = FeatureFileReader.getFeatureFileReader(getProp(FEATUREFILE));
		UnitFileReader unitFile = new UnitFileReader(getProp(UNITFILE));

		FeatureVector[] allFeatureVectors = featureFile.getFeatureVectors();
		int maxData = Integer.parseInt(getProp(MAXDATA));
		if (maxData == 0)
			maxData = allFeatureVectors.length;
		FeatureVector[] featureVectors = new FeatureVector[Math.min(maxData, allFeatureVectors.length)];
		System.arraycopy(allFeatureVectors, 0, featureVectors, 0, featureVectors.length);
		logger.debug("Total of " + allFeatureVectors.length + " feature vectors -- will use " + featureVectors.length);

		AgglomerativeClusterer clusterer = new AgglomerativeClusterer(featureVectors, featureFile.getFeatureDefinition(), null,
				new DurationDistanceMeasure(unitFile), Float.parseFloat(getProp(PROPORTIONTESTDATA)));
		DirectedGraphWriter writer = new DirectedGraphWriter();
		DirectedGraph graph;
		int iteration = 0;
		do {
			graph = clusterer.cluster();
			iteration++;
			if (graph != null) {
				writer.saveGraph(graph, getProp(DURTREE) + ".level" + iteration);
			}
		} while (clusterer.canClusterMore());

		if (graph == null) {
			return false;
		}

		// Now replace each leaf with a FloatLeafNode containing mean and stddev
		for (LeafNode leaf : graph.getLeafNodes()) {
			FeatureVectorLeafNode fvLeaf = (FeatureVectorLeafNode) leaf;
			FeatureVector[] fvs = fvLeaf.getFeatureVectors();
			double[] dur = new double[fvs.length];
			for (int i = 0; i < fvs.length; i++) {
				dur[i] = unitFile.getUnit(fvs[i].getUnitIndex()).duration / (float) unitFile.getSampleRate();
			}
			double mean = MathUtils.mean(dur);
			double stddev = MathUtils.standardDeviation(dur, mean);
			FloatLeafNode floatLeaf = new FloatLeafNode(new float[] { (float) stddev, (float) mean });
			Node mother = fvLeaf.getMother();
			assert mother != null;
			if (mother.isDecisionNode()) {
				((DecisionNode) mother).replaceDaughter(floatLeaf, fvLeaf.getNodeIndex());
			} else {
				assert mother.isDirectedGraphNode();
				assert ((DirectedGraphNode) mother).getLeafNode() == fvLeaf;
				((DirectedGraphNode) mother).setLeafNode(floatLeaf);
			}
		}
		writer.saveGraph(graph, getProp(DURTREE));
		return true;

	}

	/**
	 * Provide the progress of computation, in percent, or -1 if that feature is not implemented.
	 * 
	 * @return -1 if not implemented, or an integer between 0 and 100.
	 */
	public int getProgress() {
		return -1;
	}

	public static void main(String[] args) throws Exception {
		DurationTreeTrainer dct = new DurationTreeTrainer();
		DatabaseLayout db = new DatabaseLayout(dct);
		dct.compute();
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy