gov.sandia.cognition.statistics.distribution.CategoricalDistribution Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: CategoricalDistribution.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright May 3, 2011, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.distribution;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.ClosedFormComputableDiscreteDistribution;
import gov.sandia.cognition.statistics.ProbabilityMassFunction;
import gov.sandia.cognition.statistics.ProbabilityMassFunctionUtil;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.Random;
import java.util.Set;
/**
* The Categorical Distribution is the multivariate generalization of the
* Bernoulli distribution, where the outcome of an experiment is a one-of-N
* output, where the output is a selector Vector. This Vector will have all
* zeros except one index will have a 1.0.
* @author Kevin R. Dixon
* @since 3.3.0
*/
@PublicationReference(
author="Wikipedia",
title="Categoical Distribution",
type=PublicationType.WebPage,
year=2011,
url="http://en.wikipedia.org/wiki/Categorical_distribution"
)
public class CategoricalDistribution
extends AbstractDistribution
implements ClosedFormComputableDiscreteDistribution
{
/**
* Default number of classes (labels or parameters), {@value}.
*/
public static final int DEFAULT_NUM_CLASSES = 2;
/**
* Parameters of the multinomial distribution, must be at least
* 2-dimensional and each element must be nonnegative.
*/
protected Vector parameters;
/**
* Creates a new instance of CategoricalDistribution
*/
public CategoricalDistribution()
{
this( DEFAULT_NUM_CLASSES );
}
/**
* Creates a new instance of CategoricalDistribution
* @param numClasses
* Number of classes (labels or parameters) to use.
*/
public CategoricalDistribution(
final int numClasses )
{
this( VectorFactory.getDefault().createVector(numClasses, 1.0) );
}
/**
* Creates a new instance of CategoricalDistribution
* @param parameters
* Parameters of the multinomial distribution, must be at least
* 2-dimensional and each element must be nonnegative.
*/
public CategoricalDistribution(
final Vector parameters )
{
this.setParameters(parameters);
}
/**
* Copy constructor
* @param other
* CategoricalDistribution to copy
*/
public CategoricalDistribution(
CategoricalDistribution other )
{
this( ObjectUtil.cloneSafe(other.getParameters()) );
}
@Override
public CategoricalDistribution clone()
{
CategoricalDistribution clone = (CategoricalDistribution) super.clone();
clone.setParameters( ObjectUtil.cloneSafe( this.getParameters() ) );
return clone;
}
/**
* Getter for parameters
* @return
* Parameters of the multinomial distribution, must be at least
* 2-dimensional and each element must be nonnegative.
*/
public Vector getParameters()
{
return this.parameters;
}
/**
* Setter for parameters
* @param parameters
* Parameters of the multinomial distribution, must be at least
* 2-dimensional and each element must be nonnegative.
*/
public void setParameters(
final Vector parameters)
{
final int N = parameters.getDimensionality();
if( N < 2 )
{
throw new IllegalArgumentException( "Dimensionality must be >= 2" );
}
for( int i = 0; i < N; i++ )
{
if( parameters.getElement(i) < 0.0 )
{
throw new IllegalArgumentException(
"All parameter elements must be >= 0.0" );
}
}
this.parameters = parameters;
}
@Override
public void sampleInto(
final Random random,
final int sampleCount,
final Collection super Vector> output)
{
ArrayList domain = CollectionUtil.asArrayList(this.getDomain());
final int N = domain.size();
double[] cumulativeWeights = new double[N];
double sum = 0.0;
for( int n = 0; n < N; n++ )
{
double weight = this.parameters.getElement(n);
sum += weight;
cumulativeWeights[n] = sum;
}
ProbabilityMassFunctionUtil.sampleMultipleInto(
cumulativeWeights, domain, random, sampleCount, output);
}
@Override
public Vector getMean()
{
return this.parameters.scale( this.parameters.norm1() );
}
@Override
public Vector convertToVector()
{
return this.parameters.clone();
}
@Override
public void convertFromVector(
final Vector parameters)
{
this.parameters.assertSameDimensionality(parameters);
this.setParameters(parameters);
}
/**
* Gets the dimensionality of the input vectors
* @return
* Dimensionality of the input vectors
*/
public int getInputDimensionality()
{
return this.getParameters().getDimensionality();
}
@Override
public Set getDomain()
{
final int N = this.getInputDimensionality();
LinkedHashSet domain = new LinkedHashSet( N );
for( int n = 0; n < N; n++ )
{
Vector x = VectorFactory.getDefault().createVector(N);
x.setElement(n, 1.0);
domain.add( x );
}
return domain;
}
@Override
public int getDomainSize()
{
return this.getInputDimensionality();
}
@Override
public CategoricalDistribution.PMF getProbabilityFunction()
{
return new CategoricalDistribution.PMF( this );
}
/**
* PMF of the Categorical Distribution
*/
public static class PMF
extends CategoricalDistribution
implements ProbabilityMassFunction,
VectorInputEvaluator
{
/**
* Creates a new instance of CategoricalDistribution
*/
public PMF()
{
super();
}
/**
* Creates a new instance of CategoricalDistribution
* @param numClasses
* Number of classes (labels or parameters) to use.
*/
public PMF(
final int numClasses )
{
super( numClasses );
}
/**
* Creates a new instance of CategoricalDistribution
* @param parameters
* Parameters of the multinomial distribution, must be at least
* 2-dimensional and each element must be nonnegative.
*/
public PMF(
final Vector parameters )
{
super( parameters );
}
/**
* Copy constructor
* @param other
* CategoricalDistribution to copy
*/
public PMF(
CategoricalDistribution other )
{
super( other );
}
@Override
public double getEntropy()
{
return ProbabilityMassFunctionUtil.getEntropy(this);
}
@Override
public double logEvaluate(
Vector input)
{
return Math.log( this.evaluate(input) );
}
@Override
public Double evaluate(
Vector input)
{
this.parameters.assertSameDimensionality(input);
double pi = -1.0;
final int N = this.getInputDimensionality();
double sum = 0.0;
for( int n = 0; n < N; n++ )
{
double p = this.parameters.getElement(n);
sum += p;
double x = input.getElement(n);
if( x == 1.0 )
{
if( pi < 0.0 )
{
pi = p;
}
else
{
throw new IllegalArgumentException(
"input must only have one entry equal to 1.0!");
}
}
else if( x != 0.0 )
{
throw new IllegalArgumentException(
"input entries must be either 0.0 or 1.0" );
}
}
if( pi < 0.0 )
{
throw new IllegalArgumentException(
"input must have one entry equal to 1.0!" );
}
return pi/sum;
}
@Override
public CategoricalDistribution.PMF getProbabilityFunction()
{
return this;
}
}
}