marytts.cart.DecisionNode Maven / Gradle / Ivy
The newest version!
/**
* Copyright 2000-2009 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*
*/
package marytts.cart;
import marytts.cart.LeafNode.FeatureVectorLeafNode;
import marytts.cart.LeafNode.IntArrayLeafNode;
import marytts.features.FeatureDefinition;
import marytts.features.FeatureVector;
/**
* A decision node that determines the next Node to go to in the CART. All decision nodes inherit from this class
*/
public abstract class DecisionNode extends Node {
public enum Type {
BinaryByteDecisionNode, BinaryShortDecisionNode, BinaryFloatDecisionNode, ByteDecisionNode, ShortDecisionNode
};
protected boolean TRACE = false;
// for debugging:
protected FeatureDefinition featureDefinition;
// a decision node has an array of daughters
protected Node[] daughters;
// the feature index
protected int featureIndex;
// the feature name
protected String feature;
// remember last added daughter
protected int lastDaughter;
// the total number of data in the leaves below this node
protected int nData;
// unique index used in MaryCART format
protected int uniqueDecisionNodeId;
/**
* Construct a new DecisionNode
*
* @param feature
* the feature
* @param numDaughters
* the number of daughters
* @param featureDefinition
* feature definition
*/
public DecisionNode(String feature, int numDaughters, FeatureDefinition featureDefinition) {
this.feature = feature;
this.featureIndex = featureDefinition.getFeatureIndex(feature);
daughters = new Node[numDaughters];
isRoot = false;
// for trace and getDecisionPath():
this.featureDefinition = featureDefinition;
}
/**
* Construct a new DecisionNode
*
* @param featureIndex
* the feature index
* @param numDaughters
* the number of daughters
* @param featureDefinition
* feature definition
*/
public DecisionNode(int featureIndex, int numDaughters, FeatureDefinition featureDefinition) {
this.featureIndex = featureIndex;
this.feature = featureDefinition.getFeatureName(featureIndex);
daughters = new Node[numDaughters];
isRoot = false;
// for trace and getDecisionPath():
this.featureDefinition = featureDefinition;
}
/**
* Construct a new DecisionNode
*
* @param numDaughters
* the number of daughters
* @param featureDefinition
* feature definition
*/
public DecisionNode(int numDaughters, FeatureDefinition featureDefinition) {
daughters = new Node[numDaughters];
isRoot = false;
// for trace and getDecisionPath():
this.featureDefinition = featureDefinition;
}
@Override
public boolean isDecisionNode() {
return true;
}
/**
* Get the name of the feature
*
* @return the name of the feature
*/
public String getFeatureName() {
return feature;
}
public int getFeatureIndex() {
return featureIndex;
}
public FeatureDefinition getFeatureDefinition() {
return featureDefinition;
}
/**
* Add a daughter to the node
*
* @param daughter
* the new daughter
*/
public void addDaughter(Node daughter) {
if (lastDaughter > daughters.length - 1) {
throw new RuntimeException("Can not add daughter number " + (lastDaughter + 1) + ", since node has only "
+ daughters.length + " daughters!");
}
daughters[lastDaughter] = daughter;
if (daughter != null) {
daughter.setMother(this, lastDaughter);
}
lastDaughter++;
}
/**
* Get the daughter at the specified index
*
* @param index
* the index of the daughter
* @return the daughter (potentially null); if index out of range: null
*/
public Node getDaughter(int index) {
if (index > daughters.length - 1 || index < 0) {
return null;
}
return daughters[index];
}
/**
* Replace daughter at given index with another daughter
*
* @param newDaughter
* the new daughter
* @param index
* the index of the daughter to replace
*/
public void replaceDaughter(Node newDaughter, int index) {
if (index > daughters.length - 1 || index < 0) {
throw new RuntimeException("Can not replace daughter number " + index + ", since daughter index goes from 0 to "
+ (daughters.length - 1) + "!");
}
daughters[index] = newDaughter;
newDaughter.setMother(this, index);
}
/**
* Tests, if the given index refers to a daughter
*
* @param index
* the index
* @return true, if the index is in range of the daughters array
*/
public boolean hasMoreDaughters(int index) {
return (index > -1 && index < daughters.length);
}
/**
* Get all unit indices from all leaves below this node
*
* @return an int array containing the indices
*/
public Object getAllData() {
// What to do depends on the type of leaves.
LeafNode firstLeaf = new NodeIterator(this, true, false, false).next();
if (firstLeaf == null)
return null;
Object result;
if (firstLeaf instanceof IntArrayLeafNode) { // this includes subclass IntAndFloatArrayLeafNode
result = new int[nData];
} else if (firstLeaf instanceof FeatureVectorLeafNode) {
result = new FeatureVector[nData];
} else {
return null;
}
fillData(result, 0, nData);
return result;
}
protected void fillData(Object target, int pos, int total) {
// assert pos + total <= target.length;
for (int i = 0; i < daughters.length; i++) {
if (daughters[i] == null)
continue;
int len = daughters[i].getNumberOfData();
daughters[i].fillData(target, pos, len);
pos += len;
}
}
/**
* Count all the nodes at and below this node. A leaf will return 1; the root node will report the total number of decision
* and leaf nodes in the tree.
*
* @return the number of nodes
*/
public int getNumberOfNodes() {
int nNodes = 1; // this node
for (int i = 0; i < daughters.length; i++) {
if (daughters[i] != null)
nNodes += daughters[i].getNumberOfNodes();
}
return nNodes;
}
public int getNumberOfData() {
return nData;
}
/**
* Number of daughters of current node.
*
* @return daughters.length
*/
public int getNumberOfDaugthers() {
return daughters.length;
}
/**
* Set the number of candidates correctly, by counting while walking down the tree. This needs to be done once for the entire
* tree.
*
*/
// protected void countData() {
public void countData() {
nData = 0;
for (int i = 0; i < daughters.length; i++) {
if (daughters[i] instanceof DecisionNode)
((DecisionNode) daughters[i]).countData();
if (daughters[i] != null) {
nData += daughters[i].getNumberOfData();
}
}
}
public String toString() {
return "dn" + uniqueDecisionNodeId;
}
/**
* Get the path leading to the daughter with the given index. This will recursively go up to the root node.
*
* @param daughterIndex
* daughterIndex
* @return the unique decision node id
*/
public abstract String getDecisionPath(int daughterIndex);
// unique index used in MaryCART format
public void setUniqueDecisionNodeId(int id) {
this.uniqueDecisionNodeId = id;
}
public int getUniqueDecisionNodeId() {
return uniqueDecisionNodeId;
}
/**
* Gets the String that defines the decision done in the node
*
* @return the node definition
*/
public abstract String getNodeDefinition();
/**
* Get the decision node type
*
* @return the decision node type
*/
public abstract Type getDecisionNodeType();
/**
* Select a daughter node according to the value in the given target
*
* @param featureVector
* the feature vector
* @return a daughter
*/
public abstract Node getNextNode(FeatureVector featureVector);
/**
* A binary decision Node that compares two byte values.
*/
public static class BinaryByteDecisionNode extends DecisionNode {
// the value of this node
private byte value;
/**
* Create a new binary String DecisionNode.
*
* @param feature
* the string used to get a value from an Item
* @param value
* the value to compare to
* @param featureDefinition
* featureDefinition
*/
public BinaryByteDecisionNode(String feature, String value, FeatureDefinition featureDefinition) {
super(feature, 2, featureDefinition);
this.value = featureDefinition.getFeatureValueAsByte(feature, value);
}
public BinaryByteDecisionNode(int featureIndex, byte value, FeatureDefinition featureDefinition) {
super(featureIndex, 2, featureDefinition);
this.value = value;
}
/***
* Creates an empty BinaryByteDecisionNode, the feature and feature value of this node should be filled with
* setFeatureAndFeatureValue() function.
*
* @param uniqueId
* unique index from tree HTS test file.
* @param featureDefinition
* featureDefinition
*/
public BinaryByteDecisionNode(int uniqueId, FeatureDefinition featureDefinition) {
super(2, featureDefinition);
// System.out.println("adding decision node: " + uniqueId);
this.uniqueDecisionNodeId = uniqueId;
}
/***
* Fill the feature and feature value of an already created (empty) BinaryByteDecisionNode.
*
* @param feature
* feature
* @param value
* value
*/
public void setFeatureAndFeatureValue(String feature, String value) {
this.feature = feature;
this.featureIndex = featureDefinition.getFeatureIndex(feature);
this.value = featureDefinition.getFeatureValueAsByte(feature, value);
}
public byte getCriterionValueAsByte() {
return value;
}
public String getCriterionValueAsString() {
return featureDefinition.getFeatureValueAsString(featureIndex, value);
}
/**
* Select a daughter node according to the value in the given target
*
* @param featureVector
* the feature vector
* @return a daughter
*/
public Node getNextNode(FeatureVector featureVector) {
byte val = featureVector.getByteFeature(featureIndex);
Node returnNode;
if (val == value) {
returnNode = daughters[0];
} else {
returnNode = daughters[1];
}
if (TRACE) {
System.out.print(" " + feature + ": " + featureDefinition.getFeatureValueAsString(featureIndex, value)
+ " == " + featureDefinition.getFeatureValueAsString(featureIndex, val));
if (val == value)
System.out.println(" YES ");
else
System.out.println(" NO ");
}
return returnNode;
}
public String getDecisionPath(int daughterIndex) {
String thisNodeInfo;
if (daughterIndex == 0)
thisNodeInfo = feature + "==" + featureDefinition.getFeatureValueAsString(featureIndex, value);
else
thisNodeInfo = feature + "!=" + featureDefinition.getFeatureValueAsString(featureIndex, value);
if (mother == null)
return thisNodeInfo;
else if (mother.isDecisionNode())
return ((DecisionNode) mother).getDecisionPath(getNodeIndex()) + " - " + thisNodeInfo;
else
return mother.getDecisionPath() + " - " + thisNodeInfo;
}
/**
* Gets the String that defines the decision done in the node
*
* @return the node definition
*/
public String getNodeDefinition() {
return feature + " is " + featureDefinition.getFeatureValueAsString(featureIndex, value);
}
public Type getDecisionNodeType() {
return Type.BinaryByteDecisionNode;
}
}
/**
* A binary decision Node that compares two short values.
*/
public static class BinaryShortDecisionNode extends DecisionNode {
// the value of this node
private short value;
/**
* Create a new binary String DecisionNode.
*
* @param feature
* the string used to get a value from an Item
* @param value
* the value to compare to
* @param featureDefinition
* featureDefinition
*/
public BinaryShortDecisionNode(String feature, String value, FeatureDefinition featureDefinition) {
super(feature, 2, featureDefinition);
this.value = featureDefinition.getFeatureValueAsShort(feature, value);
}
public BinaryShortDecisionNode(int featureIndex, short value, FeatureDefinition featureDefinition) {
super(featureIndex, 2, featureDefinition);
this.value = value;
}
public short getCriterionValueAsShort() {
return value;
}
public String getCriterionValueAsString() {
return featureDefinition.getFeatureValueAsString(featureIndex, value);
}
/**
* Select a daughter node according to the value in the given target
*
* @param featureVector
* the feature vector
* @return a daughter
*/
public Node getNextNode(FeatureVector featureVector) {
short val = featureVector.getShortFeature(featureIndex);
Node returnNode;
if (val == value) {
returnNode = daughters[0];
} else {
returnNode = daughters[1];
}
if (TRACE) {
System.out.print(feature + ": " + featureDefinition.getFeatureValueAsString(featureIndex, val));
if (val == value)
System.out.print(" == ");
else
System.out.print(" != ");
System.out.println(featureDefinition.getFeatureValueAsString(featureIndex, value));
}
return returnNode;
}
public String getDecisionPath(int daughterIndex) {
String thisNodeInfo;
if (daughterIndex == 0)
thisNodeInfo = feature + "==" + featureDefinition.getFeatureValueAsString(featureIndex, value);
else
thisNodeInfo = feature + "!=" + featureDefinition.getFeatureValueAsString(featureIndex, value);
if (mother == null)
return thisNodeInfo;
else if (mother.isDecisionNode())
return ((DecisionNode) mother).getDecisionPath(getNodeIndex()) + " - " + thisNodeInfo;
else
return mother.getDecisionPath() + " - " + thisNodeInfo;
}
/**
* Gets the String that defines the decision done in the node
*
* @return the node definition
*/
public String getNodeDefinition() {
return feature + " is " + featureDefinition.getFeatureValueAsString(featureIndex, value);
}
public Type getDecisionNodeType() {
return Type.BinaryShortDecisionNode;
}
}
/**
* A binary decision Node that compares two float values.
*/
public static class BinaryFloatDecisionNode extends DecisionNode {
// the value of this node
private float value;
private boolean isByteFeature;
/**
* Create a new binary String DecisionNode.
*
* @param featureIndex
* the string used to get a value from an Item
* @param value
* the value to compare to
* @param featureDefinition
* featureDefinition
*/
public BinaryFloatDecisionNode(int featureIndex, float value, FeatureDefinition featureDefinition) {
this(featureDefinition.getFeatureName(featureIndex), value, featureDefinition);
}
public BinaryFloatDecisionNode(String feature, float value, FeatureDefinition featureDefinition) {
super(feature, 2, featureDefinition);
this.value = value;
// check for pseudo-floats:
// TODO: clean this up:
if (featureDefinition.isByteFeature(featureIndex))
isByteFeature = true;
else
isByteFeature = false;
}
public float getCriterionValueAsFloat() {
return value;
}
public String getCriterionValueAsString() {
return String.valueOf(value);
}
/**
* Select a daughter node according to the value in the given target
*
* @param featureVector
* the feature vector
* @return a daughter
*/
public Node getNextNode(FeatureVector featureVector) {
float val;
if (isByteFeature)
val = (float) featureVector.getByteFeature(featureIndex);
else
val = featureVector.getContinuousFeature(featureIndex);
Node returnNode;
if (val < value) {
returnNode = daughters[0];
} else {
returnNode = daughters[1];
}
if (TRACE) {
System.out.print(feature + ": " + val);
if (val < value)
System.out.print(" < ");
else
System.out.print(" >= ");
System.out.println(value);
}
return returnNode;
}
public String getDecisionPath(int daughterIndex) {
String thisNodeInfo;
if (daughterIndex == 0)
thisNodeInfo = feature + "<" + value;
else
thisNodeInfo = feature + ">=" + value;
if (mother == null)
return thisNodeInfo;
else if (mother.isDecisionNode())
return ((DecisionNode) mother).getDecisionPath(getNodeIndex()) + " - " + thisNodeInfo;
else
return mother.getDecisionPath() + " - " + thisNodeInfo;
}
/**
* Gets the String that defines the decision done in the node
*
* @return the node definition
*/
public String getNodeDefinition() {
return feature + " < " + value;
}
public Type getDecisionNodeType() {
return Type.BinaryFloatDecisionNode;
}
}
/**
* An decision Node with an arbitrary number of daughters. Value of the target corresponds to the index number of next
* daughter.
*/
public static class ByteDecisionNode extends DecisionNode {
/**
* Build a new byte decision node
*
* @param feature
* the feature name
* @param numDaughters
* the number of daughters
* @param featureDefinition
* featureDefinition
*/
public ByteDecisionNode(String feature, int numDaughters, FeatureDefinition featureDefinition) {
super(feature, numDaughters, featureDefinition);
}
/**
* Build a new byte decision node
*
* @param featureIndex
* the feature index
* @param numDaughters
* the number of daughters
* @param featureDefinition
* featureDefinition
*/
public ByteDecisionNode(int featureIndex, int numDaughters, FeatureDefinition featureDefinition) {
super(featureIndex, numDaughters, featureDefinition);
}
/**
* Select a daughter node according to the value in the given target
*
* @param featureVector
* the feature vector
* @return a daughter
*/
public Node getNextNode(FeatureVector featureVector) {
byte val = featureVector.getByteFeature(featureIndex);
if (TRACE) {
System.out.println(feature + ": " + featureDefinition.getFeatureValueAsString(featureIndex, val));
}
return daughters[val];
}
public String getDecisionPath(int daughterIndex) {
String thisNodeInfo = feature + "==" + featureDefinition.getFeatureValueAsString(featureIndex, daughterIndex);
if (mother == null)
return thisNodeInfo;
else if (mother.isDecisionNode())
return ((DecisionNode) mother).getDecisionPath(getNodeIndex()) + " - " + thisNodeInfo;
else
return mother.getDecisionPath() + " - " + thisNodeInfo;
}
/**
* Gets the String that defines the decision done in the node
*
* @return the node definition
*/
public String getNodeDefinition() {
return feature + " isByteOf " + daughters.length;
}
public Type getDecisionNodeType() {
return Type.ByteDecisionNode;
}
}
/**
* An decision Node with an arbitrary number of daughters. Value of the target corresponds to the index number of next
* daughter.
*/
public static class ShortDecisionNode extends DecisionNode {
/**
* Build a new short decision node
*
* @param feature
* the feature name
* @param numDaughters
* the number of daughters
* @param featureDefinition
* featureDefinition
*/
public ShortDecisionNode(String feature, int numDaughters, FeatureDefinition featureDefinition) {
super(feature, numDaughters, featureDefinition);
}
/**
* Build a new short decision node
*
* @param featureIndex
* the feature index
* @param numDaughters
* the number of daughters
* @param featureDefinition
* featureDefinition
*/
public ShortDecisionNode(int featureIndex, int numDaughters, FeatureDefinition featureDefinition) {
super(featureIndex, numDaughters, featureDefinition);
}
/**
* Select a daughter node according to the value in the given target
*
* @param featureVector
* the feature vector
* @return a daughter
*/
public Node getNextNode(FeatureVector featureVector) {
short val = featureVector.getShortFeature(featureIndex);
if (TRACE) {
System.out.println(feature + ": " + featureDefinition.getFeatureValueAsString(featureIndex, val));
}
return daughters[val];
}
public String getDecisionPath(int daughterIndex) {
String thisNodeInfo = feature + "==" + featureDefinition.getFeatureValueAsString(featureIndex, daughterIndex);
if (mother == null)
return thisNodeInfo;
else if (mother.isDecisionNode())
return ((DecisionNode) mother).getDecisionPath(getNodeIndex()) + " - " + thisNodeInfo;
else
return mother.getDecisionPath() + " - " + thisNodeInfo;
}
/**
* Gets the String that defines the decision done in the node
*
* @return the node definition
*/
public String getNodeDefinition() {
return feature + " isShortOf " + daughters.length;
}
public Type getDecisionNodeType() {
return Type.ShortDecisionNode;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy