marytts.unitselection.select.JoinModelCost Maven / Gradle / Ivy
The newest version!
/**
* Copyright 2006 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*
*/
package marytts.unitselection.select;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import marytts.cart.CART;
import marytts.cart.Node;
import marytts.cart.LeafNode.PdfLeafNode;
import marytts.cart.io.HTSCARTReader;
import marytts.exceptions.MaryConfigurationException;
import marytts.features.FeatureDefinition;
import marytts.features.FeatureVector;
import marytts.htsengine.PhoneTranslator;
import marytts.htsengine.HMMData.PdfFileFormat;
import marytts.server.MaryProperties;
import marytts.signalproc.analysis.distance.DistanceComputer;
import marytts.unitselection.data.DiphoneUnit;
import marytts.unitselection.data.Unit;
public class JoinModelCost implements JoinCostFunction {
protected int nCostComputations = 0;
/****************/
/* DATA FIELDS */
/****************/
private JoinCostFeatures jcf = null;
CART[] joinTree = null; // an array of carts, one per HMM state.
private float f0Weight;
private FeatureDefinition featureDef = null;
private boolean debugShowCostGraph = false;
/****************/
/* CONSTRUCTORS */
/****************/
/**
* Empty constructor; when using this, call load() separately to initialise this class.
*
* @see #load(String a, InputStream b, String c, float d)
*/
public JoinModelCost() {
}
/**
* Initialise this join cost function by reading the appropriate settings from the MaryProperties using the given
* configPrefix.
*
* @param configPrefix
* the prefix for the (voice-specific) config entries to use when looking up files to load.
* @throws MaryConfigurationException
* MaryConfigurationException
*/
public void init(String configPrefix) throws MaryConfigurationException {
try {
String joinFileName = MaryProperties.needFilename(configPrefix + ".joinCostFile");
InputStream joinPdfStream = MaryProperties.needStream(configPrefix + ".joinPdfFile");
InputStream joinTreeStream = MaryProperties.needStream(configPrefix + ".joinTreeFile");
// CHECK not tested the trickyPhonesFile needs to be added into the configuration file
String trickyPhonesFileName = MaryProperties.needFilename(configPrefix + ".trickyPhonesFile");
load(joinFileName, joinPdfStream, joinTreeStream, trickyPhonesFileName);
} catch (IOException ioe) {
throw new MaryConfigurationException("Problem loading join file", ioe);
}
}
@Override
@Deprecated
public void load(String a, InputStream b, String c, float d) {
throw new RuntimeException("Do not use load() -- use init()");
}
/**
* Load weights and values from the given file
*
* @param joinFileName
* the file from which to read join cost features
* @param joinPdfStream
* the file from which to read the Gaussian models in the leaves of the tree
* @param joinTreeStream
* the file from which to read the Tree, in HTS format.
* @param trickyPhonesFile
* trickyPhonesFile
* @throws IOException
* IOException
* @throws MaryConfigurationException
* MaryConfigurationException
*/
public void load(String joinFileName, InputStream joinPdfStream, InputStream joinTreeStream, String trickyPhonesFile)
throws IOException, MaryConfigurationException {
jcf = new JoinCostFeatures(joinFileName);
assert featureDef != null : "Expected to have a feature definition, but it is null!";
/* Load Trees */
HTSCARTReader htsReader = new HTSCARTReader();
int numStates = 1; // just one state in the joinModeller
// Check if there are tricky phones, and create a PhoneTranslator object
PhoneTranslator phTranslator = new PhoneTranslator(new FileInputStream(trickyPhonesFile));
try {
// joinTree.loadTreeSetGeneral(joinTreeFileName, 0, featureDef);
joinTree = htsReader.load(numStates, joinTreeStream, joinPdfStream, PdfFileFormat.join, featureDef, phTranslator);
} catch (Exception e) {
IOException ioe = new IOException("Cannot load join model trees");
ioe.initCause(e);
throw ioe;
}
}
/**
* Set the feature definition to use for interpreting target feature vectors.
*
* @param def
* the feature definition to use.
*/
public void setFeatureDefinition(FeatureDefinition def) {
this.featureDef = def;
}
/*****************/
/* MISC METHODS */
/*****************/
/**
* A combined cost computation, as a weighted sum of the signal-based cost (computed from the units) and the phonetics-based
* cost (computed from the targets).
*
* @param t1
* The left target.
* @param u1
* The left unit.
* @param t2
* The right target.
* @param u2
* The right unit.
*
* @return the cost of joining the left unit with the right unit, as a non-negative value.
*/
public double cost(Target t1, Unit u1, Target t2, Unit u2) {
// Units of length 0 cannot be joined:
if (u1.duration == 0 || u2.duration == 0)
return Double.POSITIVE_INFINITY;
// In the case of diphones, replace them with the relevant part:
if (u1 instanceof DiphoneUnit) {
u1 = ((DiphoneUnit) u1).right;
}
if (u2 instanceof DiphoneUnit) {
u2 = ((DiphoneUnit) u2).left;
}
if (u1.index + 1 == u2.index)
return 0;
double cost = 1; // basic penalty for joins of non-contiguous units.
float[] v1 = jcf.getRightJCF(u1.index);
float[] v2 = jcf.getLeftJCF(u2.index);
// double[] diff = new double[v1.length];
// for ( int i = 0; i < v1.length; i++ ) {
// diff[i] = (double)v1[i] - v2[i];
// }
double[] diff = new double[v1.length];
for (int i = 0; i < v1.length; i++) {
diff[i] = (double) v1[i] - v2[i];
}
// Now evaluate likelihood of the diff under the join model
// Compute the model name:
assert featureDef != null : "Feature Definition was not set";
FeatureVector fv1 = null;
if (t1 instanceof DiphoneTarget) {
HalfPhoneTarget hpt1 = ((DiphoneTarget) t1).right;
assert hpt1 != null;
fv1 = hpt1.getFeatureVector();
} else {
fv1 = t1.getFeatureVector();
}
assert fv1 != null : "Target has no feature vector";
// String modelName = contextTranslator.features2context(featureDef, fv1, featureList);
int state = 0; // just one state in the joinModeller
double[] mean;
double[] variance;
Node node = joinTree[state].interpretToNode(fv1, 1);
assert node instanceof PdfLeafNode : "The node must be a PdfLeafNode.";
mean = ((PdfLeafNode) node).getMean();
variance = ((PdfLeafNode) node).getVariance();
double distance = DistanceComputer.getNormalizedEuclideanDistance(diff, mean, variance);
cost += distance;
return cost;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy