gov.sandia.cognition.learning.algorithm.tree.RandomSubVectorThresholdLearner Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: RandomSubVectorThresholdLearner.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright December 06, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.tree;
import gov.sandia.cognition.collection.ArrayUtil;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.Permutation;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.AbstractRandomized;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;
/**
* Learns a decision function by taking a randomly sampling a subspace from
* a given set of input vectors and then learning a threshold function by
* passing the subspace vectors to a sublearner. This component is typically
* used along with a decision tree learner to create random forests of decision
* trees.
*
* @param
* The output type for the decider.
* @author Justin Basilico
* @since 3.0
*/
// TODO: Find a publication reference for random forests.
// -- jdbasil (2009-12-23)
public class RandomSubVectorThresholdLearner
extends AbstractRandomized
implements VectorThresholdLearner,
VectorFactoryContainer
{
/** The default percent to sample is {@value}. */
public static final double DEFAULT_PERCENT_TO_SAMPLE = 0.1;
/** The decider learner for the subspace. */
protected DeciderLearner subLearner;
/** The percentage of the dimensionality to sample. */
protected double percentToSample;
/** The dimensions to sample from in the learner. */
protected int[] dimensionsToConsider;
/** The vector factory to use. */
protected VectorFactory extends Vector> vectorFactory;
/**
* Creates a new {@code RandomSubVectorThresholdLearner}.
*/
public RandomSubVectorThresholdLearner()
{
this(null, DEFAULT_PERCENT_TO_SAMPLE, new Random());
}
/**
* Creates a new {@code RandomSubVectorThresholdLearner}.
*
* @param subLearner
* The threshold decision function learner to use over the subspace.
* @param percentToSample
* The percentage of the dimensionality to sample (must be between
* 0.0 (exclusive) and 1.0 (inclusive).
* @param random
* The random number generator.
*/
public RandomSubVectorThresholdLearner(
final DeciderLearner subLearner,
final double percentToSample,
final Random random)
{
this(subLearner, percentToSample, random, VectorFactory.getDefault());
}
/**
* Creates a new {@code RandomSubVectorThresholdLearner}.
*
* @param subLearner
* The threshold decision function learner to use over the subspace.
* @param percentToSample
* The percentage of the dimensionality to sample (must be between
* 0.0 and 1.0.
* @param random
* The random number generator.
* @param vectorFactory
* The vector factory to use.
*/
public RandomSubVectorThresholdLearner(
final DeciderLearner subLearner,
final double percentToSample,
final Random random,
final VectorFactory extends Vector> vectorFactory)
{
this(subLearner, percentToSample, null, random, vectorFactory);
}
/**
* Creates a new {@code RandomSubVectorThresholdLearner}.
*
* @param subLearner
* The threshold decision function learner to use over the subspace.
* @param percentToSample
* The percentage of the dimensionality to sample (must be between
* 0.0 and 1.0.
* @param dimensionsToConsider
* The array of vector dimensions to consider. Null means all of them
* are considered.
* @param random
* The random number generator.
* @param vectorFactory
* The vector factory to use.
*/
public RandomSubVectorThresholdLearner(
final DeciderLearner subLearner,
final double percentToSample,
final int[] dimensionsToConsider,
final Random random,
final VectorFactory extends Vector> vectorFactory)
{
super(random);
this.setSubLearner(subLearner);
this.setPercentToSample(percentToSample);
this.setDimensionsToConsider(dimensionsToConsider);
this.setVectorFactory(vectorFactory);
}
@Override
public RandomSubVectorThresholdLearner clone()
{
@SuppressWarnings("unchecked")
final RandomSubVectorThresholdLearner result = (RandomSubVectorThresholdLearner)
super.clone();
result.subLearner = ObjectUtil.cloneSmart(this.subLearner);
result.dimensionsToConsider = ArrayUtil.copy(this.dimensionsToConsider);
return result;
}
@Override
public VectorElementThresholdCategorizer learn(
final Collection extends InputOutputPair extends Vectorizable, OutputType>> data)
{
if (this.random == null)
{
this.random = new Random();
}
// Gets the dimensionality of the input.
final int dimensionality;
if (this.dimensionsToConsider == null)
{
// Include all dimensions.
dimensionality = DatasetUtil.getInputDimensionality(data);
}
else
{
dimensionality = this.dimensionsToConsider.length;
}
// Get the dimensionality of the subspace.
final int subDimensionality = this.getSubDimensionality(dimensionality);
final int[] subDimensions;
if (subDimensionality >= dimensionality)
{
if (this.dimensionsToConsider == null)
{
// No point in subsampling if the requested dimensionality is as
// big (or bigger) than the actual dimensionality.
return this.subLearner.learn(data);
}
else
{
// The subdimensions are just the set of dimensions to consider.
// Use them.
subDimensions = this.dimensionsToConsider;
}
}
else
{
// Create a partial permutation of the indices of the dimensionality.
subDimensions = Permutation.createPartialPermutation(
dimensionality, subDimensionality, this.random);
if (this.dimensionsToConsider != null)
{
// We only use the dimensions to consider based on the array.
for (int i = 0; i < subDimensionality; i++)
{
// Replace the index with the one from the dimensions to
// consider.
subDimensions[i] = this.dimensionsToConsider[subDimensions[i]];
}
}
}
if (this.subLearner instanceof VectorThresholdLearner>)
{
// In this case we can avoid copying the data by giving the learner
// the indices to learn using.
((VectorThresholdLearner>) this.subLearner).setDimensionsToConsider(
subDimensions);
return this.subLearner.learn(data);
}
// Build up the dataset for the subspace.
final ArrayList> subData =
new ArrayList<>(data.size());
for (InputOutputPair extends Vectorizable, OutputType> example
: data)
{
// Create the new subspace vector.
final Vector subVector = this.vectorFactory.createVector(
subDimensionality);
// Copy over the values from the original vector.
final Vector vector = example.getInput().convertToVector();
for (int i = 0; i < subDimensionality; i++)
{
subVector.setElement(i, vector.getElement(subDimensions[i]));
}
// Add the new example.
subData.add(new DefaultInputOutputPair<>(
subVector, example.getOutput()));
}
// Learn on the subspace data.
final VectorElementThresholdCategorizer subDecider =
this.subLearner.learn(subData);
if (subDecider != null)
{
// Change the index the threshold is applied to.
final int subIndex = subDecider.getIndex();
final int index = subDimensions[subIndex];
subDecider.setIndex(index);
}
// else - Null just gets returned.
// Return the learned function.
return subDecider;
}
/**
* Gets the dimensionality of the subspace based on the full dimensionality.
*
* @param dimensionality
* The full dimensionality
* @return
* The dimensionality of the subspace. Will always be greater than or
* equal to 1.
*/
public int getSubDimensionality(
final int dimensionality)
{
return Math.max(1, (int) (dimensionality * this.percentToSample));
}
/**
* Gets the learner used to learn a threshold function over the subspace.
*
* @return
* The learner for the subspace.
*/
public DeciderLearner
getSubLearner()
{
return this.subLearner;
}
/**
* Sets the learner used to learn a threshold function over the subspace.
*
* @param subLearner
* The learner for the subspace.
*/
public void setSubLearner(
final DeciderLearner subLearner)
{
this.subLearner = subLearner;
}
/**
* Gets the percent of the dimensionality to sample. Must be between 0.0
* and 1.0.
*
* @return
* The percent of the dimensionality to sample.
*/
public double getPercentToSample()
{
return this.percentToSample;
}
/**
* Sets the percent of the dimensionality to sample. Must be between 0.0
* and 1.0.
*
* @param percentToSample
* The percent of the dimensionality to sample.
*/
public void setPercentToSample(
final double percentToSample)
{
// Note: Technically, the percent to sample should be in the range (0.0, 1.0)
// not [0.0, 1.0] (in otherwords, where it is exclusive, not inclusive).
// However, a value of 0.0 will mean that only 1 index is chosen and a value of
// 1.0 will mean that all indices are chosen (pass-through). Since these could
// be useful values for testing various configurations, I decided to allow them.
// However, I'm not sure if that makes things more confusing or not.
// --jdbasil (2009-12-06)
if (percentToSample < 0.0 || percentToSample > 1.0)
{
throw new IllegalArgumentException(
"percentToSample must be between 0.0 and 1.0");
}
this.percentToSample = percentToSample;
}
@Override
public int[] getDimensionsToConsider()
{
return this.dimensionsToConsider;
}
@Override
public void setDimensionsToConsider(
final int... dimensionsToConsider)
{
this.dimensionsToConsider = dimensionsToConsider;
}
/**
* Gets the vector factory.
*
* @return
* The vector factory.
*/
@Override
public VectorFactory extends Vector> getVectorFactory()
{
return this.vectorFactory;
}
/**
* Sets the vector factory.
*
* @param vectorFactory
* The vector factory.
*/
public void setVectorFactory(
final VectorFactory extends Vector> vectorFactory)
{
this.vectorFactory = vectorFactory;
}
}