All Downloads are FREE. Search and download functionalities are using the official Maven repository.

marytts.tools.voiceimport.traintrees.AgglomerativeClusterer Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2009 DFKI GmbH.
 * All Rights Reserved.  Use is subject to license terms.
 *
 * This file is part of MARY TTS.
 *
 * MARY TTS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, version 3 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see .
 *
 */

package marytts.tools.voiceimport.traintrees;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Future;

import marytts.cart.CART;
import marytts.cart.DecisionNode;
import marytts.cart.DirectedGraph;
import marytts.cart.DirectedGraphNode;
import marytts.cart.FeatureVectorCART;
import marytts.cart.LeafNode;
import marytts.cart.Node;
import marytts.cart.LeafNode.FeatureVectorLeafNode;
import marytts.cart.impose.FeatureArrayIndexer;
import marytts.features.FeatureDefinition;
import marytts.features.FeatureVector;

/**
 * @author marc
 *
 */
public class AgglomerativeClusterer {
	private static final float SINGLE_ITEM_IMPURITY = 0;
	private FeatureVector[] trainingFeatures;
	private FeatureVector[] testFeatures;
	private Map impurities = new HashMap();
	private FeatureDefinition featureDefinition;
	private int numByteFeatures;
	private int[] availableFeatures;
	// private double globalMean;
	private double globalStddev;
	private DistanceMeasure dist;

	private double minFSGI, minCriterion;
	private int iBestFeature;

	private float[][] squaredDistances;

	private DirectedGraph graph;
	private int[] prevFeatureList;
	private double prevFSGI;
	private double prevTestDataDistance;
	private boolean canClusterMore = true;

	public AgglomerativeClusterer(FeatureVector[] features, FeatureDefinition featureDefinition, List featuresToUse,
			DistanceMeasure dist) {
		this(features, featureDefinition, featuresToUse, dist, 0.1f);
	}

	public AgglomerativeClusterer(FeatureVector[] features, FeatureDefinition featureDefinition, List featuresToUse,
			DistanceMeasure dist, float proportionTestData) {
		// Now replace all feature vectors with feature vectors whose unit index
		// corresponds to the distance matrix in squaredDistance:
		for (int i = 0; i < features.length; i++) {
			features[i] = new FeatureVector(features[i].getByteValuedDiscreteFeatures(),
					features[i].getShortValuedDiscreteFeatures(), features[i].getContinuousFeatures(), i);
		}

		this.dist = dist;

		this.globalStddev = Math.sqrt(((F0ContourPolynomialDistanceMeasure) dist).computeVariance(features));

		System.out.println("Global stddev: " + globalStddev);
		/*
		 * // Get an estimate of the global mean by sampling: estimateGlobalMean(features, dist);
		 * 
		 * // Precompute distances and set unit index features accordingly System.out.println("Precomputing distances..."); long
		 * startTime = System.currentTimeMillis(); squaredDistances = new float[features.length-1][]; for (int i=0;
		 * i 0)
				prevNLeaves++;
		}
		iBestFeature = -1;
		minFSGI = Double.POSITIVE_INFINITY;
		minCriterion = Double.POSITIVE_INFINITY;
		Set> openJobs = new HashSet>();
		// Loop over all unused discrete features, and compute their Global Impurity
		for (int f = 0; f < availableFeatures.length; f++) {
			int fi = availableFeatures[f];
			boolean featureAlreadyUsed = false;
			for (int i = 0; i < prevFeatureList.length; i++) {
				if (prevFeatureList[i] == fi) {
					featureAlreadyUsed = true;
					break;
				}
			}
			if (featureAlreadyUsed)
				continue;
			newFeatureList[newFeatureList.length - 1] = fi;
			fai.deepSort(newFeatureList);
			CART testCART = new FeatureVectorCART(fai.getTree(), fai);
			assert testCART.getRootNode().getNumberOfData() == trainingFeatures.length;
			verifyFeatureQuality(fi, testCART, prevNLeaves);
		}

		newFeatureList[newFeatureList.length - 1] = iBestFeature;
		fai.deepSort(newFeatureList);
		CART bestFeatureCart = new FeatureVectorCART(fai.getTree(), fai);
		int nLeaves = 0;
		for (LeafNode leaf : bestFeatureCart.getLeafNodes()) {
			if (leaf != null && leaf.getNumberOfData() > 0)
				nLeaves++;
		}
		long featSelectedTime = System.currentTimeMillis();

		// Now walk through graphSoFar and bestFeatureCart in parallel,
		// and add the leaves of bestFeatureCart into graphSoFar in order
		// to enable clustering:
		Node fNode = bestFeatureCart.getRootNode();
		Node gNode = graph.getRootNode();

		List newLeavesList = new ArrayList();
		updateGraphFromTree((DecisionNode) fNode, (DirectedGraphNode) gNode, newLeavesList);
		DirectedGraphNode[] newLeaves = newLeavesList.toArray(new DirectedGraphNode[0]);
		System.out.printf("Level %2d: %25s (%5d leaves, gi=%7.3f -->", newFeatureList.length,
				featureDefinition.getFeatureName(iBestFeature), newLeaves.length, minFSGI);

		float[][] deltaGI = new float[newLeaves.length - 1][];
		for (int i = 0; i < newLeaves.length - 1; i++) {
			deltaGI[i] = new float[newLeaves.length - i - 1];
			for (int j = i + 1; j < newLeaves.length; j++) {
				deltaGI[i][j - i - 1] = (float) computeDeltaGI(newLeaves[i], newLeaves[j]);
			}
		}
		int numLeavesLeft = newLeaves.length;

		// Now cluster the leaves
		float minDeltaGI, threshold;
		int bestPair1, bestPair2;
		do {
			// threshold = 100*(float)(Math.log(numLeavesLeft)-Math.log(numLeavesLeft-1));
			// threshold = (float)(Math.log(numLeavesLeft)-Math.log(numLeavesLeft-1));
			threshold = 0;
			// threshold = 0.01f;
			minDeltaGI = threshold; // if we cannot find any that is better, stop.
			bestPair1 = bestPair2 = -1;
			for (int i = 0; i < newLeaves.length - 1; i++) {
				if (newLeaves[i] == null)
					continue;
				for (int j = i + 1; j < newLeaves.length; j++) {
					if (newLeaves[j] == null)
						continue;
					if (deltaGI[i][j - i - 1] < minDeltaGI) {
						bestPair1 = i;
						bestPair2 = j;
						minDeltaGI = deltaGI[i][j - i - 1];
					}
				}
			}
			// System.out.printf("NumLeavesLeft=%4d, threshold=%f, minDeltaGI=%f\n", numLeavesLeft, threshold, minDeltaGI);
			if (minDeltaGI < threshold) { // found something to merge
				mergeLeaves(newLeaves[bestPair1], newLeaves[bestPair2]);
				numLeavesLeft--;
				// System.out.println("Merged leaves "+bestPair1+" and "+bestPair2+" (deltaGI: "+minDeltaGI+")");
				newLeaves[bestPair2] = null;
				// Update deltaGI table:
				for (int i = 0; i < bestPair2; i++) {
					deltaGI[i][bestPair2 - i - 1] = Float.NaN;
				}
				for (int j = bestPair2 + 1; j < newLeaves.length; j++) {
					deltaGI[bestPair2][j - bestPair2 - 1] = Float.NaN;
				}
				for (int i = 0; i < bestPair1; i++) {
					if (newLeaves[i] != null)
						deltaGI[i][bestPair1 - i - 1] = (float) computeDeltaGI(newLeaves[i], newLeaves[bestPair1]);
				}
				for (int j = bestPair1 + 1; j < newLeaves.length; j++) {
					if (newLeaves[j] != null)
						deltaGI[bestPair1][j - bestPair1 - 1] = (float) computeDeltaGI(newLeaves[bestPair1], newLeaves[j]);
				}
			}
		} while (minDeltaGI < threshold);

		int nLeavesLeft = 0;
		List survivors = new ArrayList();
		for (int i = 0; i < newLeaves.length; i++) {
			if (newLeaves[i] != null) {
				nLeavesLeft++;
				survivors.add((LeafNode) ((DirectedGraphNode) newLeaves[i]).getLeafNode());
			}
		}

		long clusteredTime = System.currentTimeMillis();

		System.out.printf("%5d leaves, gi=%7.3f).", nLeavesLeft, computeGlobalImpurity(survivors));

		deltaGI = null;
		impurities.clear();

		float testDist = rmsDistanceTestData(graph);
		System.out.printf(" Distance test data: %5.3f", testDist);

		System.out.printf(" | fs %5dms, cl %5dms", (featSelectedTime - startTime), (clusteredTime - featSelectedTime));

		System.out.println();

		// Stop criterion: stop if feature selection does not succeed in reducing global impurity further,
		// and at the same time, the test data approximation is getting worse.
		if (minFSGI > prevFSGI && testDist > prevTestDataDistance) {
			canClusterMore = false;
		}
		// Iteration step:
		prevFeatureList = newFeatureList;
		prevFSGI = minFSGI;
		prevTestDataDistance = testDist;

		return graph;
	}

	private void verifyFeatureQuality(int fi, CART testCART, int prevNLeaves) {
		List leaves = new ArrayList();
		int nLeaves = 0;
		for (LeafNode leaf : testCART.getLeafNodes()) {
			if (leaf.isEmpty())
				continue;
			leaves.add(leaf);
			nLeaves++;
		}
		if (nLeaves <= prevNLeaves) { // this feature adds no leaf
			return; // will not consider this further
		}
		double gi = computeGlobalImpurity(leaves, minCriterion);
		// More leaves cost a bit:
		double sizeBias = Math.log((float) nLeaves / prevNLeaves);
		assert sizeBias > 0;
		// double sizeBias = (float)nLeaves/prevNLeaves;
		// assert sizeBias > 1;

		// System.out.printf("%30s: GI=%.3f bias=%.7f\n", featureDefinition.getFeatureName(fi),gi,sizeBias);
		double criterion = gi;
		/*
		 * if (gi > globalMean) { // The best one is the one that can reach a small gi with a small increase in number of leaves
		 * criterion = globalMean + (gi-globalMean) * (1+sizeBias); } else { // leave as is, no size bias }
		 */
		if (criterion < minCriterion) {
			setMinCriterion(criterion);
			setMinFSGI(gi);
			setBestFeature(fi);
		}

	}

	/**
	 * Estimate the mean of all *distances* in the training set.
	 * 
	 * @param leaves
	 *            leaves
	 * @return computeglobalimpurity(leaves, double.Positive_infinity)
	 */
	/*
	 * private void estimateGlobalMean(FeatureVector[] data, DistanceMeasure dist) { int sampleSize = 100000;
	 * System.out.println("Estimating global mean by random sampling "+sampleSize+" distances"); long startTime =
	 * System.currentTimeMillis(); // Compute mean and stddev using recurrence relation, attributed by Donald Knuth // (The Art of
	 * Computer Programming, Volume 2: Seminumerical Algorithms, Section 4.2.2) // to B.P. Welford, Technometrics, 4, (1962),
	 * 419-420. // M(1) = x(1), M(k) = M(k-1) + (x(k) - M(k-1))/k // S(1) = 0, S(k) = S(k-1) + (x(k) - M(k-1))*(x(k)-M(k)) // for
	 * 2 <= k <= n, then sigma = sqrt(S(n)/(n-1)) // globalMean = 0; Random random = new Random(); for (int k=1; k I(%d)=%.3f\n", deltaGI, len1, imp1, len2, imp2, len12,
		// imp12);
		return deltaGI;
	}

	private void mergeLeaves(DirectedGraphNode dgn1, DirectedGraphNode dgn2) {
		// Copy all data from dgn2 into dgn1
		FeatureVectorLeafNode l1 = (FeatureVectorLeafNode) dgn1.getLeafNode();
		FeatureVectorLeafNode l2 = (FeatureVectorLeafNode) dgn2.getLeafNode();
		FeatureVector[] fv1 = l1.getFeatureVectors();
		FeatureVector[] fv2 = l2.getFeatureVectors();
		FeatureVector[] newFV = new FeatureVector[fv1.length + fv2.length];
		System.arraycopy(fv1, 0, newFV, 0, fv1.length);
		System.arraycopy(fv2, 0, newFV, fv1.length, fv2.length);
		l1.setFeatureVectors(newFV);
		// then update all mother/daughter relationships
		Set dgn2Mothers = new HashSet(dgn2.getMothers());
		for (Node mother : dgn2Mothers) {
			if (mother instanceof DecisionNode) {
				DecisionNode dm = (DecisionNode) mother;
				dm.replaceDaughter(dgn1, dgn2.getNodeIndex(mother));
			} else if (mother instanceof DirectedGraphNode) {
				DirectedGraphNode gm = (DirectedGraphNode) mother;
				gm.setLeafNode(dgn1);
			}
			dgn2.removeMother(mother);
		}
		dgn2.setLeafNode(null);
		l2.setMother(null, 0);
		// and remove impurity entries:
		try {
			impurities.remove(l1);
			impurities.remove(l2);
		} catch (NullPointerException e) {
			e.printStackTrace();
			System.err.println("Impurities: " + impurities + ", l1:" + l1 + ", l2:" + l2);
		}
	}

	private void updateGraphFromTree(DecisionNode treeNode, DirectedGraphNode graphNode, List newLeaves) {
		int treeFeatureIndex = treeNode.getFeatureIndex();
		int treeNumDaughters = treeNode.getNumberOfDaugthers();
		DecisionNode graphDecisionNode = graphNode.getDecisionNode();
		if (graphDecisionNode != null) {
			// Sanity check: the two must be aligned: same feature, same number of children
			int graphFeatureIndex = graphDecisionNode.getFeatureIndex();
			assert treeFeatureIndex == graphFeatureIndex : "Tree indices out of sync!";
			assert treeNumDaughters == graphDecisionNode.getNumberOfDaugthers() : "Tree structure out of sync!";
			// OK, now recursively call ourselves for all daughters
			for (int i = 0; i < treeNumDaughters; i++) {
				// We expect the next tree node to be a decision node (unless it is an empty node),
				// because the level just above the leaves does not exist in graph yet.
				Node nextTreeNode = treeNode.getDaughter(i);
				if (nextTreeNode == null)
					continue;
				else if (nextTreeNode instanceof LeafNode) {
					assert ((LeafNode) nextTreeNode).getNumberOfData() == 0;
					continue;
				}
				assert nextTreeNode instanceof DecisionNode;
				DirectedGraphNode nextGraphNode = (DirectedGraphNode) graphDecisionNode.getDaughter(i);
				updateGraphFromTree((DecisionNode) nextTreeNode, nextGraphNode, newLeaves);
			}
		} else {
			// No structure in graph yet which corresponds to tree.
			// This is what we actually want to do.
			if (featureDefinition.isByteFeature(treeFeatureIndex)) {
				graphDecisionNode = new DecisionNode.ByteDecisionNode(treeFeatureIndex, treeNumDaughters, featureDefinition);
			} else {
				assert featureDefinition.isShortFeature(treeFeatureIndex) : "Only support byte and short features";
				graphDecisionNode = new DecisionNode.ShortDecisionNode(treeFeatureIndex, treeNumDaughters, featureDefinition);
			}
			assert treeNumDaughters == graphDecisionNode.getNumberOfDaugthers();
			graphNode.setDecisionNode(graphDecisionNode);
			for (int i = 0; i < treeNumDaughters; i++) {
				// we expect the next tree node to be a leaf node
				LeafNode nextTreeNode = (LeafNode) treeNode.getDaughter(i);
				// Now create the new daughter number i of graphDecisionNode.
				// It is a DirectedGraphNode containing no decision tree but
				// a leaf node, which is itself a DirectedGraphNode with no
				// decision node but a leaf node:
				if (nextTreeNode != null && nextTreeNode.getNumberOfData() > 0) {
					DirectedGraphNode daughterLeafNode = new DirectedGraphNode(null, nextTreeNode);
					DirectedGraphNode daughterNode = new DirectedGraphNode(null, daughterLeafNode);
					graphDecisionNode.addDaughter(daughterNode);
					newLeaves.add(daughterLeafNode);
				} else {
					graphDecisionNode.addDaughter(null);
				}
			}
		}
	}

	private float rmsDistanceTestData(DirectedGraph graph) {
		// return rmsMutualDistanceTestData(graph);
		return rmsMeanDistanceTestData(graph);
	}

	private float rmsMeanDistanceTestData(DirectedGraph graph) {
		float avgDist = 0;
		for (int i = 0; i < testFeatures.length; i++) {
			int ti = testFeatures[i].getUnitIndex();
			FeatureVector[] leafData = (FeatureVector[]) graph.interpret(testFeatures[i]);
			float[] mean = ((F0ContourPolynomialDistanceMeasure) dist).computeMean(leafData);
			float oneDist = ((F0ContourPolynomialDistanceMeasure) dist).squaredDistance(testFeatures[i], mean);
			oneDist = (float) Math.sqrt(oneDist);
			avgDist += oneDist;
		}
		avgDist /= testFeatures.length;

		return avgDist;

	}

	private float rmsMutualDistanceTestData(DirectedGraph graph) {
		float avgDist = 0;
		for (int i = 0; i < testFeatures.length; i++) {
			int ti = testFeatures[i].getUnitIndex();
			FeatureVector[] leafData = (FeatureVector[]) graph.interpret(testFeatures[i]);
			float oneDist = 0;
			for (int j = 0; j < leafData.length; j++) {
				int lj = leafData[j].getUnitIndex();
				if (ti < lj) {
					oneDist += squaredDistances[ti][lj - ti - 1];
				} else if (lj < ti) {
					oneDist += squaredDistances[lj][ti - lj - 1];
				}
			}
			oneDist /= leafData.length;
			oneDist = (float) Math.sqrt(oneDist);
			avgDist += oneDist;
		}
		avgDist /= testFeatures.length;

		return avgDist;
	}

	private void setMinCriterion(double value) {
		minCriterion = value;
	}

	private void setMinFSGI(double value) {
		minFSGI = value;
	}

	private void setBestFeature(int featureIndex) {
		iBestFeature = featureIndex;
	}

	private void debugOut(DirectedGraph graph) {
		for (Iterator it = graph.getNodeIterator(); it.hasNext();) {
			Node next = it.next();
			debugOut(next);
		}
	}

	private void debugOut(CART graph) {
		Node root = graph.getRootNode();
		debugOut(root);
	}

	private void debugOut(Node node) {
		if (node instanceof DirectedGraphNode)
			debugOut((DirectedGraphNode) node);
		else if (node instanceof LeafNode)
			debugOut((LeafNode) node);
		else
			debugOut((DecisionNode) node);
	}

	private void debugOut(DirectedGraphNode node) {
		System.out.println("DGN");
		if (node.getLeafNode() != null)
			debugOut(node.getLeafNode());
		if (node.getDecisionNode() != null)
			debugOut(node.getDecisionNode());
	}

	private void debugOut(LeafNode node) {
		System.out.println("Leaf: " + node.getDecisionPath());
	}

	private void debugOut(DecisionNode node) {
		System.out.println("DN with " + node.getNumberOfDaugthers() + " daughters: " + node.toString());
		for (int i = 0; i < node.getNumberOfDaugthers(); i++) {
			Node daughter = node.getDaughter(i);
			if (daughter == null)
				System.out.println("null");
			else
				debugOut(daughter);
		}
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy