gov.sandia.cognition.learning.data.feature.StandardDistributionNormalizer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: StandardDistributionNormalizer.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright September 11, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.data.feature;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.collection.NumberComparator;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.math.AbstractUnivariateScalarFunction;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
/**
* The {@code StandardDistributionNormalizer} class implements a normalization
* method where a real value is converted onto a standard distribution. This
* means that the value is subtracted by the mean and divided by the standard
* deviation.
*
* f(x) = (x - mean) / (standardDeviation)
*
* @author Justin Basilico
* @since 2.0
*/
@CodeReview(
reviewer="Kevin R. Dixon",
date="2009-07-06",
changesNeeded=false,
comments={
"Made the learning methods take extends Number>",
"Now extends AbstractUnivariateScalarFunction",
"Cleaned up javadoc"
}
)
public class StandardDistributionNormalizer
extends AbstractUnivariateScalarFunction
{
/** The default mean is {@value}. */
public static final double DEFAULT_MEAN = 0.0;
/** The default variance is {@value}. This means that by default there is no
* normalization by variance. */
public static final double DEFAULT_VARIANCE = 1.0;
/** The mean of normalization. */
protected double mean;
/** The variance for normalization. */
protected double variance;
/** The cached value of the standard deviation for normalization. */
protected double standardDeviation;
/**
* Creates a new instance of StandardNormalization with a mean of 0.0 and
* a variance of 1.0.
*/
public StandardDistributionNormalizer()
{
this(DEFAULT_MEAN, DEFAULT_VARIANCE);
}
/**
* Creates a new instance of StandardDistributionNormalizer with the given
* mean and variance.
*
* @param mean The mean.
* @param variance The variance.
*/
public StandardDistributionNormalizer(
final double mean,
final double variance)
{
super();
this.setMean(mean);
this.setVariance(variance);
}
/**
* Creates a new instance of StandardDistributionNormalizer from the given
* Gaussian.
*
* @param gaussian The Gaussian to initialize the normalizer with.
*/
public StandardDistributionNormalizer(
final UnivariateGaussian gaussian)
{
this(gaussian.getMean(), gaussian.getVariance());
}
/**
* Creates a new copy of a StandardDistributionNormalizer.
*
* @param other The StandardDistributionNormalizer to copy.
*/
public StandardDistributionNormalizer(
final StandardDistributionNormalizer other)
{
this(other.getMean(), other.getVariance());
}
/**
* Creates a new copy of this StandardDistributionNormalizer.
*
* @return A new copy of this StandardDistributionNormalizer.
*/
@Override
public StandardDistributionNormalizer clone()
{
return (StandardDistributionNormalizer) super.clone();
}
/**
* Normalizes the given double value by subtracting the mean and dividing
* by the standard deviation (the square root of the variance).
*
* @param value The value to normalize.
* @return The normalized value.
*/
public double evaluate(
final double value)
{
return (value - this.mean) / this.standardDeviation;
}
/**
* Gets the mean.
*
* @return The mean.
*/
public double getMean()
{
return this.mean;
}
/**
* Sets the mean.
*
* @param mean The mean.
*/
public void setMean(
final double mean)
{
this.mean = mean;
}
/**
* Gets the variance.
*
* @return The variance.
*/
public double getVariance()
{
return variance;
}
/**
* Sets the variance. It must be greater than 0.0.
*
* @param variance The variance.
*/
public void setVariance(
double variance)
{
if (variance <= 0.0)
{
throw new IllegalArgumentException("variance must be positive");
}
this.variance = variance;
this.standardDeviation = Math.sqrt(variance);
}
/**
* Builds a StandardDistributionNormalizer by computing the mean and
* variance of the given collection of values.
*
* @param values The values to use to build the normalizer.
* @return The StandardDistributionNormalizer created from the mean and
* variance of the given values.
*/
public static StandardDistributionNormalizer learn(
final Collection extends Number> values)
{
return learn(values, 0.0);
}
/**
* Builds a StandardDistributionNormalizer by computing the mean and
* variance of the given collection of values. It will exclude the given
* percentage of outliers from the value.
*
* @param values The values to use to build the normalizer.
* @param outlierPercent The percentage of outliers to exclude.
* @return The StandardDistributionNormalizer created from the mean and
* variance of the given values.
*/
public static StandardDistributionNormalizer learn(
final Collection extends Number> values,
double outlierPercent)
{
if (values == null)
{
// Error: Bad values.
throw new NullPointerException("values cannot be null.");
}
else if (outlierPercent < 0.0 || outlierPercent >= 1.0)
{
// Error: Bad outlier percent.
throw new IllegalArgumentException(
"outlierPercent must be [0.0, 1.0)");
}
int count = values.size();
if (count <= 0)
{
// Error: Not enough samples.
throw new IllegalArgumentException("values cannot be empty.");
}
// Figure out the collection to compute the mean and variance on.
Collection extends Number> included = values;
if (outlierPercent > 0.0)
{
// Discard the given percentage of outliers by removing half that
// percentage from each side.
final ArrayList sorted = new ArrayList(values);
Collections.sort(sorted, NumberComparator.INSTANCE );
int numToDiscard = (int) (count * outlierPercent / 2.0);
if (numToDiscard > 0 && (2 * numToDiscard) < count)
{
included = sorted.subList(numToDiscard, count - numToDiscard);
}
}
// Get the new count of values to compute.
count = included.size();
// Compute the mean.
Pair result =
UnivariateStatisticsUtil.computeMeanAndVariance(included);
double mean = result.getFirst();
double variance = ((count-1.0)/(count))*result.getSecond();
if( variance <= 0.0 )
{
variance = 1.0;
}
return new StandardDistributionNormalizer(mean, variance);
}
/**
* The {@code Learner} class implements a {@code BatchLearner} object for
* a {@code StandardDistributionNormalizer}.
*/
public static class Learner
extends AbstractCloneableSerializable
implements BatchLearner, StandardDistributionNormalizer>
{
/** The default percentage of outliers is {@value}. */
public static final double DEFAULT_OUTLIER_PERCENT = 0.0;
/** The percentage of outliers to exclude from learning. */
protected double outlierPercent;
/**
* Creates a new StandardDistributionNormalizer.Learner.
*/
public Learner()
{
this(DEFAULT_OUTLIER_PERCENT);
}
/**
* Creates a new StandardDistributionNormalizer.Learner.
*
* @param outlierPercent The percentage of outliers to exclude.
*/
public Learner(
final double outlierPercent)
{
super();
this.setOutlierPercent(outlierPercent);
}
/**
* Learns a StandardDistributionNormalizer from the given values by
* computing the mean and standard deviation of the values.
*
* @param values The values to use.
* @return The StandardDistributionNormalizer computed from the given
* values.
*/
public StandardDistributionNormalizer learn(
final Collection values)
{
return StandardDistributionNormalizer.learn(
values, this.outlierPercent);
}
/**
* Sets the percentage of outliers to exclude from learning.
*
* @return The percentage of outliers.
*/
public double getOutlierPercent()
{
return outlierPercent;
}
/**
* Sets the percentage of outliers to exclude from learning. Must be
* between 0.0 and 1.0.
*
* @param outlierPercent The percentage of outliers.
*/
public void setOutlierPercent(
final double outlierPercent)
{
if (outlierPercent < 0.0 || outlierPercent >= 1.0)
{
throw new IllegalArgumentException(
"outlierPercent must be [0.0, 1.0)");
}
this.outlierPercent = outlierPercent;
}
}
}