All Downloads are FREE. Search and download functionalities are using the official Maven repository.

marytts.cart.io.MaryCARTReader Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2006 DFKI GmbH.
 * All Rights Reserved.  Use is subject to license terms.
 *
 * This file is part of MARY TTS.
 *
 * MARY TTS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, version 3 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see .
 *
 */
package marytts.cart.io;

import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.Properties;

import marytts.cart.CART;
import marytts.cart.DecisionNode;
import marytts.cart.LeafNode;
import marytts.cart.Node;
import marytts.exceptions.MaryConfigurationException;
import marytts.features.FeatureDefinition;
import marytts.util.data.MaryHeader;

/**
 * IO functions for CARTs in MaryCART format
 * 
 * @author Marcela Charfuelan
 */
public class MaryCARTReader {
	/**
	 * Load the cart from the given file
	 * 
	 * @param fileName
	 *            the file to load the cart from
	 * @throws IOException
	 *             if a problem occurs while loading
	 * @throws MaryConfigurationException
	 *             MaryConfigurationException
	 * @return loadFromStream(fis)
	 */
	public CART load(String fileName) throws IOException, MaryConfigurationException {
		FileInputStream fis = new FileInputStream(fileName);
		try {
			return loadFromStream(fis);
		} finally {
			fis.close();
		}
	}

	/**
	 * Load the cart from the given file
	 * 
	 * @param inStream
	 *            the stream to load the cart from
	 * @throws IOException
	 *             if a problem occurs while loading
	 * @throws MaryConfigurationException
	 *             MaryConfigurationException
	 * @return CART(rootNode, featureDefinition, props)
	 */
	public CART loadFromStream(InputStream inStream) throws IOException, MaryConfigurationException {
		// open the CART-File and read the header
		DataInput raf = new DataInputStream(new BufferedInputStream(inStream));

		MaryHeader maryHeader = new MaryHeader(raf);
		if (!maryHeader.hasCurrentVersion()) {
			throw new IOException("Wrong version of database file");
		}
		if (maryHeader.getType() != MaryHeader.CARTS) {
			throw new IOException("No CARTs file");
		}

		// Read properties
		short propDataLength = raf.readShort();
		Properties props;
		if (propDataLength == 0) {
			props = null;
		} else {
			byte[] propsData = new byte[propDataLength];
			raf.readFully(propsData);
			ByteArrayInputStream bais = new ByteArrayInputStream(propsData);
			props = new Properties();
			props.load(bais);
			bais.close();
		}

		// Read the feature definition
		FeatureDefinition featureDefinition = new FeatureDefinition(raf);

		// read the decision nodes
		int numDecNodes = raf.readInt(); // number of decision nodes

		// First we need to read all nodes into memory, then we can link them properly
		// in terms of parent/child.
		DecisionNode[] dns = new DecisionNode[numDecNodes];
		int[][] childIndexes = new int[numDecNodes][];
		for (int i = 0; i < numDecNodes; i++) {
			// read one decision node
			int featureNameIndex = raf.readInt();
			int nodeTypeNr = raf.readInt();
			DecisionNode.Type nodeType = DecisionNode.Type.values()[nodeTypeNr];
			int numChildren = 2; // for binary nodes
			switch (nodeType) {
			case BinaryByteDecisionNode:
				int criterion = raf.readInt();
				dns[i] = new DecisionNode.BinaryByteDecisionNode(featureNameIndex, (byte) criterion, featureDefinition);
				break;
			case BinaryShortDecisionNode:
				criterion = raf.readInt();
				dns[i] = new DecisionNode.BinaryShortDecisionNode(featureNameIndex, (short) criterion, featureDefinition);
				break;
			case BinaryFloatDecisionNode:
				float floatCriterion = raf.readFloat();
				dns[i] = new DecisionNode.BinaryFloatDecisionNode(featureNameIndex, floatCriterion, featureDefinition);
				break;
			case ByteDecisionNode:
				numChildren = raf.readInt();
				if (featureDefinition.getNumberOfValues(featureNameIndex) != numChildren) {
					throw new IOException("Inconsistent cart file: feature " + featureDefinition.getFeatureName(featureNameIndex)
							+ " should have " + featureDefinition.getNumberOfValues(featureNameIndex)
							+ " values, but decision node " + i + " has only " + numChildren + " child nodes");
				}
				dns[i] = new DecisionNode.ByteDecisionNode(featureNameIndex, numChildren, featureDefinition);
				break;
			case ShortDecisionNode:
				numChildren = raf.readInt();
				if (featureDefinition.getNumberOfValues(featureNameIndex) != numChildren) {
					throw new IOException("Inconsistent cart file: feature " + featureDefinition.getFeatureName(featureNameIndex)
							+ " should have " + featureDefinition.getNumberOfValues(featureNameIndex)
							+ " values, but decision node " + i + " has only " + numChildren + " child nodes");
				}
				dns[i] = new DecisionNode.ShortDecisionNode(featureNameIndex, numChildren, featureDefinition);
			}
			// now read the children, indexes only:
			childIndexes[i] = new int[numChildren];
			for (int k = 0; k < numChildren; k++) {
				childIndexes[i][k] = raf.readInt();
			}
		}

		// read the leaves
		int numLeafNodes = raf.readInt(); // number of leaves, it does not include empty leaves
		LeafNode[] lns = new LeafNode[numLeafNodes];

		for (int j = 0; j < numLeafNodes; j++) {
			// read one leaf node
			int leafTypeNr = raf.readInt();
			LeafNode.LeafType leafNodeType = LeafNode.LeafType.values()[leafTypeNr];
			switch (leafNodeType) {
			case IntArrayLeafNode:
				int numData = raf.readInt();
				int[] data = new int[numData];
				for (int d = 0; d < numData; d++) {
					data[d] = raf.readInt();
				}
				lns[j] = new LeafNode.IntArrayLeafNode(data);
				break;
			case FloatLeafNode:
				float stddev = raf.readFloat();
				float mean = raf.readFloat();
				lns[j] = new LeafNode.FloatLeafNode(new float[] { stddev, mean });
				break;
			case IntAndFloatArrayLeafNode:
			case StringAndFloatLeafNode:
				int numPairs = raf.readInt();
				int[] ints = new int[numPairs];
				float[] floats = new float[numPairs];
				for (int d = 0; d < numPairs; d++) {
					ints[d] = raf.readInt();
					floats[d] = raf.readFloat();
				}
				if (leafNodeType == LeafNode.LeafType.IntAndFloatArrayLeafNode)
					lns[j] = new LeafNode.IntAndFloatArrayLeafNode(ints, floats);
				else
					lns[j] = new LeafNode.StringAndFloatLeafNode(ints, floats);
				break;
			case FeatureVectorLeafNode:
				throw new IllegalArgumentException("Reading feature vector leaf nodes is not yet implemented");
			case PdfLeafNode:
				throw new IllegalArgumentException("Reading pdf leaf nodes is not yet implemented");
			}
		}

		// Now, link up the decision nodes with their daughters
		for (int i = 0; i < numDecNodes; i++) {
			for (int k = 0; k < childIndexes[i].length; k++) {
				int childIndex = childIndexes[i][k];
				if (childIndex < 0) { // a decision node
					assert -childIndex - 1 < numDecNodes;
					dns[i].addDaughter(dns[-childIndex - 1]);
				} else if (childIndex > 0) { // a leaf node
					dns[i].addDaughter(lns[childIndex - 1]);
				} else { // == 0, an empty leaf
					dns[i].addDaughter(null);
				}
			}
		}

		Node rootNode;
		if (dns.length > 0) {
			rootNode = dns[0];
			// Now count all data once, so that getNumberOfData()
			// will return the correct figure.
			((DecisionNode) rootNode).countData();
		} else if (lns.length > 0) {
			rootNode = lns[0]; // single-leaf tree...
		} else {
			rootNode = null;
		}

		// set the rootNode as the rootNode of cart
		return new CART(rootNode, featureDefinition, props);
	}

	/**
	 * Load the cart from the given file
	 * 
	 * @param fileName
	 *            the file to load the cart from
	 * @param featDefinition
	 *            the feature definition
	 * @param dummy
	 *            unused, just here for compatibility with the FeatureFileIndexer.
	 * @throws IOException
	 *             if a problem occurs while loading
	 */
	private CART loadFromByteBuffer(String fileName) throws IOException, MaryConfigurationException {
		// open the CART-File and read the header
		FileInputStream fis = new FileInputStream(fileName);
		FileChannel fc = fis.getChannel();
		ByteBuffer bb = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size());
		fis.close();

		MaryHeader maryHeader = new MaryHeader(bb);
		if (!maryHeader.hasCurrentVersion()) {
			throw new IOException("Wrong version of database file");
		}
		if (maryHeader.getType() != MaryHeader.CARTS) {
			throw new IOException("No CARTs file");
		}

		// Read properties
		short propDataLength = bb.getShort();
		Properties props;
		if (propDataLength == 0) {
			props = null;
		} else {
			byte[] propsData = new byte[propDataLength];
			bb.get(propsData);
			ByteArrayInputStream bais = new ByteArrayInputStream(propsData);
			props = new Properties();
			props.load(bais);
			bais.close();
		}

		// Read the feature definition
		FeatureDefinition featureDefinition = new FeatureDefinition(bb);

		// read the decision nodes
		int numDecNodes = bb.getInt(); // number of decision nodes

		// First we need to read all nodes into memory, then we can link them properly
		// in terms of parent/child.
		DecisionNode[] dns = new DecisionNode[numDecNodes];
		int[][] childIndexes = new int[numDecNodes][];
		for (int i = 0; i < numDecNodes; i++) {
			// read one decision node
			int featureNameIndex = bb.getInt();
			int nodeTypeNr = bb.getInt();
			DecisionNode.Type nodeType = DecisionNode.Type.values()[nodeTypeNr];
			int numChildren = 2; // for binary nodes
			switch (nodeType) {
			case BinaryByteDecisionNode:
				int criterion = bb.getInt();
				dns[i] = new DecisionNode.BinaryByteDecisionNode(featureNameIndex, (byte) criterion, featureDefinition);
				break;
			case BinaryShortDecisionNode:
				criterion = bb.getInt();
				dns[i] = new DecisionNode.BinaryShortDecisionNode(featureNameIndex, (short) criterion, featureDefinition);
				break;
			case BinaryFloatDecisionNode:
				float floatCriterion = bb.getFloat();
				dns[i] = new DecisionNode.BinaryFloatDecisionNode(featureNameIndex, floatCriterion, featureDefinition);
				break;
			case ByteDecisionNode:
				numChildren = bb.getInt();
				if (featureDefinition.getNumberOfValues(featureNameIndex) != numChildren) {
					throw new IOException("Inconsistent cart file: feature " + featureDefinition.getFeatureName(featureNameIndex)
							+ " should have " + featureDefinition.getNumberOfValues(featureNameIndex)
							+ " values, but decision node " + i + " has only " + numChildren + " child nodes");
				}
				dns[i] = new DecisionNode.ByteDecisionNode(featureNameIndex, numChildren, featureDefinition);
				break;
			case ShortDecisionNode:
				numChildren = bb.getInt();
				if (featureDefinition.getNumberOfValues(featureNameIndex) != numChildren) {
					throw new IOException("Inconsistent cart file: feature " + featureDefinition.getFeatureName(featureNameIndex)
							+ " should have " + featureDefinition.getNumberOfValues(featureNameIndex)
							+ " values, but decision node " + i + " has only " + numChildren + " child nodes");
				}
				dns[i] = new DecisionNode.ShortDecisionNode(featureNameIndex, numChildren, featureDefinition);
			}
			// now read the children, indexes only:
			childIndexes[i] = new int[numChildren];
			for (int k = 0; k < numChildren; k++) {
				childIndexes[i][k] = bb.getInt();
			}
		}

		// read the leaves
		int numLeafNodes = bb.getInt(); // number of leaves, it does not include empty leaves
		LeafNode[] lns = new LeafNode[numLeafNodes];

		for (int j = 0; j < numLeafNodes; j++) {
			// read one leaf node
			int leafTypeNr = bb.getInt();
			LeafNode.LeafType leafNodeType = LeafNode.LeafType.values()[leafTypeNr];
			switch (leafNodeType) {
			case IntArrayLeafNode:
				int numData = bb.getInt();
				int[] data = new int[numData];
				for (int d = 0; d < numData; d++) {
					data[d] = bb.getInt();
				}
				lns[j] = new LeafNode.IntArrayLeafNode(data);
				break;
			case FloatLeafNode:
				float stddev = bb.getFloat();
				float mean = bb.getFloat();
				lns[j] = new LeafNode.FloatLeafNode(new float[] { stddev, mean });
				break;
			case IntAndFloatArrayLeafNode:
			case StringAndFloatLeafNode:
				int numPairs = bb.getInt();
				int[] ints = new int[numPairs];
				float[] floats = new float[numPairs];
				for (int d = 0; d < numPairs; d++) {
					ints[d] = bb.getInt();
					floats[d] = bb.getFloat();
				}
				if (leafNodeType == LeafNode.LeafType.IntAndFloatArrayLeafNode)
					lns[j] = new LeafNode.IntAndFloatArrayLeafNode(ints, floats);
				else
					lns[j] = new LeafNode.StringAndFloatLeafNode(ints, floats);
				break;
			case FeatureVectorLeafNode:
				throw new IllegalArgumentException("Reading feature vector leaf nodes is not yet implemented");
			case PdfLeafNode:
				throw new IllegalArgumentException("Reading pdf leaf nodes is not yet implemented");
			}
		}

		// Now, link up the decision nodes with their daughters
		for (int i = 0; i < numDecNodes; i++) {
			for (int k = 0; k < childIndexes[i].length; k++) {
				int childIndex = childIndexes[i][k];
				if (childIndex < 0) { // a decision node
					assert -childIndex - 1 < numDecNodes;
					dns[i].addDaughter(dns[-childIndex - 1]);
				} else if (childIndex > 0) { // a leaf node
					dns[i].addDaughter(lns[childIndex - 1]);
				} else { // == 0, an empty leaf
					dns[i].addDaughter(null);
				}
			}
		}

		Node rootNode;
		if (dns.length > 0) {
			rootNode = dns[0];
			// Now count all data once, so that getNumberOfData()
			// will return the correct figure.
			((DecisionNode) rootNode).countData();
		} else if (lns.length > 0) {
			rootNode = lns[0]; // single-leaf tree...
		} else {
			rootNode = null;
		}

		// set the rootNode as the rootNode of cart
		return new CART(rootNode, featureDefinition, props);
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy