All Downloads are FREE. Search and download functionalities are using the official Maven repository.

marytts.cart.io.DirectedGraphWriter Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2006 DFKI GmbH.
 * All Rights Reserved.  Use is subject to license terms.
 *
 * This file is part of MARY TTS.
 *
 * MARY TTS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, version 3 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see .
 *
 */
package marytts.cart.io;

import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Properties;

import marytts.cart.DecisionNode;
import marytts.cart.DirectedGraph;
import marytts.cart.DirectedGraphNode;
import marytts.cart.LeafNode;
import marytts.cart.Node;
import marytts.cart.DecisionNode.BinaryByteDecisionNode;
import marytts.cart.DecisionNode.BinaryFloatDecisionNode;
import marytts.cart.DecisionNode.BinaryShortDecisionNode;
import marytts.cart.LeafNode.FeatureVectorLeafNode;
import marytts.cart.LeafNode.FloatLeafNode;
import marytts.cart.LeafNode.IntAndFloatArrayLeafNode;
import marytts.cart.LeafNode.IntArrayLeafNode;
import marytts.cart.LeafNode.LeafType;
import marytts.features.FeatureVector;
import marytts.util.MaryUtils;
import marytts.util.data.MaryHeader;

import org.apache.log4j.Logger;

/**
 * IO functions for directed graphs in Mary format
 * 
 * @author Marcela Charfuelan, Marc Schröder
 */
public class DirectedGraphWriter {

	protected Logger logger = MaryUtils.getLogger(this.getClass().getName());

	/**
	 * Dump the graph in Mary format
	 * 
	 * @param graph
	 *            graph
	 * @param destFile
	 *            the destination file
	 * @throws IOException
	 *             IOException
	 */
	public void saveGraph(DirectedGraph graph, String destFile) throws IOException {
		if (graph == null)
			throw new NullPointerException("Cannot dump null graph");
		if (destFile == null)
			throw new NullPointerException("No destination file");

		logger.debug("Dumping directed graph in Mary format to " + destFile + " ...");

		// Open the destination file and output the header
		DataOutputStream out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(destFile)));
		// create new CART-header and write it to output file
		MaryHeader hdr = new MaryHeader(MaryHeader.DIRECTED_GRAPH);
		hdr.writeTo(out);

		Properties props = graph.getProperties();
		if (props == null) {
			out.writeShort(0);
		} else {
			ByteArrayOutputStream baos = new ByteArrayOutputStream();
			props.store(baos, null);
			byte[] propData = baos.toByteArray();
			out.writeShort(propData.length);
			out.write(propData);
		}

		// feature definition
		graph.getFeatureDefinition().writeBinaryTo(out);

		// dump graph
		dumpBinary(graph, out);

		// finish
		out.close();
		logger.debug(" ... done\n");
	}

	public void toTextOut(DirectedGraph graph, PrintWriter pw) throws IOException {
		try {
			int numLeafNodes = setUniqueLeafNodeIds(graph);
			int numDecNodes = setUniqueDecisionNodeIds(graph);
			int numGraphNodes = setUniqueDirectedGraphNodeIds(graph);
			pw.println("Num decision nodes= " + numDecNodes + "  Num leaf nodes= " + numLeafNodes
					+ "  Num directed graph nodes= " + numGraphNodes);
			printDecisionNodes(graph, null, pw);
			pw.println("\n----------------\n");
			printLeafNodes(graph, null, pw);
			pw.println("\n----------------\n");
			printDirectedGraphNodes(graph, null, pw);

			pw.flush();
			pw.close();
		} catch (IOException ioe) {
			IOException newIOE = new IOException("Error dumping graph to standard output");
			newIOE.initCause(ioe);
			throw newIOE;
		}
	}

	/**
	 * Assign unique ids to leaf nodes.
	 * 
	 * @param graph
	 * @return the number of different leaf nodes
	 */
	private int setUniqueLeafNodeIds(DirectedGraph graph) {
		int i = 0;
		for (LeafNode l : graph.getLeafNodes()) {
			l.setUniqueLeafId(++i);
		}
		return i;
	}

	/**
	 * Assign unique ids to decision nodes.
	 * 
	 * @param graph
	 * @return the number of different decision nodes
	 */
	private int setUniqueDecisionNodeIds(DirectedGraph graph) {
		int i = 0;
		for (DecisionNode d : graph.getDecisionNodes()) {
			d.setUniqueDecisionNodeId(++i);
		}
		return i;
	}

	/**
	 * Assign unique ids to directed graph nodes.
	 * 
	 * @param graph
	 * @return the number of different directed graph nodes
	 */
	private int setUniqueDirectedGraphNodeIds(DirectedGraph graph) {
		int i = 0;
		for (DirectedGraphNode g : graph.getDirectedGraphNodes()) {
			g.setUniqueGraphNodeID(++i);
		}
		return i;
	}

	private void dumpBinary(DirectedGraph graph, DataOutput os) throws IOException {
		try {
			int numLeafNodes = setUniqueLeafNodeIds(graph);
			int numDecNodes = setUniqueDecisionNodeIds(graph);
			int numGraphNodes = setUniqueDirectedGraphNodeIds(graph);
			int maxNum = 1 << 30;
			if (numLeafNodes > maxNum || numDecNodes > maxNum || numGraphNodes > maxNum) {
				throw new UnsupportedOperationException("Cannot write more than " + maxNum + " nodes of one type in this format");
			}
			// write the number of decision nodes
			os.writeInt(numDecNodes);
			printDecisionNodes(graph, os, null);

			// write the number of leaves.
			os.writeInt(numLeafNodes);
			printLeafNodes(graph, os, null);

			// write the number of directed graph nodes
			os.writeInt(numGraphNodes);
			printDirectedGraphNodes(graph, os, null);

		} catch (IOException ioe) {
			IOException newIOE = new IOException("Error dumping CART to output stream");
			newIOE.initCause(ioe);
			throw newIOE;
		}
	}

	private void printDecisionNodes(DirectedGraph graph, DataOutput out, PrintWriter pw) throws IOException {
		for (DecisionNode decNode : graph.getDecisionNodes()) {
			int id = decNode.getUniqueDecisionNodeId();
			String nodeDefinition = decNode.getNodeDefinition();
			int featureIndex = decNode.getFeatureIndex();
			DecisionNode.Type nodeType = decNode.getDecisionNodeType();

			if (out != null) {
				// dump in binary form to output
				out.writeInt(featureIndex);
				out.writeInt(nodeType.ordinal());
				// Now, questionValue, which depends on nodeType
				switch (nodeType) {
				case BinaryByteDecisionNode:
					out.writeInt(((BinaryByteDecisionNode) decNode).getCriterionValueAsByte());
					assert decNode.getNumberOfDaugthers() == 2;
					break;
				case BinaryShortDecisionNode:
					out.writeInt(((BinaryShortDecisionNode) decNode).getCriterionValueAsShort());
					assert decNode.getNumberOfDaugthers() == 2;
					break;
				case BinaryFloatDecisionNode:
					out.writeFloat(((BinaryFloatDecisionNode) decNode).getCriterionValueAsFloat());
					assert decNode.getNumberOfDaugthers() == 2;
					break;
				case ByteDecisionNode:
				case ShortDecisionNode:
					out.writeInt(decNode.getNumberOfDaugthers());
				}

				// The child nodes
				for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) {
					Node daughter = decNode.getDaughter(i);
					if (daughter == null) {
						out.writeInt(0);
					} else if (daughter.isDecisionNode()) {
						int daughterID = ((DecisionNode) daughter).getUniqueDecisionNodeId();
						// Mark as decision node:
						daughterID |= DirectedGraphReader.DECISIONNODE << 30;
						out.writeInt(daughterID);
					} else if (daughter.isLeafNode()) {
						int daughterID = ((LeafNode) daughter).getUniqueLeafId();
						// Mark as leaf node:
						if (daughterID != 0)
							daughterID |= DirectedGraphReader.LEAFNODE << 30;
						out.writeInt(daughterID);
					} else if (daughter.isDirectedGraphNode()) {
						int daughterID = ((DirectedGraphNode) daughter).getUniqueGraphNodeID();
						// Mark as directed graph node:
						if (daughterID != 0)
							daughterID |= DirectedGraphReader.DIRECTEDGRAPHNODE << 30;
						out.writeInt(daughterID);
					}
				}
			}
			if (pw != null) {
				// dump to print writer
				StringBuilder strNode = new StringBuilder("-" + id + " " + nodeDefinition);
				for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) {
					strNode.append(" ");
					Node daughter = decNode.getDaughter(i);
					if (daughter == null) {
						strNode.append("0");
					} else if (daughter.isDecisionNode()) {
						int daughterID = ((DecisionNode) daughter).getUniqueDecisionNodeId();
						strNode.append("-").append(daughterID);
						out.writeInt(daughterID);
					} else if (daughter.isLeafNode()) {
						int daughterID = ((LeafNode) daughter).getUniqueLeafId();
						if (daughterID == 0)
							strNode.append("0");
						else
							strNode.append("id").append(daughterID);
					} else if (daughter.isDirectedGraphNode()) {
						int daughterID = ((DirectedGraphNode) daughter).getUniqueGraphNodeID();
						if (daughterID == 0)
							strNode.append("0");
						else
							strNode.append("DGN").append(daughterID);
					}
				}
				pw.println(strNode.toString());
			}
		}
	}

	private void printLeafNodes(DirectedGraph graph, DataOutput out, PrintWriter pw) throws IOException {
		for (LeafNode leaf : graph.getLeafNodes()) {
			if (leaf.getUniqueLeafId() == 0) // empty leaf, do not write
				continue;
			LeafType leafType = leaf.getLeafNodeType();
			if (leafType == LeafType.FeatureVectorLeafNode) {
				leafType = LeafType.IntArrayLeafNode;
				// save feature vector leaf nodes as int array leaf nodes
			}
			if (out != null) {
				// Leaf node type
				out.writeInt(leafType.ordinal());
			}
			if (pw != null) {
				pw.print("id" + leaf.getUniqueLeafId() + " " + leafType);
			}
			switch (leaf.getLeafNodeType()) {
			case IntArrayLeafNode:
				int data[] = ((IntArrayLeafNode) leaf).getIntData();
				// Number of data points following:
				if (out != null)
					out.writeInt(data.length);
				if (pw != null)
					pw.print(" " + data.length);
				// for each index, write the index
				for (int i = 0; i < data.length; i++) {
					if (out != null)
						out.writeInt(data[i]);
					if (pw != null)
						pw.print(" " + data[i]);
				}
				break;
			case FloatLeafNode:
				float stddev = ((FloatLeafNode) leaf).getStDeviation();
				float mean = ((FloatLeafNode) leaf).getMean();
				if (out != null) {
					out.writeFloat(stddev);
					out.writeFloat(mean);
				}
				if (pw != null) {
					pw.print(" 1 " + stddev + " " + mean);
				}
				break;
			case IntAndFloatArrayLeafNode:
			case StringAndFloatLeafNode:
				int data1[] = ((IntAndFloatArrayLeafNode) leaf).getIntData();
				float floats[] = ((IntAndFloatArrayLeafNode) leaf).getFloatData();
				// Number of data points following:
				if (out != null)
					out.writeInt(data1.length);
				if (pw != null)
					pw.print(" " + data1.length);
				// for each index, write the index and then its float
				for (int i = 0; i < data1.length; i++) {
					if (out != null) {
						out.writeInt(data1[i]);
						out.writeFloat(floats[i]);
					}
					if (pw != null)
						pw.print(" " + data1[i] + " " + floats[i]);
				}
				break;
			case FeatureVectorLeafNode:
				FeatureVector fv[] = ((FeatureVectorLeafNode) leaf).getFeatureVectors();
				// Number of data points following:
				if (out != null)
					out.writeInt(fv.length);
				if (pw != null)
					pw.print(" " + fv.length);
				// for each feature vector, write the index
				for (int i = 0; i < fv.length; i++) {
					if (out != null)
						out.writeInt(fv[i].getUnitIndex());
					if (pw != null)
						pw.print(" " + fv[i].getUnitIndex());
				}
				break;
			case PdfLeafNode:
				throw new IllegalArgumentException("Writing of pdf leaf nodes not yet implemented");
			}
			if (pw != null)
				pw.println();
		}
	}

	private void printDirectedGraphNodes(DirectedGraph graph, DataOutput out, PrintWriter pw) throws IOException {
		for (DirectedGraphNode g : graph.getDirectedGraphNodes()) {
			int id = g.getUniqueGraphNodeID();
			if (id == 0)
				continue;// empty node, do not write
			Node leaf = g.getLeafNode();
			int leafID = 0;
			int leafNodeType = 0;
			if (leaf != null) {
				if (leaf instanceof LeafNode) {
					leafID = ((LeafNode) leaf).getUniqueLeafId();
					leafNodeType = DirectedGraphReader.LEAFNODE;
				} else if (leaf instanceof DirectedGraphNode) {
					leafID = ((DirectedGraphNode) leaf).getUniqueGraphNodeID();
					leafNodeType = DirectedGraphReader.DIRECTEDGRAPHNODE;
				} else {
					throw new IllegalArgumentException("Unexpected leaf type: " + leaf.getClass());
				}
			}
			DecisionNode d = g.getDecisionNode();
			int decID = d != null ? d.getUniqueDecisionNodeId() : 0;
			if (out != null) {
				int outLeafId = leafID == 0 ? 0 : leafID | (leafNodeType << 30);
				out.writeInt(outLeafId);
				int outDecId = decID == 0 ? 0 : decID | (DirectedGraphReader.DECISIONNODE << 30);
				out.writeInt(outDecId);
			}
			if (pw != null) {
				pw.print("DGN" + id);
				if (leafID == 0) {
					pw.print(" 0");
				} else if (leaf.isLeafNode()) {
					pw.print(" id" + leafID);
				} else {
					assert leaf.isDirectedGraphNode();
					pw.print(" DGN" + leafID);
				}
				if (decID == 0)
					pw.print(" 0");
				else
					pw.print(" -" + decID);
				pw.println();
			}
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy