marytts.cart.io.MaryCARTWriter Maven / Gradle / Ivy
The newest version!
/**
* Copyright 2006 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*
*/
package marytts.cart.io;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Properties;
import marytts.cart.CART;
import marytts.cart.DecisionNode;
import marytts.cart.LeafNode;
import marytts.cart.Node;
import marytts.cart.DecisionNode.BinaryByteDecisionNode;
import marytts.cart.DecisionNode.BinaryFloatDecisionNode;
import marytts.cart.DecisionNode.BinaryShortDecisionNode;
import marytts.cart.LeafNode.FeatureVectorLeafNode;
import marytts.cart.LeafNode.FloatLeafNode;
import marytts.cart.LeafNode.IntAndFloatArrayLeafNode;
import marytts.cart.LeafNode.IntArrayLeafNode;
import marytts.cart.LeafNode.LeafType;
import marytts.features.FeatureVector;
import marytts.util.MaryUtils;
import marytts.util.data.MaryHeader;
import org.apache.log4j.Logger;
/**
* IO functions for CARTs in MaryCART format
*
* @author Marcela Charfuelan
*/
public class MaryCARTWriter {
protected Logger logger = MaryUtils.getLogger(this.getClass().getName());
/**
* Dump the CARTs in MaryCART format
*
* @param cart
* cart
* @param destFile
* the destination file
* @throws IOException
* IOException
*/
public void dumpMaryCART(CART cart, String destFile) throws IOException {
if (cart == null)
throw new NullPointerException("Cannot dump null CART");
if (destFile == null)
throw new NullPointerException("No destination file");
logger.debug("Dumping CART in MaryCART format to " + destFile + " ...");
// Open the destination file (cart.bin) and output the header
DataOutputStream out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(destFile)));
// create new CART-header and write it to output file
MaryHeader hdr = new MaryHeader(MaryHeader.CARTS);
hdr.writeTo(out);
Properties props = cart.getProperties();
if (props == null) {
out.writeShort(0);
} else {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
props.store(baos, null);
byte[] propData = baos.toByteArray();
out.writeShort(propData.length);
out.write(propData);
}
// feature definition
cart.getFeatureDefinition().writeBinaryTo(out);
// dump CART
dumpBinary(cart.getRootNode(), out);
// finish
out.close();
logger.debug(" ... done\n");
}
public void toTextOut(CART cart, PrintWriter pw) throws IOException {
try {
int id[] = new int[2];
id[0] = 0; // number of decision nodes
id[1] = 0; // number of leaf nodes
// System.out.println("Total number of nodes:" + rootNode.getNumberOfNodes());
setUniqueNodeId(cart.getRootNode(), id);
pw.println("Num decision nodes= " + id[0] + " Num leaf nodes= " + id[1]);
printDecisionNodes(cart.getRootNode(), null, pw);
pw.println("\n----------------\n");
printLeafNodes(cart.getRootNode(), null, pw);
pw.flush();
pw.close();
} catch (IOException ioe) {
IOException newIOE = new IOException("Error dumping CART to standard output");
newIOE.initCause(ioe);
throw newIOE;
}
}
private void setUniqueNodeId(Node node, int id[]) throws IOException {
int thisIdNode;
String leafstr = "";
// if the node is decision node
if (node.getNumberOfNodes() > 1) {
assert node instanceof DecisionNode;
DecisionNode decNode = (DecisionNode) node;
id[0]--;
decNode.setUniqueDecisionNodeId(id[0]);
String strNode = "";
// this.decNodeStr = "";
thisIdNode = id[0];
// add Ids to the daughters
for (int i = 0; i < decNode.getNumberOfDaugthers(); i++) {
setUniqueNodeId(decNode.getDaughter(i), id);
}
} else { // the node is a leaf node
assert node instanceof LeafNode;
LeafNode leaf = (LeafNode) node;
if (leaf.isEmpty()) {
leaf.setUniqueLeafId(0);
} else {
id[1]++;
leaf.setUniqueLeafId(id[1]);
}
}
}
private void dumpBinary(Node rootNode, DataOutput os) throws IOException {
try {
int id[] = new int[2];
id[0] = 0; // number of decision nodes
id[1] = 0; // number of leaf nodes
// first add unique identifiers to decision nodes and leaf nodes
setUniqueNodeId(rootNode, id);
// write the number of decision nodes
os.writeInt(Math.abs(id[0]));
// lines that start with a negative number are decision nodes
printDecisionNodes(rootNode, os, null);
// write the number of leaves.
os.writeInt(id[1]);
// lines that start with id are leaf nodes
printLeafNodes(rootNode, (DataOutputStream) os, null);
} catch (IOException ioe) {
IOException newIOE = new IOException("Error dumping CART to output stream");
newIOE.initCause(ioe);
throw newIOE;
}
}
private void printDecisionNodes(Node node, DataOutput out, PrintWriter pw) throws IOException {
if (!(node instanceof DecisionNode))
return; // nothing to do here
DecisionNode decNode = (DecisionNode) node;
int id = decNode.getUniqueDecisionNodeId();
String nodeDefinition = decNode.getNodeDefinition();
int featureIndex = decNode.getFeatureIndex();
DecisionNode.Type nodeType = decNode.getDecisionNodeType();
if (out != null) {
// dump in binary form to output
out.writeInt(featureIndex);
out.writeInt(nodeType.ordinal());
// Now, questionValue, which depends on nodeType
switch (nodeType) {
case BinaryByteDecisionNode:
out.writeInt(((BinaryByteDecisionNode) decNode).getCriterionValueAsByte());
assert decNode.getNumberOfDaugthers() == 2;
break;
case BinaryShortDecisionNode:
out.writeInt(((BinaryShortDecisionNode) decNode).getCriterionValueAsShort());
assert decNode.getNumberOfDaugthers() == 2;
break;
case BinaryFloatDecisionNode:
out.writeFloat(((BinaryFloatDecisionNode) decNode).getCriterionValueAsFloat());
assert decNode.getNumberOfDaugthers() == 2;
break;
case ByteDecisionNode:
case ShortDecisionNode:
out.writeInt(decNode.getNumberOfDaugthers());
}
// The child nodes
for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) {
Node daughter = decNode.getDaughter(i);
if (daughter instanceof DecisionNode) {
out.writeInt(((DecisionNode) daughter).getUniqueDecisionNodeId());
} else {
assert daughter instanceof LeafNode;
out.writeInt(((LeafNode) daughter).getUniqueLeafId());
}
}
}
if (pw != null) {
// dump to print writer
StringBuilder strNode = new StringBuilder(id + " " + nodeDefinition);
for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) {
strNode.append(" ");
Node daughter = decNode.getDaughter(i);
if (daughter instanceof DecisionNode) {
strNode.append(((DecisionNode) daughter).getUniqueDecisionNodeId());
} else {
assert daughter instanceof LeafNode;
strNode.append("id").append(((LeafNode) daughter).getUniqueLeafId());
}
}
pw.println(strNode.toString());
}
// add the daughters
for (int i = 0; i < ((DecisionNode) node).getNumberOfDaugthers(); i++) {
if (((DecisionNode) node).getDaughter(i).getNumberOfNodes() > 1)
printDecisionNodes(((DecisionNode) node).getDaughter(i), out, pw);
}
}
/** This function will print the leaf nodes only, but it goes through all the decision nodes. */
private void printLeafNodes(Node node, DataOutput out, PrintWriter pw) throws IOException {
// If the node does not have leaves then it just return.
// I we are in a decision node then print the leaves of the daughters.
Node nextNode;
if (node.getNumberOfNodes() > 1) {
assert node instanceof DecisionNode;
DecisionNode decNode = (DecisionNode) node;
for (int i = 0; i < decNode.getNumberOfDaugthers(); i++) {
nextNode = decNode.getDaughter(i);
printLeafNodes(nextNode, out, pw);
}
} else {
assert node instanceof LeafNode;
LeafNode leaf = (LeafNode) node;
if (leaf.getUniqueLeafId() == 0) // empty leaf, do not write
return;
LeafType leafType = leaf.getLeafNodeType();
if (leafType == LeafType.FeatureVectorLeafNode) {
leafType = LeafType.IntArrayLeafNode;
// save feature vector leaf nodes as int array leaf nodes
}
if (out != null) {
// Leaf node type
out.writeInt(leafType.ordinal());
}
if (pw != null) {
pw.print("id" + leaf.getUniqueLeafId() + " " + leafType);
}
switch (leaf.getLeafNodeType()) {
case IntArrayLeafNode:
int data[] = ((IntArrayLeafNode) leaf).getIntData();
// Number of data points following:
if (out != null)
out.writeInt(data.length);
if (pw != null)
pw.print(" " + data.length);
// for each index, write the index
for (int i = 0; i < data.length; i++) {
if (out != null)
out.writeInt(data[i]);
if (pw != null)
pw.print(" " + data[i]);
}
break;
case FloatLeafNode:
float stddev = ((FloatLeafNode) leaf).getStDeviation();
float mean = ((FloatLeafNode) leaf).getMean();
if (out != null) {
out.writeFloat(stddev);
out.writeFloat(mean);
}
if (pw != null) {
pw.print(" 1 " + stddev + " " + mean);
}
break;
case IntAndFloatArrayLeafNode:
case StringAndFloatLeafNode:
int data1[] = ((IntAndFloatArrayLeafNode) leaf).getIntData();
float floats[] = ((IntAndFloatArrayLeafNode) leaf).getFloatData();
// Number of data points following:
if (out != null)
out.writeInt(data1.length);
if (pw != null)
pw.print(" " + data1.length);
// for each index, write the index and then its float
for (int i = 0; i < data1.length; i++) {
if (out != null) {
out.writeInt(data1[i]);
out.writeFloat(floats[i]);
}
if (pw != null)
pw.print(" " + data1[i] + " " + floats[i]);
}
break;
case FeatureVectorLeafNode:
FeatureVector fv[] = ((FeatureVectorLeafNode) leaf).getFeatureVectors();
// Number of data points following:
if (out != null)
out.writeInt(fv.length);
if (pw != null)
pw.print(" " + fv.length);
// for each feature vector, write the index
for (int i = 0; i < fv.length; i++) {
if (out != null)
out.writeInt(fv[i].getUnitIndex());
if (pw != null)
pw.print(" " + fv[i].getUnitIndex());
}
break;
case PdfLeafNode:
throw new IllegalArgumentException("Writing of pdf leaf nodes not yet implemented");
}
if (pw != null)
pw.println();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy