marytts.unitselection.select.JoinCostFeatures Maven / Gradle / Ivy
The newest version!
/**
* Copyright 2006 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*
*/
package marytts.unitselection.select;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.EOFException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.util.Vector;
import marytts.exceptions.MaryConfigurationException;
import marytts.features.ByteValuedFeatureProcessor;
import marytts.features.MaryGenericFeatureProcessors;
import marytts.modules.phonemiser.Allophone;
import marytts.server.MaryProperties;
import marytts.signalproc.display.Histogram;
import marytts.unitselection.data.DiphoneUnit;
import marytts.unitselection.data.Unit;
import marytts.unitselection.weightingfunctions.WeightFunc;
import marytts.unitselection.weightingfunctions.WeightFunctionManager;
import marytts.util.MaryUtils;
import marytts.util.data.MaryHeader;
import marytts.util.io.StreamUtils;
public class JoinCostFeatures implements JoinCostFunction {
protected float wSignal;
protected float wPhonetic;
protected boolean debugShowCostGraph = false;
protected double[] cumulWeightedSignalCosts = null;
protected int nCostComputations = 0;
protected PrecompiledJoinCostReader precompiledCosts;
protected JoinCostReporter jcr;
/****************/
/* DATA FIELDS */
/****************/
private MaryHeader hdr = null;
private float[] featureWeight = null;
private WeightFunc[] weightFunction = null;
private boolean[] isLinear = null; // wether the i'th weight function is a linear function
private float[][] leftJCF = null;
private float[][] rightJCF = null;
/****************/
/* CONSTRUCTORS */
/****************/
/**
* Empty constructor; when using this, call load() separately to initialise this class.
*
* @see #load(String joinFileName, InputStream weightStream, String precompiledCostFileName, float wSignal)
*/
public JoinCostFeatures() {
}
/**
* Constructor which read a Mary Join Cost file.
*
* @param fileName
* fileName
* @throws IOException
* IOException
* @throws MaryConfigurationException
* MaryConfigurationException
*/
public JoinCostFeatures(String fileName) throws IOException, MaryConfigurationException {
load(fileName, null, null, (float) 0.5);
}
/**
* Initialise this join cost function by reading the appropriate settings from the MaryProperties using the given
* configPrefix.
*
* @param configPrefix
* the prefix for the (voice-specific) config entries to use when looking up files to load.
* @throws MaryConfigurationException
* MaryConfigurationException
*/
public void init(String configPrefix) throws MaryConfigurationException {
String joinFileName = MaryProperties.needFilename(configPrefix + ".joinCostFile");
String precomputedJoinCostFileName = MaryProperties.getFilename(configPrefix + ".precomputedJoinCostFile");
float wSignal = Float.parseFloat(MaryProperties.getProperty(configPrefix + ".joincostfunction.wSignal", "1.0"));
try {
InputStream joinWeightStream = MaryProperties.getStream(configPrefix + ".joinCostWeights");
load(joinFileName, joinWeightStream, precomputedJoinCostFileName, wSignal);
} catch (IOException ioe) {
throw new MaryConfigurationException("Problem loading join file " + joinFileName, ioe);
}
}
/**
* Load weights and values from the given file
*
* @param joinFileName
* the file from which to read default weights and join cost features
* @param weightStream
* an optional file from which to read weights, taking precedence over
* @param precompiledCostFileName
* an optional file containing precompiled join costs
* @param wSignal
* Relative weight of the signal-based join costs relative to the phonetic join costs computed from the target
* @throws IOException
* IOException
* @throws MaryConfigurationException
* MaryConfigurationException
*/
public void load(String joinFileName, InputStream weightStream, String precompiledCostFileName, float wSignal)
throws IOException, MaryConfigurationException {
loadFromByteBuffer(joinFileName, weightStream, precompiledCostFileName, wSignal);
}
/**
* Load weights and values from the given file
*
* @param joinFileName
* the file from which to read default weights and join cost features
* @param weightStream
* an optional file from which to read weights, taking precedence over
* @param precompiledCostFileName
* an optional file containing precompiled join costs
* @param wSignal
* Relative weight of the signal-based join costs relative to the phonetic join costs computed from the target
* @throws IOException
* IOException
* @throws MaryConfigurationException
* MaryConfigurationException
*/
private void loadFromByteBuffer(String joinFileName, InputStream weightStream, String precompiledCostFileName, float wSignal)
throws IOException, MaryConfigurationException {
if (precompiledCostFileName != null) {
precompiledCosts = new PrecompiledJoinCostReader(precompiledCostFileName);
}
this.wSignal = wSignal;
wPhonetic = 1 - wSignal;
/* Open the file */
FileInputStream fis = new FileInputStream(joinFileName);
FileChannel fc = fis.getChannel();
ByteBuffer bb = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size());
/* Read the Mary header */
hdr = new MaryHeader(bb);
if (hdr.getType() != MaryHeader.JOINFEATS) {
throw new IOException("File [" + joinFileName + "] is not a valid Mary join features file.");
}
try {
/* Read the feature weights and feature processors */
int numberOfFeatures = bb.getInt();
featureWeight = new float[numberOfFeatures];
weightFunction = new WeightFunc[numberOfFeatures];
isLinear = new boolean[numberOfFeatures];
WeightFunctionManager wfm = new WeightFunctionManager();
String wfStr = null;
for (int i = 0; i < numberOfFeatures; i++) {
featureWeight[i] = bb.getFloat();
wfStr = StreamUtils.readUTF(bb);
if ("".equals(wfStr))
weightFunction[i] = wfm.getWeightFunction("linear");
else
weightFunction[i] = wfm.getWeightFunction(wfStr);
}
// Overwrite weights and weight functions from file?
if (weightStream != null) {
MaryUtils.getLogger("JoinCostFeatures").debug("Overwriting join cost weights");
Object[] weightData = readJoinCostWeightsStream(weightStream);
featureWeight = (float[]) weightData[0];
String[] wf = (String[]) weightData[1];
if (featureWeight.length != numberOfFeatures)
throw new IllegalArgumentException("Join cost file contains " + numberOfFeatures
+ " features, but weight file contains " + featureWeight.length + " feature weights!");
for (int i = 0; i < numberOfFeatures; i++) {
weightFunction[i] = wfm.getWeightFunction(wf[i]);
}
}
for (int i = 0; i < numberOfFeatures; i++) {
isLinear[i] = weightFunction[i].whoAmI().equals("linear");
}
/* Read the left and right Join Cost Features */
int numberOfUnits = bb.getInt();
FloatBuffer fb = bb.asFloatBuffer();
leftJCF = new float[numberOfUnits][];
rightJCF = new float[numberOfUnits][];
for (int i = 0; i < numberOfUnits; i++) {
// System.out.println("Reading join features for unit "+i+" out of "+numberOfUnits);
leftJCF[i] = new float[numberOfFeatures];
fb.get(leftJCF[i]);
rightJCF[i] = new float[numberOfFeatures];
fb.get(rightJCF[i]);
}
} catch (EOFException e) {
IOException ioe = new IOException("The currently read Join Cost File has prematurely reached EOF.");
ioe.initCause(e);
throw ioe;
}
if (MaryProperties.getBoolean("debug.show.cost.graph")) {
debugShowCostGraph = true;
cumulWeightedSignalCosts = new double[featureWeight.length];
jcr = new JoinCostReporter(cumulWeightedSignalCosts);
jcr.showInJFrame("Average signal join costs", false, false);
jcr.start();
}
}
/**
* Load weights and values from the given file
*
* @param joinFileName
* the file from which to read default weights and join cost features
* @param weightStream
* an optional file from which to read weights, taking precedence over
* @param precompiledCostFileName
* an optional file containing precompiled join costs
* @param wSignal
* Relative weight of the signal-based join costs relative to the phonetic join costs computed from the target
* @throws IOException
* IOException
* @throws MaryConfigurationException
* MaryConfigurationException
*/
private void loadFromStream(String joinFileName, InputStream weightStream, String precompiledCostFileName, float wSignal)
throws IOException, MaryConfigurationException {
if (precompiledCostFileName != null) {
precompiledCosts = new PrecompiledJoinCostReader(precompiledCostFileName);
}
this.wSignal = wSignal;
wPhonetic = 1 - wSignal;
/* Open the file */
File fid = new File(joinFileName);
DataInput raf = new DataInputStream(new BufferedInputStream(new FileInputStream(fid)));
/* Read the Mary header */
hdr = new MaryHeader(raf);
if (hdr.getType() != MaryHeader.JOINFEATS) {
throw new MaryConfigurationException("File [" + joinFileName + "] is not a valid Mary join features file.");
}
try {
/* Read the feature weights and feature processors */
int numberOfFeatures = raf.readInt();
featureWeight = new float[numberOfFeatures];
weightFunction = new WeightFunc[numberOfFeatures];
isLinear = new boolean[numberOfFeatures];
WeightFunctionManager wfm = new WeightFunctionManager();
String wfStr = null;
for (int i = 0; i < numberOfFeatures; i++) {
featureWeight[i] = raf.readFloat();
wfStr = raf.readUTF();
if ("".equals(wfStr))
weightFunction[i] = wfm.getWeightFunction("linear");
else
weightFunction[i] = wfm.getWeightFunction(wfStr);
}
// Overwrite weights and weight functions from file?
if (weightStream != null) {
MaryUtils.getLogger("JoinCostFeatures").debug("Overwriting join cost weights");
Object[] weightData = readJoinCostWeightsStream(weightStream);
featureWeight = (float[]) weightData[0];
String[] wf = (String[]) weightData[1];
if (featureWeight.length != numberOfFeatures)
throw new IllegalArgumentException("Join cost file contains " + numberOfFeatures
+ " features, but weight file contains " + featureWeight.length + " feature weights!");
for (int i = 0; i < numberOfFeatures; i++) {
weightFunction[i] = wfm.getWeightFunction(wf[i]);
}
}
for (int i = 0; i < numberOfFeatures; i++) {
isLinear[i] = weightFunction[i].whoAmI().equals("linear");
}
/* Read the left and right Join Cost Features */
int numberOfUnits = raf.readInt();
leftJCF = new float[numberOfUnits][];
rightJCF = new float[numberOfUnits][];
for (int i = 0; i < numberOfUnits; i++) {
// System.out.println("Reading join features for unit "+i+" out of "+numberOfUnits);
leftJCF[i] = new float[numberOfFeatures];
for (int j = 0; j < numberOfFeatures; j++) {
leftJCF[i][j] = raf.readFloat();
}
rightJCF[i] = new float[numberOfFeatures];
for (int j = 0; j < numberOfFeatures; j++) {
rightJCF[i][j] = raf.readFloat();
}
}
} catch (EOFException e) {
IOException ioe = new IOException("The currently read Join Cost File has prematurely reached EOF.");
ioe.initCause(e);
throw ioe;
}
if (MaryProperties.getBoolean("debug.show.cost.graph")) {
debugShowCostGraph = true;
cumulWeightedSignalCosts = new double[featureWeight.length];
jcr = new JoinCostReporter(cumulWeightedSignalCosts);
jcr.showInJFrame("Average signal join costs", false, false);
jcr.start();
}
}
/**
* Read the join cost weight specifications from the given file. The weights will be normalized such that they sum to one.
*
* @param fileName
* the text file containing the join weights
* @throws IOException
* IOException
* @throws FileNotFoundException
* FileNotFoundException
* @return readJoinCostWeightsStream(new FileInputStream(fileName))
* */
public static Object[] readJoinCostWeightsFile(String fileName) throws IOException, FileNotFoundException {
return readJoinCostWeightsStream(new FileInputStream(fileName));
}
/**
* Read the join cost weight specifications from the given file. The weights will be normalized such that they sum to one.
*
* @param weightStream
* the text file containing the join weights
* @throws IOException
* IOException
* @throws FileNotFoundException
* FileNotFoundException
* @return Object[] { fw, wfun }
* */
public static Object[] readJoinCostWeightsStream(InputStream weightStream) throws IOException, FileNotFoundException {
Vector v = new Vector(16, 16);
Vector vf = new Vector(16, 16);
/* Open the file */
BufferedReader in = new BufferedReader(new InputStreamReader(weightStream, "UTF-8"));
/* Loop through the lines */
String line = null;
String[] fields = null;
float sumOfWeights = 0;
while ((line = in.readLine()) != null) {
// System.out.println( line );
line = line.split("#", 2)[0]; // Remove possible trailing comments
line = line.trim(); // Remove leading and trailing blanks
if (line.equals(""))
continue; // Empty line: don't parse
line = line.split(":", 2)[1].trim(); // Remove the line number and :
// System.out.print( "CLEANED: [" + line + "]" );
fields = line.split("\\s", 2); // Separate the weight value from the function name
float aWeight = Float.parseFloat(fields[0]);
sumOfWeights += aWeight;
v.add(new Float(aWeight)); // Push the weight
vf.add(fields[1]); // Push the function
// System.out.println( "NBFEA=" + numberOfFeatures );
}
in.close();
// System.out.flush();
/* Export the vector of weighting function names as a String array: */
String[] wfun = (String[]) vf.toArray(new String[vf.size()]);
/*
* For the weights, create a float array containing the weights, normalized such that they sum to one:
*/
float[] fw = new float[v.size()];
for (int i = 0; i < fw.length; i++) {
Float aWeight = (Float) v.get(i);
fw[i] = aWeight.floatValue() / sumOfWeights;
}
/* Return these as an Object[2]. */
return new Object[] { fw, wfun };
}
/*****************/
/* ACCESSORS */
/*****************/
/**
* Get the number of feature weights and weighting functions.
*
* @return (featureWeight.length)
*/
public int getNumberOfFeatures() {
return (featureWeight.length);
}
/**
* Get the number of units.
*
* @return (leftJCF.length)
*/
public int getNumberOfUnits() {
return (leftJCF.length);
}
/**
* Gets the array of left join cost features for a particular unit index.
*
* @param u
* The index of the considered unit.
*
* @return The array of left join cost features for the given unit.
*/
public float[] getLeftJCF(int u) {
if (u < 0) {
throw new RuntimeException("The unit index [" + u + "] is out of range: a unit index can't be negative.");
}
if (u > getNumberOfUnits()) {
throw new RuntimeException("The unit index [" + u + "] is out of range: this file contains [" + getNumberOfUnits()
+ "] units.");
}
return (leftJCF[u]);
}
/**
* Gets the array of right join cost features for a particular unit index.
*
* @param u
* The index of the considered unit.
*
* @return The array of right join cost features for the given unit.
*/
public float[] getRightJCF(int u) {
if (u < 0) {
throw new RuntimeException("The unit index [" + u + "] is out of range: a unit index can't be negative.");
}
if (u > getNumberOfUnits()) {
throw new RuntimeException("The unit index [" + u + "] is out of range: this file contains [" + getNumberOfUnits()
+ "] units.");
}
return (rightJCF[u]);
}
/*****************/
/* MISC METHODS */
/*****************/
/**
* Deliver the join cost between two units described by their index.
*
* @param u1
* the left unit
* @param u2
* the right unit
*
* @return the cost of joining the right Join Cost features of the left unit with the left Join Cost Features of the right
* unit.
*/
public double cost(int u1, int u2) {
/* Check the given indexes */
if (u1 < 0) {
throw new RuntimeException("The left unit index [" + u1 + "] is out of range: a unit index can't be negative.");
}
// if ( u1 > getNumberOfUnits() ) {
if (u1 > leftJCF.length) {
throw new RuntimeException("The left unit index [" + u1 + "] is out of range: this file contains ["
+ getNumberOfUnits() + "] units.");
}
if (u2 < 0) {
throw new RuntimeException("The right unit index [" + u2 + "] is out of range: a unit index can't be negative.");
}
// if ( u2 > getNumberOfUnits() ) {
if (u2 > leftJCF.length) {
throw new RuntimeException("The right unit index [" + u2 + "] is out of range: this file contains ["
+ getNumberOfUnits() + "] units.");
}
if (debugShowCostGraph) {
jcr.tick();
}
/* Cumulate the join costs for each feature */
double res = 0.0;
float[] v1 = rightJCF[u1];
float[] v2 = leftJCF[u2];
for (int i = 0; i < v1.length; i++) {
float a = v1[i];
float b = v2[i];
// if (!Float.isNaN(v1[i]) && !Float.isNaN(v2[i])) {
if (!(a != a) && !(b != b)) {
double c;
if (isLinear[i]) {
c = featureWeight[i] * (a > b ? (a - b) : (b - a));
} else {
c = featureWeight[i] * weightFunction[i].cost(a, b);
}
res += c;
if (debugShowCostGraph) {
cumulWeightedSignalCosts[i] += wSignal * c;
}
} // if anything is NaN, count the cost as 0.
}
return (res);
}
/**
* A combined cost computation, as a weighted sum of the signal-based cost (computed from the units) and the phonetics-based
* cost (computed from the targets).
*
* @param t1
* The left target.
* @param u1
* The left unit.
* @param t2
* The right target.
* @param u2
* The right unit.
*
* @return the cost of joining the left unit with the right unit, as a non-negative value.
*/
public double cost(Target t1, Unit u1, Target t2, Unit u2) {
// Units of length 0 cannot be joined:
if (u1.duration == 0 || u2.duration == 0)
return Double.POSITIVE_INFINITY;
// In the case of diphones, replace them with the relevant part:
boolean bothDiphones = true;
if (u1 instanceof DiphoneUnit) {
u1 = ((DiphoneUnit) u1).right;
} else {
bothDiphones = false;
}
if (u2 instanceof DiphoneUnit) {
u2 = ((DiphoneUnit) u2).left;
} else {
bothDiphones = false;
}
if (u1.index + 1 == u2.index)
return 0;
// Either not half phone synthesis, or at a diphone boundary
double cost = 1; // basic penalty for joins of non-contiguous units.
if (bothDiphones && precompiledCosts != null) {
cost += precompiledCosts.cost(t1, u1, t2, u2);
} else { // need to actually compute the cost
cost += cost(u1.index, u2.index);
}
return cost;
}
/**
* A phonetic join cost, computed solely from the target.
*
* @param t1
* the left target
* @param t2
* the right target
* @return a non-negative join cost, usually between 0 (best) and 1 (worst).
* @deprecated
*/
protected double cost(Target t1, Target t2) {
// TODO: This is really ad hoc for the moment. Redo once we know what we are doing.
// Add penalties for a number of criteria.
double cost = 0;
ByteValuedFeatureProcessor stressProcessor = new MaryGenericFeatureProcessors.Stressed("",
new MaryGenericFeatureProcessors.SyllableNavigator());
// Stressed?
boolean stressed1 = stressProcessor.process(t1) == (byte) 1;
boolean stressed2 = stressProcessor.process(t1) == (byte) 1;
// Try to avoid joining in a stressed syllable:
if (stressed1 || stressed2)
cost += 0.2;
Allophone p1 = t1.getAllophone();
Allophone p2 = t2.getAllophone();
// Discourage joining vowels:
if (p1.isVowel() || p2.isVowel())
cost += 0.2;
// Discourage joining glides:
if (p1.isGlide() || p2.isGlide())
cost += 0.2;
// Discourage joining voiced segments:
if (p1.isVoiced() || p2.isVoiced())
cost += 0.1;
// If both are voiced, it's really bad
if (p1.isVoiced() && p2.isVoiced())
cost += 0.1;
// Slightly penalize nasals and liquids
if (p1.isNasal() || p2.isNasal())
cost += 0.05;
if (p1.isLiquid() || p2.isLiquid())
cost += 0.05;
// Fricatives -- nothing?
// Plosives -- nothing?
if (cost > 1)
cost = 1;
return cost;
}
public static class JoinCostReporter extends Histogram {
private double[] data;
private int lastN = 0;
private int nCostComputations = 0;
public JoinCostReporter(double[] data) {
super(0, 1, data);
this.data = data;
}
public void start() {
new Thread() {
public void run() {
while (isVisible()) {
try {
Thread.sleep(500);
} catch (InterruptedException ie) {
}
updateGraph();
}
}
}.start();
}
/**
* Register one new cost computation
*/
public void tick() {
nCostComputations++;
}
protected void updateGraph() {
if (nCostComputations == lastN)
return;
lastN = nCostComputations;
double[] newCosts = new double[data.length];
for (int i = 0; i < newCosts.length; i++) {
newCosts[i] = data[i] / nCostComputations;
}
updateData(0, 1, newCosts);
repaint();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy