All Downloads are FREE. Search and download functionalities are using the official Maven repository.

marytts.tools.voiceimport.traintrees.Wagon Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2009 DFKI GmbH.
 * All Rights Reserved.  Use is subject to license terms.
 *
 * This file is part of MARY TTS.
 *
 * MARY TTS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, version 3 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see .
 *
 */

package marytts.tools.voiceimport.traintrees;

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.HashSet;
import java.util.Locale;
import java.util.Set;

import marytts.cart.CART;
import marytts.cart.LeafNode;
import marytts.cart.LeafNode.LeafType;
import marytts.cart.io.WagonCARTReader;
import marytts.features.FeatureDefinition;
import marytts.features.FeatureVector;
import marytts.util.MaryUtils;

import org.apache.log4j.Logger;

/**
 * A class providing the functionality to interface with an external wagon process.
 * 
 * @author marc
 * 
 */
public class Wagon implements Runnable {
	private static File wagonExecutable;

	public static void setWagonExecutable(File wagonExe) {
		wagonExecutable = wagonExe;
	}

	private Logger logger;
	private String id;
	private FeatureDefinition featureDefinition;
	private FeatureVector[] fv;
	private DistanceMeasure distMeasure;
	private File distFile;
	private File descFile;
	private File featFile;
	private File cartFile;
	private String systemCall;
	private boolean finished = false;
	private boolean success = false;
	private CART cart = null;

	/**
	 * Set up a new wagon process. Wagon.setWagonExecutable() must be called beforehand.
	 * 
	 * @param id
	 *            id
	 * @param featureDefinition
	 *            featureDefinition
	 * @param featureVectors
	 *            featureVectors
	 * @param aDistanceMeasure
	 *            aDistanceMeasure
	 * @param dir
	 *            dir
	 * @param balance
	 *            balance
	 * @param stop
	 *            stop
	 * @throws IOException
	 *             if there was no call to Wagon.setWagonExecutable() with an executable file before calling this constructor.
	 */
	public Wagon(String id, FeatureDefinition featureDefinition, FeatureVector[] featureVectors,
			DistanceMeasure aDistanceMeasure, File dir, int balance, int stop) throws IOException {
		if (wagonExecutable == null || !wagonExecutable.isFile()) {
			throw new IOException("No wagon executable set using Wagon.setExecutable()!");
		}
		this.logger = MaryUtils.getLogger("Wagon");
		this.id = id;
		this.featureDefinition = featureDefinition;
		this.fv = featureVectors;
		this.distMeasure = aDistanceMeasure;
		this.distFile = new File(dir, id + ".dist");
		this.descFile = new File(dir, id + ".desc");
		this.featFile = new File(dir, id + ".feat");
		this.cartFile = new File(dir, id + ".cart");
		this.systemCall = wagonExecutable.getAbsolutePath() + " -desc " + descFile.getAbsolutePath() + " -data "
				+ featFile.getAbsolutePath() + " -balance " + balance + " -distmatrix " + distFile.getAbsolutePath() + " -stop "
				+ stop + " -output " + cartFile.getAbsolutePath();
	}

	/**
	 * Export this feature definition in the "all.desc" format which can be read by wagon.
	 * 
	 * @throws IOException
	 *             IOException
	 */
	private void createDescFile() throws IOException {
		PrintWriter out = new PrintWriter(new FileOutputStream(descFile));
		Set featuresToIgnore = new HashSet();
		featuresToIgnore.add("unit_logf0");
		featuresToIgnore.add("unit_duration");

		int numDiscreteFeatures = featureDefinition.getNumberOfByteFeatures() + featureDefinition.getNumberOfShortFeatures();
		out.println("(");
		out.println("(occurid cluster)");
		for (int i = 0, n = featureDefinition.getNumberOfFeatures(); i < n; i++) {
			out.print("( ");
			String featureName = featureDefinition.getFeatureName(i);
			out.print(featureName);
			if (featuresToIgnore != null && featuresToIgnore.contains(featureName)) {
				out.print(" ignore");
			}
			if (i < numDiscreteFeatures) { // list values
				for (int v = 0, vmax = featureDefinition.getNumberOfValues(i); v < vmax; v++) {
					out.print("  ");
					// Print values surrounded by double quotes, and make sure any
					// double quotes in the value are preceded by a backslash --
					// otherwise, we get problems e.g. for sentence_punc
					String val = featureDefinition.getFeatureValueAsString(i, v);
					if (val.indexOf('"') != -1) {
						StringBuilder buf = new StringBuilder();
						for (int c = 0; c < val.length(); c++) {
							char ch = val.charAt(c);
							if (ch == '"')
								buf.append("\\\"");
							else
								buf.append(ch);
						}
						val = buf.toString();
					}
					out.print("\"" + val + "\"");
				}
				out.println(" )");
			} else { // float feature
				out.println(" float )");
			}
		}
		out.println(")");
		out.close();
	}

	private void dumpFeatureVectors() throws IOException {
		// open file
		PrintWriter out = new PrintWriter(new BufferedOutputStream(new FileOutputStream(featFile)));
		for (int i = 0; i < fv.length; i++) {
			// Print the feature string
			out.print(i + " " + featureDefinition.toFeatureString(fv[i]));
			// print a newline if this is not the last vector
			if (i + 1 != fv.length) {
				out.print("\n");
			}
		}
		// dump and close
		out.flush();
		out.close();
	}

	/*
	 * Save in "ancient" text format. Extremely inefficient for large files. Keeping this for documentation purposes only.
	 */
	private void saveDistanceMatrix() throws IOException {
		PrintWriter out = new PrintWriter(new BufferedOutputStream(new FileOutputStream(distFile)));
		for (int i = 0; i < fv.length; i++) {
			for (int j = 0; j < fv.length; j++) {
				float distance = (i == j ? 0f : distMeasure.squaredDistance(fv[i], fv[j]));
				out.printf(Locale.US, "%.1f ", distance);
			}
			out.print("\n");
		}
		out.flush();
		out.close();
	}

	/* Save in efficient binary format. */
	private void binarySaveDistanceMatrix() throws IOException {
		DataOutputStream out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(distFile)));
		out.writeBytes("EST_File fmatrix\n");
		out.writeBytes("version 1\n");
		out.writeBytes("DataType binary\n");
		out.writeBytes("ByteOrder BigEndian\n");
		out.writeBytes("rows " + fv.length + "\n");
		out.writeBytes("columns " + fv.length + "\n");
		out.writeBytes("EST_Header_End\n");
		for (int i = 0; i < fv.length; i++) {
			for (int j = 0; j < fv.length; j++) {
				float distance = (i == j ? 0f : distMeasure.squaredDistance(fv[i], fv[j]));
				out.writeFloat(distance);
			}
		}
		out.flush();
		out.close();
	}

	public void run() {
		try {
			long startTime = System.currentTimeMillis();

			logger.debug(id + "> Creating " + descFile.getName());
			createDescFile();

			logger.debug(id + "> Dumping features to " + featFile.getName());
			dumpFeatureVectors();

			logger.debug(id + "> Dumping distance matrix to " + distFile.getName());
			binarySaveDistanceMatrix();

			logger.debug(id + "> Calling wagon as follows:");
			logger.debug(systemCall);
			Process p = Runtime.getRuntime().exec(systemCall);
			// collect the output
			// read from error stream
			StreamGobbler errorGobbler = new StreamGobbler(p.getErrorStream(), id + " err");

			// read from output stream
			StreamGobbler outputGobbler = new StreamGobbler(p.getInputStream(), id + " out");
			// start reading from the streams
			errorGobbler.start();
			outputGobbler.start();
			p.waitFor();
			if (p.exitValue() != 0) {
				finished = true;
				success = false;
			} else {
				success = true;
				logger.debug(id + "> Wagon call took " + (System.currentTimeMillis() - startTime) + " ms");

				// read in the resulting CART
				logger.debug(id + "> Reading CART");
				BufferedReader buf = new BufferedReader(new FileReader(cartFile));
				WagonCARTReader wagonReader = new WagonCARTReader(LeafType.IntAndFloatArrayLeafNode);
				cart = new CART(wagonReader.load(buf, featureDefinition), featureDefinition);
				buf.close();

				// Fix the new cart's leaves:
				// They are currently the index numbers in featureVectors;
				// but what we need is the unit index numbers!
				for (LeafNode leaf : cart.getLeafNodes()) {
					int[] data = (int[]) leaf.getAllData();
					for (int i = 0; i < data.length; i++) {
						data[i] = fv[data[i]].getUnitIndex();
					}
				}

				logger.debug(id + "> completed in " + (System.currentTimeMillis() - startTime) + " ms");
				finished = true;
			}
			if (!Boolean.getBoolean("wagon.keepfiles")) {
				featFile.delete();
				distFile.delete();
			}

		} catch (Exception e) {
			e.printStackTrace();
			finished = true;
			success = false;
			throw new RuntimeException("Exception running wagon");
		}

	}

	public boolean finished() {
		return finished;
	}

	public boolean success() {
		return success;
	}

	public String id() {
		return id;
	}

	public CART getCART() {
		return cart;
	}

	static class StreamGobbler extends Thread {
		InputStream is;
		String type;

		StreamGobbler(InputStream is, String type) {
			this.is = is;
			this.type = type;
		}

		public void run() {
			try {
				InputStreamReader isr = new InputStreamReader(is);
				BufferedReader br = new BufferedReader(isr);
				String line = null;
				while ((line = br.readLine()) != null)
					System.out.println(type + ">" + line);
			} catch (IOException ioe) {
				ioe.printStackTrace();
			}
		}
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy