All Downloads are FREE. Search and download functionalities are using the official Maven repository.

marytts.unitselection.select.JoinModelCost Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2006 DFKI GmbH.
 * All Rights Reserved.  Use is subject to license terms.
 *
 * This file is part of MARY TTS.
 *
 * MARY TTS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, version 3 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see .
 *
 */
package marytts.unitselection.select;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;

import marytts.cart.CART;
import marytts.cart.Node;
import marytts.cart.LeafNode.PdfLeafNode;
import marytts.cart.io.HTSCARTReader;
import marytts.exceptions.MaryConfigurationException;
import marytts.features.FeatureDefinition;
import marytts.features.FeatureVector;
import marytts.htsengine.PhoneTranslator;
import marytts.htsengine.HMMData.PdfFileFormat;
import marytts.server.MaryProperties;
import marytts.signalproc.analysis.distance.DistanceComputer;
import marytts.unitselection.data.DiphoneUnit;
import marytts.unitselection.data.Unit;

public class JoinModelCost implements JoinCostFunction {
	protected int nCostComputations = 0;

	/****************/
	/* DATA FIELDS */
	/****************/
	private JoinCostFeatures jcf = null;

	CART[] joinTree = null; // an array of carts, one per HMM state.

	private float f0Weight;

	private FeatureDefinition featureDef = null;

	private boolean debugShowCostGraph = false;

	/****************/
	/* CONSTRUCTORS */
	/****************/

	/**
	 * Empty constructor; when using this, call load() separately to initialise this class.
	 * 
	 * @see #load(String a, InputStream b, String c, float d)
	 */
	public JoinModelCost() {
	}

	/**
	 * Initialise this join cost function by reading the appropriate settings from the MaryProperties using the given
	 * configPrefix.
	 * 
	 * @param configPrefix
	 *            the prefix for the (voice-specific) config entries to use when looking up files to load.
	 * @throws MaryConfigurationException
	 *             MaryConfigurationException
	 */
	public void init(String configPrefix) throws MaryConfigurationException {
		try {
			String joinFileName = MaryProperties.needFilename(configPrefix + ".joinCostFile");
			InputStream joinPdfStream = MaryProperties.needStream(configPrefix + ".joinPdfFile");
			InputStream joinTreeStream = MaryProperties.needStream(configPrefix + ".joinTreeFile");
			// CHECK not tested the trickyPhonesFile needs to be added into the configuration file
			String trickyPhonesFileName = MaryProperties.needFilename(configPrefix + ".trickyPhonesFile");
			load(joinFileName, joinPdfStream, joinTreeStream, trickyPhonesFileName);
		} catch (IOException ioe) {
			throw new MaryConfigurationException("Problem loading join file", ioe);
		}
	}

	@Override
	@Deprecated
	public void load(String a, InputStream b, String c, float d) {
		throw new RuntimeException("Do not use load() -- use init()");
	}

	/**
	 * Load weights and values from the given file
	 * 
	 * @param joinFileName
	 *            the file from which to read join cost features
	 * @param joinPdfStream
	 *            the file from which to read the Gaussian models in the leaves of the tree
	 * @param joinTreeStream
	 *            the file from which to read the Tree, in HTS format.
	 * @param trickyPhonesFile
	 *            trickyPhonesFile
	 * @throws IOException
	 *             IOException
	 * @throws MaryConfigurationException
	 *             MaryConfigurationException
	 */
	public void load(String joinFileName, InputStream joinPdfStream, InputStream joinTreeStream, String trickyPhonesFile)
			throws IOException, MaryConfigurationException {
		jcf = new JoinCostFeatures(joinFileName);

		assert featureDef != null : "Expected to have a feature definition, but it is null!";

		/* Load Trees */
		HTSCARTReader htsReader = new HTSCARTReader();
		int numStates = 1; // just one state in the joinModeller

		// Check if there are tricky phones, and create a PhoneTranslator object
		PhoneTranslator phTranslator = new PhoneTranslator(new FileInputStream(trickyPhonesFile));

		try {
			// joinTree.loadTreeSetGeneral(joinTreeFileName, 0, featureDef);
			joinTree = htsReader.load(numStates, joinTreeStream, joinPdfStream, PdfFileFormat.join, featureDef, phTranslator);

		} catch (Exception e) {
			IOException ioe = new IOException("Cannot load join model trees");
			ioe.initCause(e);
			throw ioe;
		}

	}

	/**
	 * Set the feature definition to use for interpreting target feature vectors.
	 * 
	 * @param def
	 *            the feature definition to use.
	 */
	public void setFeatureDefinition(FeatureDefinition def) {
		this.featureDef = def;
	}

	/*****************/
	/* MISC METHODS */
	/*****************/

	/**
	 * A combined cost computation, as a weighted sum of the signal-based cost (computed from the units) and the phonetics-based
	 * cost (computed from the targets).
	 * 
	 * @param t1
	 *            The left target.
	 * @param u1
	 *            The left unit.
	 * @param t2
	 *            The right target.
	 * @param u2
	 *            The right unit.
	 * 
	 * @return the cost of joining the left unit with the right unit, as a non-negative value.
	 */
	public double cost(Target t1, Unit u1, Target t2, Unit u2) {
		// Units of length 0 cannot be joined:
		if (u1.duration == 0 || u2.duration == 0)
			return Double.POSITIVE_INFINITY;
		// In the case of diphones, replace them with the relevant part:
		if (u1 instanceof DiphoneUnit) {
			u1 = ((DiphoneUnit) u1).right;
		}
		if (u2 instanceof DiphoneUnit) {
			u2 = ((DiphoneUnit) u2).left;
		}

		if (u1.index + 1 == u2.index)
			return 0;
		double cost = 1; // basic penalty for joins of non-contiguous units.

		float[] v1 = jcf.getRightJCF(u1.index);
		float[] v2 = jcf.getLeftJCF(u2.index);
		// double[] diff = new double[v1.length];
		// for ( int i = 0; i < v1.length; i++ ) {
		// diff[i] = (double)v1[i] - v2[i];
		// }
		double[] diff = new double[v1.length];
		for (int i = 0; i < v1.length; i++) {
			diff[i] = (double) v1[i] - v2[i];
		}

		// Now evaluate likelihood of the diff under the join model
		// Compute the model name:
		assert featureDef != null : "Feature Definition was not set";
		FeatureVector fv1 = null;
		if (t1 instanceof DiphoneTarget) {
			HalfPhoneTarget hpt1 = ((DiphoneTarget) t1).right;
			assert hpt1 != null;
			fv1 = hpt1.getFeatureVector();
		} else {
			fv1 = t1.getFeatureVector();
		}
		assert fv1 != null : "Target has no feature vector";
		// String modelName = contextTranslator.features2context(featureDef, fv1, featureList);

		int state = 0; // just one state in the joinModeller
		double[] mean;
		double[] variance;

		Node node = joinTree[state].interpretToNode(fv1, 1);

		assert node instanceof PdfLeafNode : "The node must be a PdfLeafNode.";

		mean = ((PdfLeafNode) node).getMean();
		variance = ((PdfLeafNode) node).getVariance();

		double distance = DistanceComputer.getNormalizedEuclideanDistance(diff, mean, variance);

		cost += distance;

		return cost;
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy