All Downloads are FREE. Search and download functionalities are using the official Maven repository.

marytts.unitselection.select.JoinCostFeatures Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2006 DFKI GmbH.
 * All Rights Reserved.  Use is subject to license terms.
 *
 * This file is part of MARY TTS.
 *
 * MARY TTS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, version 3 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see .
 *
 */
package marytts.unitselection.select;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.EOFException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.util.Vector;

import marytts.exceptions.MaryConfigurationException;
import marytts.features.ByteValuedFeatureProcessor;
import marytts.features.MaryGenericFeatureProcessors;
import marytts.modules.phonemiser.Allophone;
import marytts.server.MaryProperties;
import marytts.signalproc.display.Histogram;
import marytts.unitselection.data.DiphoneUnit;
import marytts.unitselection.data.Unit;
import marytts.unitselection.weightingfunctions.WeightFunc;
import marytts.unitselection.weightingfunctions.WeightFunctionManager;
import marytts.util.MaryUtils;
import marytts.util.data.MaryHeader;
import marytts.util.io.StreamUtils;

public class JoinCostFeatures implements JoinCostFunction {

	protected float wSignal;
	protected float wPhonetic;

	protected boolean debugShowCostGraph = false;
	protected double[] cumulWeightedSignalCosts = null;
	protected int nCostComputations = 0;

	protected PrecompiledJoinCostReader precompiledCosts;

	protected JoinCostReporter jcr;

	/****************/
	/* DATA FIELDS */
	/****************/
	private MaryHeader hdr = null;

	private float[] featureWeight = null;
	private WeightFunc[] weightFunction = null;
	private boolean[] isLinear = null; // wether the i'th weight function is a linear function

	private float[][] leftJCF = null;
	private float[][] rightJCF = null;

	/****************/
	/* CONSTRUCTORS */
	/****************/

	/**
	 * Empty constructor; when using this, call load() separately to initialise this class.
	 * 
	 * @see #load(String joinFileName, InputStream weightStream, String precompiledCostFileName, float wSignal)
	 */
	public JoinCostFeatures() {
	}

	/**
	 * Constructor which read a Mary Join Cost file.
	 * 
	 * @param fileName
	 *            fileName
	 * @throws IOException
	 *             IOException
	 * @throws MaryConfigurationException
	 *             MaryConfigurationException
	 */
	public JoinCostFeatures(String fileName) throws IOException, MaryConfigurationException {
		load(fileName, null, null, (float) 0.5);
	}

	/**
	 * Initialise this join cost function by reading the appropriate settings from the MaryProperties using the given
	 * configPrefix.
	 * 
	 * @param configPrefix
	 *            the prefix for the (voice-specific) config entries to use when looking up files to load.
	 * @throws MaryConfigurationException
	 *             MaryConfigurationException
	 */
	public void init(String configPrefix) throws MaryConfigurationException {
		String joinFileName = MaryProperties.needFilename(configPrefix + ".joinCostFile");
		String precomputedJoinCostFileName = MaryProperties.getFilename(configPrefix + ".precomputedJoinCostFile");
		float wSignal = Float.parseFloat(MaryProperties.getProperty(configPrefix + ".joincostfunction.wSignal", "1.0"));
		try {
			InputStream joinWeightStream = MaryProperties.getStream(configPrefix + ".joinCostWeights");
			load(joinFileName, joinWeightStream, precomputedJoinCostFileName, wSignal);
		} catch (IOException ioe) {
			throw new MaryConfigurationException("Problem loading join file " + joinFileName, ioe);
		}
	}

	/**
	 * Load weights and values from the given file
	 * 
	 * @param joinFileName
	 *            the file from which to read default weights and join cost features
	 * @param weightStream
	 *            an optional file from which to read weights, taking precedence over
	 * @param precompiledCostFileName
	 *            an optional file containing precompiled join costs
	 * @param wSignal
	 *            Relative weight of the signal-based join costs relative to the phonetic join costs computed from the target
	 * @throws IOException
	 *             IOException
	 * @throws MaryConfigurationException
	 *             MaryConfigurationException
	 */
	public void load(String joinFileName, InputStream weightStream, String precompiledCostFileName, float wSignal)
			throws IOException, MaryConfigurationException {
		loadFromByteBuffer(joinFileName, weightStream, precompiledCostFileName, wSignal);
	}

	/**
	 * Load weights and values from the given file
	 * 
	 * @param joinFileName
	 *            the file from which to read default weights and join cost features
	 * @param weightStream
	 *            an optional file from which to read weights, taking precedence over
	 * @param precompiledCostFileName
	 *            an optional file containing precompiled join costs
	 * @param wSignal
	 *            Relative weight of the signal-based join costs relative to the phonetic join costs computed from the target
	 * @throws IOException
	 *             IOException
	 * @throws MaryConfigurationException
	 *             MaryConfigurationException
	 */
	private void loadFromByteBuffer(String joinFileName, InputStream weightStream, String precompiledCostFileName, float wSignal)
			throws IOException, MaryConfigurationException {
		if (precompiledCostFileName != null) {
			precompiledCosts = new PrecompiledJoinCostReader(precompiledCostFileName);
		}
		this.wSignal = wSignal;
		wPhonetic = 1 - wSignal;
		/* Open the file */
		FileInputStream fis = new FileInputStream(joinFileName);
		FileChannel fc = fis.getChannel();
		ByteBuffer bb = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size());
		/* Read the Mary header */
		hdr = new MaryHeader(bb);
		if (hdr.getType() != MaryHeader.JOINFEATS) {
			throw new IOException("File [" + joinFileName + "] is not a valid Mary join features file.");
		}
		try {
			/* Read the feature weights and feature processors */
			int numberOfFeatures = bb.getInt();
			featureWeight = new float[numberOfFeatures];
			weightFunction = new WeightFunc[numberOfFeatures];
			isLinear = new boolean[numberOfFeatures];
			WeightFunctionManager wfm = new WeightFunctionManager();
			String wfStr = null;
			for (int i = 0; i < numberOfFeatures; i++) {
				featureWeight[i] = bb.getFloat();
				wfStr = StreamUtils.readUTF(bb);
				if ("".equals(wfStr))
					weightFunction[i] = wfm.getWeightFunction("linear");
				else
					weightFunction[i] = wfm.getWeightFunction(wfStr);
			}
			// Overwrite weights and weight functions from file?
			if (weightStream != null) {
				MaryUtils.getLogger("JoinCostFeatures").debug("Overwriting join cost weights");
				Object[] weightData = readJoinCostWeightsStream(weightStream);
				featureWeight = (float[]) weightData[0];
				String[] wf = (String[]) weightData[1];
				if (featureWeight.length != numberOfFeatures)
					throw new IllegalArgumentException("Join cost file contains " + numberOfFeatures
							+ " features, but weight file contains " + featureWeight.length + " feature weights!");
				for (int i = 0; i < numberOfFeatures; i++) {
					weightFunction[i] = wfm.getWeightFunction(wf[i]);
				}
			}
			for (int i = 0; i < numberOfFeatures; i++) {
				isLinear[i] = weightFunction[i].whoAmI().equals("linear");
			}

			/* Read the left and right Join Cost Features */
			int numberOfUnits = bb.getInt();
			FloatBuffer fb = bb.asFloatBuffer();
			leftJCF = new float[numberOfUnits][];
			rightJCF = new float[numberOfUnits][];
			for (int i = 0; i < numberOfUnits; i++) {
				// System.out.println("Reading join features for unit "+i+" out of "+numberOfUnits);
				leftJCF[i] = new float[numberOfFeatures];
				fb.get(leftJCF[i]);
				rightJCF[i] = new float[numberOfFeatures];
				fb.get(rightJCF[i]);
			}
		} catch (EOFException e) {
			IOException ioe = new IOException("The currently read Join Cost File has prematurely reached EOF.");
			ioe.initCause(e);
			throw ioe;

		}
		if (MaryProperties.getBoolean("debug.show.cost.graph")) {
			debugShowCostGraph = true;
			cumulWeightedSignalCosts = new double[featureWeight.length];
			jcr = new JoinCostReporter(cumulWeightedSignalCosts);
			jcr.showInJFrame("Average signal join costs", false, false);
			jcr.start();
		}

	}

	/**
	 * Load weights and values from the given file
	 * 
	 * @param joinFileName
	 *            the file from which to read default weights and join cost features
	 * @param weightStream
	 *            an optional file from which to read weights, taking precedence over
	 * @param precompiledCostFileName
	 *            an optional file containing precompiled join costs
	 * @param wSignal
	 *            Relative weight of the signal-based join costs relative to the phonetic join costs computed from the target
	 * @throws IOException
	 *             IOException
	 * @throws MaryConfigurationException
	 *             MaryConfigurationException
	 */
	private void loadFromStream(String joinFileName, InputStream weightStream, String precompiledCostFileName, float wSignal)
			throws IOException, MaryConfigurationException {
		if (precompiledCostFileName != null) {
			precompiledCosts = new PrecompiledJoinCostReader(precompiledCostFileName);
		}
		this.wSignal = wSignal;
		wPhonetic = 1 - wSignal;
		/* Open the file */
		File fid = new File(joinFileName);
		DataInput raf = new DataInputStream(new BufferedInputStream(new FileInputStream(fid)));
		/* Read the Mary header */
		hdr = new MaryHeader(raf);
		if (hdr.getType() != MaryHeader.JOINFEATS) {
			throw new MaryConfigurationException("File [" + joinFileName + "] is not a valid Mary join features file.");
		}
		try {
			/* Read the feature weights and feature processors */
			int numberOfFeatures = raf.readInt();
			featureWeight = new float[numberOfFeatures];
			weightFunction = new WeightFunc[numberOfFeatures];
			isLinear = new boolean[numberOfFeatures];
			WeightFunctionManager wfm = new WeightFunctionManager();
			String wfStr = null;
			for (int i = 0; i < numberOfFeatures; i++) {
				featureWeight[i] = raf.readFloat();
				wfStr = raf.readUTF();
				if ("".equals(wfStr))
					weightFunction[i] = wfm.getWeightFunction("linear");
				else
					weightFunction[i] = wfm.getWeightFunction(wfStr);
			}
			// Overwrite weights and weight functions from file?
			if (weightStream != null) {
				MaryUtils.getLogger("JoinCostFeatures").debug("Overwriting join cost weights");
				Object[] weightData = readJoinCostWeightsStream(weightStream);
				featureWeight = (float[]) weightData[0];
				String[] wf = (String[]) weightData[1];
				if (featureWeight.length != numberOfFeatures)
					throw new IllegalArgumentException("Join cost file contains " + numberOfFeatures
							+ " features, but weight file contains " + featureWeight.length + " feature weights!");
				for (int i = 0; i < numberOfFeatures; i++) {
					weightFunction[i] = wfm.getWeightFunction(wf[i]);
				}
			}
			for (int i = 0; i < numberOfFeatures; i++) {
				isLinear[i] = weightFunction[i].whoAmI().equals("linear");
			}

			/* Read the left and right Join Cost Features */
			int numberOfUnits = raf.readInt();
			leftJCF = new float[numberOfUnits][];
			rightJCF = new float[numberOfUnits][];
			for (int i = 0; i < numberOfUnits; i++) {
				// System.out.println("Reading join features for unit "+i+" out of "+numberOfUnits);
				leftJCF[i] = new float[numberOfFeatures];
				for (int j = 0; j < numberOfFeatures; j++) {
					leftJCF[i][j] = raf.readFloat();
				}
				rightJCF[i] = new float[numberOfFeatures];
				for (int j = 0; j < numberOfFeatures; j++) {
					rightJCF[i][j] = raf.readFloat();
				}
			}
		} catch (EOFException e) {
			IOException ioe = new IOException("The currently read Join Cost File has prematurely reached EOF.");
			ioe.initCause(e);
			throw ioe;

		}
		if (MaryProperties.getBoolean("debug.show.cost.graph")) {
			debugShowCostGraph = true;
			cumulWeightedSignalCosts = new double[featureWeight.length];
			jcr = new JoinCostReporter(cumulWeightedSignalCosts);
			jcr.showInJFrame("Average signal join costs", false, false);
			jcr.start();
		}

	}

	/**
	 * Read the join cost weight specifications from the given file. The weights will be normalized such that they sum to one.
	 * 
	 * @param fileName
	 *            the text file containing the join weights
	 * @throws IOException
	 *             IOException
	 * @throws FileNotFoundException
	 *             FileNotFoundException
	 * @return readJoinCostWeightsStream(new FileInputStream(fileName))
	 * */
	public static Object[] readJoinCostWeightsFile(String fileName) throws IOException, FileNotFoundException {
		return readJoinCostWeightsStream(new FileInputStream(fileName));
	}

	/**
	 * Read the join cost weight specifications from the given file. The weights will be normalized such that they sum to one.
	 * 
	 * @param weightStream
	 *            the text file containing the join weights
	 * @throws IOException
	 *             IOException
	 * @throws FileNotFoundException
	 *             FileNotFoundException
	 * @return Object[] { fw, wfun }
	 * */
	public static Object[] readJoinCostWeightsStream(InputStream weightStream) throws IOException, FileNotFoundException {
		Vector v = new Vector(16, 16);
		Vector vf = new Vector(16, 16);
		/* Open the file */
		BufferedReader in = new BufferedReader(new InputStreamReader(weightStream, "UTF-8"));
		/* Loop through the lines */
		String line = null;
		String[] fields = null;
		float sumOfWeights = 0;
		while ((line = in.readLine()) != null) {
			// System.out.println( line );
			line = line.split("#", 2)[0]; // Remove possible trailing comments
			line = line.trim(); // Remove leading and trailing blanks
			if (line.equals(""))
				continue; // Empty line: don't parse
			line = line.split(":", 2)[1].trim(); // Remove the line number and :
			// System.out.print( "CLEANED: [" + line + "]" );
			fields = line.split("\\s", 2); // Separate the weight value from the function name
			float aWeight = Float.parseFloat(fields[0]);
			sumOfWeights += aWeight;
			v.add(new Float(aWeight)); // Push the weight
			vf.add(fields[1]); // Push the function
			// System.out.println( "NBFEA=" + numberOfFeatures );
		}
		in.close();
		// System.out.flush();
		/* Export the vector of weighting function names as a String array: */
		String[] wfun = (String[]) vf.toArray(new String[vf.size()]);
		/*
		 * For the weights, create a float array containing the weights, normalized such that they sum to one:
		 */
		float[] fw = new float[v.size()];
		for (int i = 0; i < fw.length; i++) {
			Float aWeight = (Float) v.get(i);
			fw[i] = aWeight.floatValue() / sumOfWeights;
		}
		/* Return these as an Object[2]. */
		return new Object[] { fw, wfun };
	}

	/*****************/
	/* ACCESSORS */
	/*****************/

	/**
	 * Get the number of feature weights and weighting functions.
	 * 
	 * @return (featureWeight.length)
	 */
	public int getNumberOfFeatures() {
		return (featureWeight.length);
	}

	/**
	 * Get the number of units.
	 * 
	 * @return (leftJCF.length)
	 */
	public int getNumberOfUnits() {
		return (leftJCF.length);
	}

	/**
	 * Gets the array of left join cost features for a particular unit index.
	 * 
	 * @param u
	 *            The index of the considered unit.
	 * 
	 * @return The array of left join cost features for the given unit.
	 */
	public float[] getLeftJCF(int u) {
		if (u < 0) {
			throw new RuntimeException("The unit index [" + u + "] is out of range: a unit index can't be negative.");
		}
		if (u > getNumberOfUnits()) {
			throw new RuntimeException("The unit index [" + u + "] is out of range: this file contains [" + getNumberOfUnits()
					+ "] units.");
		}
		return (leftJCF[u]);
	}

	/**
	 * Gets the array of right join cost features for a particular unit index.
	 * 
	 * @param u
	 *            The index of the considered unit.
	 * 
	 * @return The array of right join cost features for the given unit.
	 */
	public float[] getRightJCF(int u) {
		if (u < 0) {
			throw new RuntimeException("The unit index [" + u + "] is out of range: a unit index can't be negative.");
		}
		if (u > getNumberOfUnits()) {
			throw new RuntimeException("The unit index [" + u + "] is out of range: this file contains [" + getNumberOfUnits()
					+ "] units.");
		}
		return (rightJCF[u]);
	}

	/*****************/
	/* MISC METHODS */
	/*****************/

	/**
	 * Deliver the join cost between two units described by their index.
	 * 
	 * @param u1
	 *            the left unit
	 * @param u2
	 *            the right unit
	 * 
	 * @return the cost of joining the right Join Cost features of the left unit with the left Join Cost Features of the right
	 *         unit.
	 */
	public double cost(int u1, int u2) {
		/* Check the given indexes */
		if (u1 < 0) {
			throw new RuntimeException("The left unit index [" + u1 + "] is out of range: a unit index can't be negative.");
		}
		// if ( u1 > getNumberOfUnits() ) {
		if (u1 > leftJCF.length) {
			throw new RuntimeException("The left unit index [" + u1 + "] is out of range: this file contains ["
					+ getNumberOfUnits() + "] units.");
		}
		if (u2 < 0) {
			throw new RuntimeException("The right unit index [" + u2 + "] is out of range: a unit index can't be negative.");
		}
		// if ( u2 > getNumberOfUnits() ) {
		if (u2 > leftJCF.length) {
			throw new RuntimeException("The right unit index [" + u2 + "] is out of range: this file contains ["
					+ getNumberOfUnits() + "] units.");
		}
		if (debugShowCostGraph) {
			jcr.tick();
		}
		/* Cumulate the join costs for each feature */
		double res = 0.0;
		float[] v1 = rightJCF[u1];
		float[] v2 = leftJCF[u2];
		for (int i = 0; i < v1.length; i++) {
			float a = v1[i];
			float b = v2[i];
			// if (!Float.isNaN(v1[i]) && !Float.isNaN(v2[i])) {
			if (!(a != a) && !(b != b)) {
				double c;
				if (isLinear[i]) {
					c = featureWeight[i] * (a > b ? (a - b) : (b - a));
				} else {
					c = featureWeight[i] * weightFunction[i].cost(a, b);
				}
				res += c;
				if (debugShowCostGraph) {
					cumulWeightedSignalCosts[i] += wSignal * c;
				}
			} // if anything is NaN, count the cost as 0.
		}
		return (res);
	}

	/**
	 * A combined cost computation, as a weighted sum of the signal-based cost (computed from the units) and the phonetics-based
	 * cost (computed from the targets).
	 * 
	 * @param t1
	 *            The left target.
	 * @param u1
	 *            The left unit.
	 * @param t2
	 *            The right target.
	 * @param u2
	 *            The right unit.
	 * 
	 * @return the cost of joining the left unit with the right unit, as a non-negative value.
	 */
	public double cost(Target t1, Unit u1, Target t2, Unit u2) {
		// Units of length 0 cannot be joined:
		if (u1.duration == 0 || u2.duration == 0)
			return Double.POSITIVE_INFINITY;
		// In the case of diphones, replace them with the relevant part:
		boolean bothDiphones = true;
		if (u1 instanceof DiphoneUnit) {
			u1 = ((DiphoneUnit) u1).right;
		} else {
			bothDiphones = false;
		}
		if (u2 instanceof DiphoneUnit) {
			u2 = ((DiphoneUnit) u2).left;
		} else {
			bothDiphones = false;
		}

		if (u1.index + 1 == u2.index)
			return 0;
		// Either not half phone synthesis, or at a diphone boundary
		double cost = 1; // basic penalty for joins of non-contiguous units.
		if (bothDiphones && precompiledCosts != null) {
			cost += precompiledCosts.cost(t1, u1, t2, u2);
		} else { // need to actually compute the cost
			cost += cost(u1.index, u2.index);
		}
		return cost;
	}

	/**
	 * A phonetic join cost, computed solely from the target.
	 * 
	 * @param t1
	 *            the left target
	 * @param t2
	 *            the right target
	 * @return a non-negative join cost, usually between 0 (best) and 1 (worst).
	 * @deprecated
	 */
	protected double cost(Target t1, Target t2) {
		// TODO: This is really ad hoc for the moment. Redo once we know what we are doing.
		// Add penalties for a number of criteria.
		double cost = 0;
		ByteValuedFeatureProcessor stressProcessor = new MaryGenericFeatureProcessors.Stressed("",
				new MaryGenericFeatureProcessors.SyllableNavigator());
		// Stressed?
		boolean stressed1 = stressProcessor.process(t1) == (byte) 1;
		boolean stressed2 = stressProcessor.process(t1) == (byte) 1;
		// Try to avoid joining in a stressed syllable:
		if (stressed1 || stressed2)
			cost += 0.2;
		Allophone p1 = t1.getAllophone();
		Allophone p2 = t2.getAllophone();

		// Discourage joining vowels:
		if (p1.isVowel() || p2.isVowel())
			cost += 0.2;
		// Discourage joining glides:
		if (p1.isGlide() || p2.isGlide())
			cost += 0.2;
		// Discourage joining voiced segments:
		if (p1.isVoiced() || p2.isVoiced())
			cost += 0.1;
		// If both are voiced, it's really bad
		if (p1.isVoiced() && p2.isVoiced())
			cost += 0.1;
		// Slightly penalize nasals and liquids
		if (p1.isNasal() || p2.isNasal())
			cost += 0.05;
		if (p1.isLiquid() || p2.isLiquid())
			cost += 0.05;
		// Fricatives -- nothing?
		// Plosives -- nothing?

		if (cost > 1)
			cost = 1;
		return cost;
	}

	public static class JoinCostReporter extends Histogram {
		private double[] data;
		private int lastN = 0;
		private int nCostComputations = 0;

		public JoinCostReporter(double[] data) {
			super(0, 1, data);
			this.data = data;
		}

		public void start() {
			new Thread() {
				public void run() {
					while (isVisible()) {
						try {
							Thread.sleep(500);
						} catch (InterruptedException ie) {
						}
						updateGraph();
					}
				}
			}.start();
		}

		/**
		 * Register one new cost computation
		 */
		public void tick() {
			nCostComputations++;
		}

		protected void updateGraph() {
			if (nCostComputations == lastN)
				return;
			lastN = nCostComputations;
			double[] newCosts = new double[data.length];
			for (int i = 0; i < newCosts.length; i++) {
				newCosts[i] = data[i] / nCostComputations;
			}
			updateData(0, 1, newCosts);
			repaint();
		}
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy