marytts.tools.voiceimport.traintrees.Wagon Maven / Gradle / Ivy
The newest version!
/**
* Copyright 2009 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*
*/
package marytts.tools.voiceimport.traintrees;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.HashSet;
import java.util.Locale;
import java.util.Set;
import marytts.cart.CART;
import marytts.cart.LeafNode;
import marytts.cart.LeafNode.LeafType;
import marytts.cart.io.WagonCARTReader;
import marytts.features.FeatureDefinition;
import marytts.features.FeatureVector;
import marytts.util.MaryUtils;
import org.apache.log4j.Logger;
/**
* A class providing the functionality to interface with an external wagon process.
*
* @author marc
*
*/
public class Wagon implements Runnable {
private static File wagonExecutable;
public static void setWagonExecutable(File wagonExe) {
wagonExecutable = wagonExe;
}
private Logger logger;
private String id;
private FeatureDefinition featureDefinition;
private FeatureVector[] fv;
private DistanceMeasure distMeasure;
private File distFile;
private File descFile;
private File featFile;
private File cartFile;
private String systemCall;
private boolean finished = false;
private boolean success = false;
private CART cart = null;
/**
* Set up a new wagon process. Wagon.setWagonExecutable() must be called beforehand.
*
* @param id
* id
* @param featureDefinition
* featureDefinition
* @param featureVectors
* featureVectors
* @param aDistanceMeasure
* aDistanceMeasure
* @param dir
* dir
* @param balance
* balance
* @param stop
* stop
* @throws IOException
* if there was no call to Wagon.setWagonExecutable() with an executable file before calling this constructor.
*/
public Wagon(String id, FeatureDefinition featureDefinition, FeatureVector[] featureVectors,
DistanceMeasure aDistanceMeasure, File dir, int balance, int stop) throws IOException {
if (wagonExecutable == null || !wagonExecutable.isFile()) {
throw new IOException("No wagon executable set using Wagon.setExecutable()!");
}
this.logger = MaryUtils.getLogger("Wagon");
this.id = id;
this.featureDefinition = featureDefinition;
this.fv = featureVectors;
this.distMeasure = aDistanceMeasure;
this.distFile = new File(dir, id + ".dist");
this.descFile = new File(dir, id + ".desc");
this.featFile = new File(dir, id + ".feat");
this.cartFile = new File(dir, id + ".cart");
this.systemCall = wagonExecutable.getAbsolutePath() + " -desc " + descFile.getAbsolutePath() + " -data "
+ featFile.getAbsolutePath() + " -balance " + balance + " -distmatrix " + distFile.getAbsolutePath() + " -stop "
+ stop + " -output " + cartFile.getAbsolutePath();
}
/**
* Export this feature definition in the "all.desc" format which can be read by wagon.
*
* @throws IOException
* IOException
*/
private void createDescFile() throws IOException {
PrintWriter out = new PrintWriter(new FileOutputStream(descFile));
Set featuresToIgnore = new HashSet();
featuresToIgnore.add("unit_logf0");
featuresToIgnore.add("unit_duration");
int numDiscreteFeatures = featureDefinition.getNumberOfByteFeatures() + featureDefinition.getNumberOfShortFeatures();
out.println("(");
out.println("(occurid cluster)");
for (int i = 0, n = featureDefinition.getNumberOfFeatures(); i < n; i++) {
out.print("( ");
String featureName = featureDefinition.getFeatureName(i);
out.print(featureName);
if (featuresToIgnore != null && featuresToIgnore.contains(featureName)) {
out.print(" ignore");
}
if (i < numDiscreteFeatures) { // list values
for (int v = 0, vmax = featureDefinition.getNumberOfValues(i); v < vmax; v++) {
out.print(" ");
// Print values surrounded by double quotes, and make sure any
// double quotes in the value are preceded by a backslash --
// otherwise, we get problems e.g. for sentence_punc
String val = featureDefinition.getFeatureValueAsString(i, v);
if (val.indexOf('"') != -1) {
StringBuilder buf = new StringBuilder();
for (int c = 0; c < val.length(); c++) {
char ch = val.charAt(c);
if (ch == '"')
buf.append("\\\"");
else
buf.append(ch);
}
val = buf.toString();
}
out.print("\"" + val + "\"");
}
out.println(" )");
} else { // float feature
out.println(" float )");
}
}
out.println(")");
out.close();
}
private void dumpFeatureVectors() throws IOException {
// open file
PrintWriter out = new PrintWriter(new BufferedOutputStream(new FileOutputStream(featFile)));
for (int i = 0; i < fv.length; i++) {
// Print the feature string
out.print(i + " " + featureDefinition.toFeatureString(fv[i]));
// print a newline if this is not the last vector
if (i + 1 != fv.length) {
out.print("\n");
}
}
// dump and close
out.flush();
out.close();
}
/*
* Save in "ancient" text format. Extremely inefficient for large files. Keeping this for documentation purposes only.
*/
private void saveDistanceMatrix() throws IOException {
PrintWriter out = new PrintWriter(new BufferedOutputStream(new FileOutputStream(distFile)));
for (int i = 0; i < fv.length; i++) {
for (int j = 0; j < fv.length; j++) {
float distance = (i == j ? 0f : distMeasure.squaredDistance(fv[i], fv[j]));
out.printf(Locale.US, "%.1f ", distance);
}
out.print("\n");
}
out.flush();
out.close();
}
/* Save in efficient binary format. */
private void binarySaveDistanceMatrix() throws IOException {
DataOutputStream out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(distFile)));
out.writeBytes("EST_File fmatrix\n");
out.writeBytes("version 1\n");
out.writeBytes("DataType binary\n");
out.writeBytes("ByteOrder BigEndian\n");
out.writeBytes("rows " + fv.length + "\n");
out.writeBytes("columns " + fv.length + "\n");
out.writeBytes("EST_Header_End\n");
for (int i = 0; i < fv.length; i++) {
for (int j = 0; j < fv.length; j++) {
float distance = (i == j ? 0f : distMeasure.squaredDistance(fv[i], fv[j]));
out.writeFloat(distance);
}
}
out.flush();
out.close();
}
public void run() {
try {
long startTime = System.currentTimeMillis();
logger.debug(id + "> Creating " + descFile.getName());
createDescFile();
logger.debug(id + "> Dumping features to " + featFile.getName());
dumpFeatureVectors();
logger.debug(id + "> Dumping distance matrix to " + distFile.getName());
binarySaveDistanceMatrix();
logger.debug(id + "> Calling wagon as follows:");
logger.debug(systemCall);
Process p = Runtime.getRuntime().exec(systemCall);
// collect the output
// read from error stream
StreamGobbler errorGobbler = new StreamGobbler(p.getErrorStream(), id + " err");
// read from output stream
StreamGobbler outputGobbler = new StreamGobbler(p.getInputStream(), id + " out");
// start reading from the streams
errorGobbler.start();
outputGobbler.start();
p.waitFor();
if (p.exitValue() != 0) {
finished = true;
success = false;
} else {
success = true;
logger.debug(id + "> Wagon call took " + (System.currentTimeMillis() - startTime) + " ms");
// read in the resulting CART
logger.debug(id + "> Reading CART");
BufferedReader buf = new BufferedReader(new FileReader(cartFile));
WagonCARTReader wagonReader = new WagonCARTReader(LeafType.IntAndFloatArrayLeafNode);
cart = new CART(wagonReader.load(buf, featureDefinition), featureDefinition);
buf.close();
// Fix the new cart's leaves:
// They are currently the index numbers in featureVectors;
// but what we need is the unit index numbers!
for (LeafNode leaf : cart.getLeafNodes()) {
int[] data = (int[]) leaf.getAllData();
for (int i = 0; i < data.length; i++) {
data[i] = fv[data[i]].getUnitIndex();
}
}
logger.debug(id + "> completed in " + (System.currentTimeMillis() - startTime) + " ms");
finished = true;
}
if (!Boolean.getBoolean("wagon.keepfiles")) {
featFile.delete();
distFile.delete();
}
} catch (Exception e) {
e.printStackTrace();
finished = true;
success = false;
throw new RuntimeException("Exception running wagon");
}
}
public boolean finished() {
return finished;
}
public boolean success() {
return success;
}
public String id() {
return id;
}
public CART getCART() {
return cart;
}
static class StreamGobbler extends Thread {
InputStream is;
String type;
StreamGobbler(InputStream is, String type) {
this.is = is;
this.type = type;
}
public void run() {
try {
InputStreamReader isr = new InputStreamReader(is);
BufferedReader br = new BufferedReader(isr);
String line = null;
while ((line = br.readLine()) != null)
System.out.println(type + ">" + line);
} catch (IOException ioe) {
ioe.printStackTrace();
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy