All Downloads are FREE. Search and download functionalities are using the official Maven repository.

marytts.modules.acoustic.CARTModel Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2010 DFKI GmbH.
 * All Rights Reserved.  Use is subject to license terms.
 *
 * This file is part of MARY TTS.
 *
 * MARY TTS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, version 3 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see .
 *
 */

package marytts.modules.acoustic;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;

import marytts.cart.DirectedGraph;
import marytts.cart.io.DirectedGraphReader;
import marytts.exceptions.MaryConfigurationException;
import marytts.features.FeatureProcessorManager;
import marytts.unitselection.select.Target;

/**
 * Model for applying a CART to a list of Targets
 * 
 * @author steiner
 * 
 */
public class CARTModel extends Model {
	private DirectedGraph cart;

	public CARTModel(FeatureProcessorManager featureManager, String voiceName, InputStream dataStream,
			String targetAttributeName, String targetAttributeFormat, String featureName, String predictFrom, String applyTo)
			throws MaryConfigurationException {
		super(featureManager, voiceName, dataStream, targetAttributeName, targetAttributeFormat, featureName, predictFrom,
				applyTo);
		load();
	}

	/**
	 * Load CART from file for this Model
	 */
	@Override
	protected void loadData() throws IOException, MaryConfigurationException {
		cart = new DirectedGraphReader().load(dataStream);
		try {
			predictionFeatureNames = cart.getFeatureDefinition().getFeatureNames();
		} catch (NullPointerException e) {
			throw new IOException("Could not get FeatureDefinition from CART", e);
		}
		if (predictionFeatureNames.length() == 0) { // isEmpty
			throw new IOException("Could not get prediction feature names");
		}
	}

	/**
	 * Apply the CART to a Target to get its predicted value
	 */
	@Override
	protected float evaluate(Target target) throws Exception {
		assert target != null;

		float[] result = null;
		try {
			result = (float[]) cart.interpret(target);
		} catch (IllegalArgumentException e) {
			throw new Exception("Could not interpret target '" + target + "'", e);
		}

		float value = 0;
		try {
			value = result[1]; // assuming result is [stdev, val]
		} catch (ArrayIndexOutOfBoundsException e) {
			throw new Exception("Could not handle predicted value: '" + value + "'", e);
		}
		return value;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy