All Downloads are FREE. Search and download functionalities are using the official Maven repository.

marytts.tools.voiceimport.DurationCARTTrainer Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2007 DFKI GmbH.
 * All Rights Reserved.  Use is subject to license terms.
 *
 * This file is part of MARY TTS.
 *
 * MARY TTS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, version 3 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see .
 *
 */
package marytts.tools.voiceimport;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.SortedMap;
import java.util.TreeMap;

import marytts.cart.CART;
import marytts.cart.Node;
import marytts.cart.LeafNode.LeafType;
import marytts.cart.io.MaryCARTWriter;
import marytts.cart.io.WagonCARTReader;
import marytts.exceptions.MaryConfigurationException;
import marytts.features.FeatureDefinition;
import marytts.unitselection.data.FeatureFileReader;
import marytts.unitselection.data.HnmTimelineReader;
import marytts.unitselection.data.TimelineReader;
import marytts.unitselection.data.Unit;
import marytts.unitselection.data.UnitFileReader;

/**
 * A class which converts a text file in festvox format into a one-file-per-utterance format in a given directory.
 * 
 * @author schroed
 *
 */
public class DurationCARTTrainer extends VoiceImportComponent {
	protected File unitlabelDir;
	protected File unitfeatureDir;
	protected File durationDir;
	protected File durationFeatsFile;
	protected File durationDescFile;
	protected File wagonTreeFile;
	protected DatabaseLayout db = null;
	protected int percent = 0;
	protected boolean useStepwiseTraining = false;

	private final String name = "DurationCARTTrainer";
	public final String DURTREE = name + ".durTree";
	public final String STEPWISETRAINING = name + ".stepwiseTraining";
	public final String FEATUREFILE = name + ".featureFile";
	public final String UNITFILE = name + ".unitFile";
	public final String WAVETIMELINE = name + ".waveTimeline";
	public final String ISHNMTIMELINE = name + ".isHnmTimeline";

	public String getName() {
		return name;
	}

	@Override
	protected void initialiseComp() {
		this.unitlabelDir = new File(db.getProp(DatabaseLayout.PHONELABDIR));
		this.unitfeatureDir = new File(db.getProp(DatabaseLayout.PHONEFEATUREDIR));
		String durDir = db.getProp(DatabaseLayout.TEMPDIR);
		this.durationDir = new File(durDir);
		if (!durationDir.exists()) {
			System.out.print("temp dir " + durDir + " does not exist; ");
			if (!durationDir.mkdir()) {
				throw new Error("Could not create DURDIR");
			}
			System.out.print("Created successfully.\n");
		}
		this.durationFeatsFile = new File(durDir + "dur.feats");
		this.durationDescFile = new File(durDir + "dur.desc");
		this.wagonTreeFile = new File(durDir + "dur.tree");
		this.useStepwiseTraining = Boolean.valueOf(getProp(STEPWISETRAINING)).booleanValue();

	}

	public SortedMap getDefaultProps(DatabaseLayout dbl) {
		this.db = dbl;
		if (props == null) {
			props = new TreeMap();
			props.put(STEPWISETRAINING, "false");
			props.put(FEATUREFILE, db.getProp(DatabaseLayout.FILEDIR) + "phoneFeatures" + db.getProp(DatabaseLayout.MARYEXT));
			props.put(UNITFILE, db.getProp(DatabaseLayout.FILEDIR) + "phoneUnits" + db.getProp(DatabaseLayout.MARYEXT));
			props.put(WAVETIMELINE,
					db.getProp(DatabaseLayout.FILEDIR) + "timeline_waveforms" + db.getProp(DatabaseLayout.MARYEXT));
			props.put(ISHNMTIMELINE, "false");
			props.put(DURTREE, db.getProp(DatabaseLayout.FILEDIR) + "dur.tree");
		}
		return props;
	}

	protected void setupHelp() {
		props2Help = new TreeMap();
		props2Help.put(STEPWISETRAINING, "\"false\" or \"true\" ???????????????????????????????????????????????????????????");
		props2Help.put(FEATUREFILE, "file containing all phone units and their target cost features");
		props2Help.put(UNITFILE, "file containing all phone units");
		props2Help.put(WAVETIMELINE, "file containing all waveforms or models that can genarate them");
		props2Help.put(ISHNMTIMELINE, "file containing all wave files");
		props2Help.put(DURTREE, "file containing the duration CART. Will be created by this module");
	}

	@Override
	public boolean compute() throws IOException, MaryConfigurationException {
		FeatureFileReader featureFile = FeatureFileReader.getFeatureFileReader(getProp(FEATUREFILE));
		UnitFileReader unitFile = new UnitFileReader(getProp(UNITFILE));
		TimelineReader waveTimeline = null;
		if (getProp(ISHNMTIMELINE).compareToIgnoreCase("true") == 0)
			waveTimeline = new HnmTimelineReader(getProp(WAVETIMELINE));
		else
			waveTimeline = new TimelineReader(getProp(WAVETIMELINE));

		PrintWriter toFeaturesFile = new PrintWriter(new FileOutputStream(durationFeatsFile));
		System.out.println("Duration CART trainer: exporting duration features");

		FeatureDefinition featureDefinition = featureFile.getFeatureDefinition();
		int nUnits = 0;
		for (int i = 0, len = unitFile.getNumberOfUnits(); i < len; i++) {
			// We estimate that feature extraction takes 1/10 of the total time
			// (that's probably wrong, but never mind)
			percent = 10 * i / len;
			Unit u = unitFile.getUnit(i);
			float dur = u.duration / (float) unitFile.getSampleRate();
			if (dur >= 0.01) { // enforce a minimum duration for training data
				toFeaturesFile.println(dur + " " + featureDefinition.toFeatureString(featureFile.getFeatureVector(i)));
				nUnits++;
			}
		}
		if (useStepwiseTraining)
			percent = 1;
		else
			percent = 10;
		toFeaturesFile.close();
		System.out.println("Duration features extracted for " + nUnits + " units");

		PrintWriter toDesc = new PrintWriter(new FileOutputStream(durationDescFile));
		generateFeatureDescriptionForWagon(featureDefinition, toDesc);
		toDesc.close();

		boolean ok = false;
		// Now, call wagon
		WagonCaller wagonCaller = new WagonCaller(db.getProp(DatabaseLayout.ESTDIR), null);
		if (useStepwiseTraining) {
			// Split the data set in training and test part:
			// TODO: hardcoded path = EVIL
			Process traintest = Runtime.getRuntime().exec(
					"/project/mary/Festival/festvox/src/general/traintest " + durationFeatsFile.getAbsolutePath());
			try {
				traintest.waitFor();
			} catch (InterruptedException ie) {
			}
			ok = wagonCaller.callWagon("-data " + durationFeatsFile.getAbsolutePath() + ".train" + " -test "
					+ durationFeatsFile.getAbsolutePath() + ".test -stepwise" + " -desc " + durationDescFile.getAbsolutePath()
					+ " -stop 10 " + " -output " + wagonTreeFile.getAbsolutePath());
		} else {
			ok = wagonCaller.callWagon("-data " + durationFeatsFile.getAbsolutePath() + " -desc "
					+ durationDescFile.getAbsolutePath() + " -stop 10 " + " -output " + wagonTreeFile.getAbsolutePath());
		}
		if (ok) {
			String destinationFile = getProp(DURTREE);
			WagonCARTReader wagonDURReader = new WagonCARTReader(LeafType.FloatLeafNode);
			Node rootNode = wagonDURReader.load(new BufferedReader(new FileReader(wagonTreeFile)), featureDefinition);
			CART durCart = new CART(rootNode, featureDefinition);
			MaryCARTWriter wwdur = new MaryCARTWriter();
			wwdur.dumpMaryCART(durCart, destinationFile);
		}
		percent = 100;
		return ok;

	}

	private void generateFeatureDescriptionForWagon(FeatureDefinition fd, PrintWriter out) {
		out.println("(");
		out.println("(segment_duration float)");
		int nDiscreteFeatures = fd.getNumberOfByteFeatures() + fd.getNumberOfShortFeatures();
		for (int i = 0, n = fd.getNumberOfFeatures(); i < n; i++) {
			out.print("( ");
			out.print(fd.getFeatureName(i));
			if (i < nDiscreteFeatures) { // list values
				if (fd.getNumberOfValues(i) == 20 && fd.getFeatureValueAsString(i, 19).equals("19")) {
					// one of our pseudo-floats
					out.println(" float )");
				} else { // list the values
					for (int v = 0, vmax = fd.getNumberOfValues(i); v < vmax; v++) {
						out.print("  ");
						String val = fd.getFeatureValueAsString(i, v);
						if (val.indexOf('"') != -1) {
							StringBuilder buf = new StringBuilder();
							for (int c = 0; c < val.length(); c++) {
								char ch = val.charAt(c);
								if (ch == '"')
									buf.append("\\\"");
								else
									buf.append(ch);
							}
							val = buf.toString();
						}
						out.print("\"" + val + "\"");
					}
					out.println(" )");
				}
			} else { // float feature
				out.println(" float )");
			}

		}
		out.println(")");
	}

	/**
	 * Provide the progress of computation, in percent, or -1 if that feature is not implemented.
	 * 
	 * @return -1 if not implemented, or an integer between 0 and 100.
	 */
	public int getProgress() {
		return percent;
	}

	public static void main(String[] args) throws Exception {
		DurationCARTTrainer dct = new DurationCARTTrainer();
		DatabaseLayout db = new DatabaseLayout(dct);
		dct.compute();
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy