marytts.tools.voiceimport.traintrees.AgglomerativeClusterer Maven / Gradle / Ivy
The newest version!
/**
* Copyright 2009 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*
*/
package marytts.tools.voiceimport.traintrees;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Future;
import marytts.cart.CART;
import marytts.cart.DecisionNode;
import marytts.cart.DirectedGraph;
import marytts.cart.DirectedGraphNode;
import marytts.cart.FeatureVectorCART;
import marytts.cart.LeafNode;
import marytts.cart.Node;
import marytts.cart.LeafNode.FeatureVectorLeafNode;
import marytts.cart.impose.FeatureArrayIndexer;
import marytts.features.FeatureDefinition;
import marytts.features.FeatureVector;
/**
* @author marc
*
*/
public class AgglomerativeClusterer {
private static final float SINGLE_ITEM_IMPURITY = 0;
private FeatureVector[] trainingFeatures;
private FeatureVector[] testFeatures;
private Map impurities = new HashMap();
private FeatureDefinition featureDefinition;
private int numByteFeatures;
private int[] availableFeatures;
// private double globalMean;
private double globalStddev;
private DistanceMeasure dist;
private double minFSGI, minCriterion;
private int iBestFeature;
private float[][] squaredDistances;
private DirectedGraph graph;
private int[] prevFeatureList;
private double prevFSGI;
private double prevTestDataDistance;
private boolean canClusterMore = true;
public AgglomerativeClusterer(FeatureVector[] features, FeatureDefinition featureDefinition, List featuresToUse,
DistanceMeasure dist) {
this(features, featureDefinition, featuresToUse, dist, 0.1f);
}
public AgglomerativeClusterer(FeatureVector[] features, FeatureDefinition featureDefinition, List featuresToUse,
DistanceMeasure dist, float proportionTestData) {
// Now replace all feature vectors with feature vectors whose unit index
// corresponds to the distance matrix in squaredDistance:
for (int i = 0; i < features.length; i++) {
features[i] = new FeatureVector(features[i].getByteValuedDiscreteFeatures(),
features[i].getShortValuedDiscreteFeatures(), features[i].getContinuousFeatures(), i);
}
this.dist = dist;
this.globalStddev = Math.sqrt(((F0ContourPolynomialDistanceMeasure) dist).computeVariance(features));
System.out.println("Global stddev: " + globalStddev);
/*
* // Get an estimate of the global mean by sampling: estimateGlobalMean(features, dist);
*
* // Precompute distances and set unit index features accordingly System.out.println("Precomputing distances..."); long
* startTime = System.currentTimeMillis(); squaredDistances = new float[features.length-1][]; for (int i=0;
* i 0)
prevNLeaves++;
}
iBestFeature = -1;
minFSGI = Double.POSITIVE_INFINITY;
minCriterion = Double.POSITIVE_INFINITY;
Set> openJobs = new HashSet>();
// Loop over all unused discrete features, and compute their Global Impurity
for (int f = 0; f < availableFeatures.length; f++) {
int fi = availableFeatures[f];
boolean featureAlreadyUsed = false;
for (int i = 0; i < prevFeatureList.length; i++) {
if (prevFeatureList[i] == fi) {
featureAlreadyUsed = true;
break;
}
}
if (featureAlreadyUsed)
continue;
newFeatureList[newFeatureList.length - 1] = fi;
fai.deepSort(newFeatureList);
CART testCART = new FeatureVectorCART(fai.getTree(), fai);
assert testCART.getRootNode().getNumberOfData() == trainingFeatures.length;
verifyFeatureQuality(fi, testCART, prevNLeaves);
}
newFeatureList[newFeatureList.length - 1] = iBestFeature;
fai.deepSort(newFeatureList);
CART bestFeatureCart = new FeatureVectorCART(fai.getTree(), fai);
int nLeaves = 0;
for (LeafNode leaf : bestFeatureCart.getLeafNodes()) {
if (leaf != null && leaf.getNumberOfData() > 0)
nLeaves++;
}
long featSelectedTime = System.currentTimeMillis();
// Now walk through graphSoFar and bestFeatureCart in parallel,
// and add the leaves of bestFeatureCart into graphSoFar in order
// to enable clustering:
Node fNode = bestFeatureCart.getRootNode();
Node gNode = graph.getRootNode();
List newLeavesList = new ArrayList();
updateGraphFromTree((DecisionNode) fNode, (DirectedGraphNode) gNode, newLeavesList);
DirectedGraphNode[] newLeaves = newLeavesList.toArray(new DirectedGraphNode[0]);
System.out.printf("Level %2d: %25s (%5d leaves, gi=%7.3f -->", newFeatureList.length,
featureDefinition.getFeatureName(iBestFeature), newLeaves.length, minFSGI);
float[][] deltaGI = new float[newLeaves.length - 1][];
for (int i = 0; i < newLeaves.length - 1; i++) {
deltaGI[i] = new float[newLeaves.length - i - 1];
for (int j = i + 1; j < newLeaves.length; j++) {
deltaGI[i][j - i - 1] = (float) computeDeltaGI(newLeaves[i], newLeaves[j]);
}
}
int numLeavesLeft = newLeaves.length;
// Now cluster the leaves
float minDeltaGI, threshold;
int bestPair1, bestPair2;
do {
// threshold = 100*(float)(Math.log(numLeavesLeft)-Math.log(numLeavesLeft-1));
// threshold = (float)(Math.log(numLeavesLeft)-Math.log(numLeavesLeft-1));
threshold = 0;
// threshold = 0.01f;
minDeltaGI = threshold; // if we cannot find any that is better, stop.
bestPair1 = bestPair2 = -1;
for (int i = 0; i < newLeaves.length - 1; i++) {
if (newLeaves[i] == null)
continue;
for (int j = i + 1; j < newLeaves.length; j++) {
if (newLeaves[j] == null)
continue;
if (deltaGI[i][j - i - 1] < minDeltaGI) {
bestPair1 = i;
bestPair2 = j;
minDeltaGI = deltaGI[i][j - i - 1];
}
}
}
// System.out.printf("NumLeavesLeft=%4d, threshold=%f, minDeltaGI=%f\n", numLeavesLeft, threshold, minDeltaGI);
if (minDeltaGI < threshold) { // found something to merge
mergeLeaves(newLeaves[bestPair1], newLeaves[bestPair2]);
numLeavesLeft--;
// System.out.println("Merged leaves "+bestPair1+" and "+bestPair2+" (deltaGI: "+minDeltaGI+")");
newLeaves[bestPair2] = null;
// Update deltaGI table:
for (int i = 0; i < bestPair2; i++) {
deltaGI[i][bestPair2 - i - 1] = Float.NaN;
}
for (int j = bestPair2 + 1; j < newLeaves.length; j++) {
deltaGI[bestPair2][j - bestPair2 - 1] = Float.NaN;
}
for (int i = 0; i < bestPair1; i++) {
if (newLeaves[i] != null)
deltaGI[i][bestPair1 - i - 1] = (float) computeDeltaGI(newLeaves[i], newLeaves[bestPair1]);
}
for (int j = bestPair1 + 1; j < newLeaves.length; j++) {
if (newLeaves[j] != null)
deltaGI[bestPair1][j - bestPair1 - 1] = (float) computeDeltaGI(newLeaves[bestPair1], newLeaves[j]);
}
}
} while (minDeltaGI < threshold);
int nLeavesLeft = 0;
List survivors = new ArrayList();
for (int i = 0; i < newLeaves.length; i++) {
if (newLeaves[i] != null) {
nLeavesLeft++;
survivors.add((LeafNode) ((DirectedGraphNode) newLeaves[i]).getLeafNode());
}
}
long clusteredTime = System.currentTimeMillis();
System.out.printf("%5d leaves, gi=%7.3f).", nLeavesLeft, computeGlobalImpurity(survivors));
deltaGI = null;
impurities.clear();
float testDist = rmsDistanceTestData(graph);
System.out.printf(" Distance test data: %5.3f", testDist);
System.out.printf(" | fs %5dms, cl %5dms", (featSelectedTime - startTime), (clusteredTime - featSelectedTime));
System.out.println();
// Stop criterion: stop if feature selection does not succeed in reducing global impurity further,
// and at the same time, the test data approximation is getting worse.
if (minFSGI > prevFSGI && testDist > prevTestDataDistance) {
canClusterMore = false;
}
// Iteration step:
prevFeatureList = newFeatureList;
prevFSGI = minFSGI;
prevTestDataDistance = testDist;
return graph;
}
private void verifyFeatureQuality(int fi, CART testCART, int prevNLeaves) {
List leaves = new ArrayList();
int nLeaves = 0;
for (LeafNode leaf : testCART.getLeafNodes()) {
if (leaf.isEmpty())
continue;
leaves.add(leaf);
nLeaves++;
}
if (nLeaves <= prevNLeaves) { // this feature adds no leaf
return; // will not consider this further
}
double gi = computeGlobalImpurity(leaves, minCriterion);
// More leaves cost a bit:
double sizeBias = Math.log((float) nLeaves / prevNLeaves);
assert sizeBias > 0;
// double sizeBias = (float)nLeaves/prevNLeaves;
// assert sizeBias > 1;
// System.out.printf("%30s: GI=%.3f bias=%.7f\n", featureDefinition.getFeatureName(fi),gi,sizeBias);
double criterion = gi;
/*
* if (gi > globalMean) { // The best one is the one that can reach a small gi with a small increase in number of leaves
* criterion = globalMean + (gi-globalMean) * (1+sizeBias); } else { // leave as is, no size bias }
*/
if (criterion < minCriterion) {
setMinCriterion(criterion);
setMinFSGI(gi);
setBestFeature(fi);
}
}
/**
* Estimate the mean of all *distances* in the training set.
*
* @param leaves
* leaves
* @return computeglobalimpurity(leaves, double.Positive_infinity)
*/
/*
* private void estimateGlobalMean(FeatureVector[] data, DistanceMeasure dist) { int sampleSize = 100000;
* System.out.println("Estimating global mean by random sampling "+sampleSize+" distances"); long startTime =
* System.currentTimeMillis(); // Compute mean and stddev using recurrence relation, attributed by Donald Knuth // (The Art of
* Computer Programming, Volume 2: Seminumerical Algorithms, Section 4.2.2) // to B.P. Welford, Technometrics, 4, (1962),
* 419-420. // M(1) = x(1), M(k) = M(k-1) + (x(k) - M(k-1))/k // S(1) = 0, S(k) = S(k-1) + (x(k) - M(k-1))*(x(k)-M(k)) // for
* 2 <= k <= n, then sigma = sqrt(S(n)/(n-1)) // globalMean = 0; Random random = new Random(); for (int k=1; k I(%d)=%.3f\n", deltaGI, len1, imp1, len2, imp2, len12,
// imp12);
return deltaGI;
}
private void mergeLeaves(DirectedGraphNode dgn1, DirectedGraphNode dgn2) {
// Copy all data from dgn2 into dgn1
FeatureVectorLeafNode l1 = (FeatureVectorLeafNode) dgn1.getLeafNode();
FeatureVectorLeafNode l2 = (FeatureVectorLeafNode) dgn2.getLeafNode();
FeatureVector[] fv1 = l1.getFeatureVectors();
FeatureVector[] fv2 = l2.getFeatureVectors();
FeatureVector[] newFV = new FeatureVector[fv1.length + fv2.length];
System.arraycopy(fv1, 0, newFV, 0, fv1.length);
System.arraycopy(fv2, 0, newFV, fv1.length, fv2.length);
l1.setFeatureVectors(newFV);
// then update all mother/daughter relationships
Set dgn2Mothers = new HashSet(dgn2.getMothers());
for (Node mother : dgn2Mothers) {
if (mother instanceof DecisionNode) {
DecisionNode dm = (DecisionNode) mother;
dm.replaceDaughter(dgn1, dgn2.getNodeIndex(mother));
} else if (mother instanceof DirectedGraphNode) {
DirectedGraphNode gm = (DirectedGraphNode) mother;
gm.setLeafNode(dgn1);
}
dgn2.removeMother(mother);
}
dgn2.setLeafNode(null);
l2.setMother(null, 0);
// and remove impurity entries:
try {
impurities.remove(l1);
impurities.remove(l2);
} catch (NullPointerException e) {
e.printStackTrace();
System.err.println("Impurities: " + impurities + ", l1:" + l1 + ", l2:" + l2);
}
}
private void updateGraphFromTree(DecisionNode treeNode, DirectedGraphNode graphNode, List newLeaves) {
int treeFeatureIndex = treeNode.getFeatureIndex();
int treeNumDaughters = treeNode.getNumberOfDaugthers();
DecisionNode graphDecisionNode = graphNode.getDecisionNode();
if (graphDecisionNode != null) {
// Sanity check: the two must be aligned: same feature, same number of children
int graphFeatureIndex = graphDecisionNode.getFeatureIndex();
assert treeFeatureIndex == graphFeatureIndex : "Tree indices out of sync!";
assert treeNumDaughters == graphDecisionNode.getNumberOfDaugthers() : "Tree structure out of sync!";
// OK, now recursively call ourselves for all daughters
for (int i = 0; i < treeNumDaughters; i++) {
// We expect the next tree node to be a decision node (unless it is an empty node),
// because the level just above the leaves does not exist in graph yet.
Node nextTreeNode = treeNode.getDaughter(i);
if (nextTreeNode == null)
continue;
else if (nextTreeNode instanceof LeafNode) {
assert ((LeafNode) nextTreeNode).getNumberOfData() == 0;
continue;
}
assert nextTreeNode instanceof DecisionNode;
DirectedGraphNode nextGraphNode = (DirectedGraphNode) graphDecisionNode.getDaughter(i);
updateGraphFromTree((DecisionNode) nextTreeNode, nextGraphNode, newLeaves);
}
} else {
// No structure in graph yet which corresponds to tree.
// This is what we actually want to do.
if (featureDefinition.isByteFeature(treeFeatureIndex)) {
graphDecisionNode = new DecisionNode.ByteDecisionNode(treeFeatureIndex, treeNumDaughters, featureDefinition);
} else {
assert featureDefinition.isShortFeature(treeFeatureIndex) : "Only support byte and short features";
graphDecisionNode = new DecisionNode.ShortDecisionNode(treeFeatureIndex, treeNumDaughters, featureDefinition);
}
assert treeNumDaughters == graphDecisionNode.getNumberOfDaugthers();
graphNode.setDecisionNode(graphDecisionNode);
for (int i = 0; i < treeNumDaughters; i++) {
// we expect the next tree node to be a leaf node
LeafNode nextTreeNode = (LeafNode) treeNode.getDaughter(i);
// Now create the new daughter number i of graphDecisionNode.
// It is a DirectedGraphNode containing no decision tree but
// a leaf node, which is itself a DirectedGraphNode with no
// decision node but a leaf node:
if (nextTreeNode != null && nextTreeNode.getNumberOfData() > 0) {
DirectedGraphNode daughterLeafNode = new DirectedGraphNode(null, nextTreeNode);
DirectedGraphNode daughterNode = new DirectedGraphNode(null, daughterLeafNode);
graphDecisionNode.addDaughter(daughterNode);
newLeaves.add(daughterLeafNode);
} else {
graphDecisionNode.addDaughter(null);
}
}
}
}
private float rmsDistanceTestData(DirectedGraph graph) {
// return rmsMutualDistanceTestData(graph);
return rmsMeanDistanceTestData(graph);
}
private float rmsMeanDistanceTestData(DirectedGraph graph) {
float avgDist = 0;
for (int i = 0; i < testFeatures.length; i++) {
int ti = testFeatures[i].getUnitIndex();
FeatureVector[] leafData = (FeatureVector[]) graph.interpret(testFeatures[i]);
float[] mean = ((F0ContourPolynomialDistanceMeasure) dist).computeMean(leafData);
float oneDist = ((F0ContourPolynomialDistanceMeasure) dist).squaredDistance(testFeatures[i], mean);
oneDist = (float) Math.sqrt(oneDist);
avgDist += oneDist;
}
avgDist /= testFeatures.length;
return avgDist;
}
private float rmsMutualDistanceTestData(DirectedGraph graph) {
float avgDist = 0;
for (int i = 0; i < testFeatures.length; i++) {
int ti = testFeatures[i].getUnitIndex();
FeatureVector[] leafData = (FeatureVector[]) graph.interpret(testFeatures[i]);
float oneDist = 0;
for (int j = 0; j < leafData.length; j++) {
int lj = leafData[j].getUnitIndex();
if (ti < lj) {
oneDist += squaredDistances[ti][lj - ti - 1];
} else if (lj < ti) {
oneDist += squaredDistances[lj][ti - lj - 1];
}
}
oneDist /= leafData.length;
oneDist = (float) Math.sqrt(oneDist);
avgDist += oneDist;
}
avgDist /= testFeatures.length;
return avgDist;
}
private void setMinCriterion(double value) {
minCriterion = value;
}
private void setMinFSGI(double value) {
minFSGI = value;
}
private void setBestFeature(int featureIndex) {
iBestFeature = featureIndex;
}
private void debugOut(DirectedGraph graph) {
for (Iterator it = graph.getNodeIterator(); it.hasNext();) {
Node next = it.next();
debugOut(next);
}
}
private void debugOut(CART graph) {
Node root = graph.getRootNode();
debugOut(root);
}
private void debugOut(Node node) {
if (node instanceof DirectedGraphNode)
debugOut((DirectedGraphNode) node);
else if (node instanceof LeafNode)
debugOut((LeafNode) node);
else
debugOut((DecisionNode) node);
}
private void debugOut(DirectedGraphNode node) {
System.out.println("DGN");
if (node.getLeafNode() != null)
debugOut(node.getLeafNode());
if (node.getDecisionNode() != null)
debugOut(node.getDecisionNode());
}
private void debugOut(LeafNode node) {
System.out.println("Leaf: " + node.getDecisionPath());
}
private void debugOut(DecisionNode node) {
System.out.println("DN with " + node.getNumberOfDaugthers() + " daughters: " + node.toString());
for (int i = 0; i < node.getNumberOfDaugthers(); i++) {
Node daughter = node.getDaughter(i);
if (daughter == null)
System.out.println("null");
else
debugOut(daughter);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy