All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.statistics.distribution.MultivariatePolyaDistribution Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                MultivariatePolyaDistribution.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright May 15, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.MathUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.ClosedFormComputableDiscreteDistribution;
import gov.sandia.cognition.statistics.ProbabilityMassFunction;
import gov.sandia.cognition.statistics.ProbabilityMassFunctionUtil;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;

/**
 * A multivariate Polya Distribution, also known as a Dirichlet-Multinomial
 * model, is a compound distribution where the parameters of a multinomial
 * are drawn from a Dirichlet distribution with fixed parameters and a constant
 * number of trials and then the observations are generated by this
 * multinomial.  This is the multivariate generalization of the Beta-Binomial
 * Distribution and is the predictive posterior distribution for a
 * Multinomial Bayesian estimator using its conjugate prior Dirichlet
 * distribution.
 * @author Kevin R. Dixon
 * @since 3.0
 */
@PublicationReference(
    author="Wikipedia",
    title="Multivariate Polya Distribution",
    type=PublicationType.WebPage,
    year=2010,
    url="http://en.wikipedia.org/wiki/Multivariate_Polya_distribution"
)
public class MultivariatePolyaDistribution 
    extends AbstractDistribution
    implements ClosedFormComputableDiscreteDistribution
{

    /**
     * Default number of trials, {@value}.
     */
    public static final int DEFAULT_NUM_TRIALS =
        MultinomialDistribution.DEFAULT_NUM_TRIALS;

    /**
     * Default dimensionality, {@value}.
     */
    public static final int DEFAULT_DIMENSIONALITY = 2;

    /**
     * Parameters of the Dirichlet distribution, must be at least 2-dimensional
     * and each element must be positive.
     */
    protected Vector parameters;

    /**
     * Number of trials in the distribution, must be greater than 0.
     */
    private int numTrials;

    /**
     * Creates a new instance of DirichletDistribution
     */
    public MultivariatePolyaDistribution()
    {
        this( DEFAULT_DIMENSIONALITY, DEFAULT_NUM_TRIALS );
    }

    /**
     * Creates a new instance of MultivariatePolyaDistribution
     * @param dimensionality
     * Dimensionality of the distribution
     * @param numTrials
     * Number of trials in the distribution, must be greater than 0.
     */
    public MultivariatePolyaDistribution(
        final int dimensionality,
        final int numTrials )
    {
        this( VectorFactory.getDefault().createVector(dimensionality,1.0), numTrials );
    }

    /**
     * Creates a new instance of MultivariatePolyaDistribution
     * @param parameters
     * Parameters of the Dirichlet distribution, must be at least 2-dimensional
     * and each element must be positive.
     * @param numTrials
     * Number of trials in the distribution, must be greater than 0.
     */
    public MultivariatePolyaDistribution(
        final Vector parameters,
        final int numTrials )
    {
        this.setParameters(parameters);
        this.setNumTrials(numTrials);
    }

    /**
     * Copy Constructor.
     * @param other
     * MultivariatePolyaDistribution to copy.
     */
    public MultivariatePolyaDistribution(
        MultivariatePolyaDistribution other )
    {
        this( ObjectUtil.cloneSafe( other.getParameters() ), other.getNumTrials() );
    }

    @Override
    public MultivariatePolyaDistribution clone()
    {
        MultivariatePolyaDistribution clone =
            (MultivariatePolyaDistribution) super.clone();
        clone.setParameters( ObjectUtil.cloneSafe(this.getParameters()) );
        return clone;
    }

    @Override
    public Vector getMean()
    {
        return this.parameters.scale( this.numTrials / this.parameters.norm1() );
    }

    @Override
    public void sampleInto(
        final Random random,
        final int sampleCount,
        final Collection output)
    {
        DirichletDistribution prior =
            new DirichletDistribution( this.parameters );
        ArrayList dirichletSamples = prior.sample(random, sampleCount);

        final int dim = this.getInputDimensionality();
        final int N = this.getNumTrials();
        MultinomialDistribution conditional =
            new MultinomialDistribution( dim, N );
        conditional.setNumTrials(N);
        for( int i = 0; i < sampleCount; i++ )
        {
            conditional.setParameters(dirichletSamples.get(i));
            output.add( conditional.sample(random) );
        }
    }

    /**
     * Getter for numTrials
     * @return
     * Number of trials in the distribution, must be greater than 0.
     */
    public int getNumTrials()
    {
        return this.numTrials;
    }

    /**
     * Setter for numTrials
     * @param numTrials
     * Number of trials in the distribution, must be greater than 0.
     */
    public void setNumTrials(
        final int numTrials)
    {
        if( numTrials <= 0 )
        {
            throw new IllegalArgumentException( "numTrials must be > 0" );
        }
        this.numTrials = numTrials;
    }

    @Override
    public MultivariatePolyaDistribution.PMF getProbabilityFunction()
    {
        return new MultivariatePolyaDistribution.PMF( this );
    }

    @Override
    public Vector convertToVector()
    {
        return ObjectUtil.cloneSafe(this.getParameters());
    }

    @Override
    public void convertFromVector(
        final Vector parameters)
    {
        parameters.assertSameDimensionality( this.getParameters() );
        this.setParameters( ObjectUtil.cloneSafe(parameters) );
    }

    /**
     * Gets the dimensionality of the parameters
     * @return
     * Number of parameters in the distribution.
     */
    public int getInputDimensionality()
    {
        return (this.parameters != null) ? this.parameters.getDimensionality() : 0;
    }

    /**
     * Getter for parameters
     * @return
     * Parameters of the Dirichlet distribution, must be at least 2-dimensional
     * and each element must be positive.
     */
    public Vector getParameters()
    {
        return this.parameters;
    }

    /**
     * Setter for parameters
     * @param parameters
     * Parameters of the Dirichlet distribution, must be at least 2-dimensional
     * and each element must be positive.
     */
    public void setParameters(
        final Vector parameters)
    {

        final int N = parameters.getDimensionality();

        if( N < 2 )
        {
            throw new IllegalArgumentException( "Dimensionality must be >= 2" );
        }

        for( int i = 0; i < N; i++ )
        {
            if( parameters.getElement(i) <= 0.0 )
            {
                throw new IllegalArgumentException(
                    "All parameter elements must be > 0.0" );
            }
        }
        this.parameters = parameters;
    }

    @Override
    public MultinomialDistribution.Domain getDomain()
    {
        return new MultinomialDistribution.Domain(
            this.getInputDimensionality(), this.getNumTrials() );
    }

    @Override
    public int getDomainSize()
    {
        return this.getDomain().size();
    }

    @Override
    public String toString()
    {
        return "N = " + this.getNumTrials() + ", Parameters = " + this.getParameters();
    }

    /**
     * PMF of the MultivariatePolyaDistribution
     */
    public static class PMF
        extends MultivariatePolyaDistribution
        implements ProbabilityMassFunction,
        VectorInputEvaluator
    {

        /**
         * Creates a new instance of DirichletDistribution
         */
        public PMF()
        {
            super();
        }

        /**
         * Creates a new instance of MultivariatePolyaDistribution
         * @param dimensionality
         * Dimensionality of the distribution
         * @param numTrials
         * Number of trials in the distribution, must be greater than 0.
         */
        public PMF(
            final int dimensionality,
            final int numTrials )
        {
            super( dimensionality, numTrials );
        }

        /**
         * Creates a new instance of MultivariatePolyaDistribution
         * @param parameters
         * Parameters of the Dirichlet distribution, must be at least 2-dimensional
         * and each element must be positive.
         * @param numTrials
         * Number of trials in the distribution, must be greater than 0.
         *
         */
        public PMF(
            final Vector parameters,
            final int numTrials )
        {
            super( parameters, numTrials );
        }

        /**
         * Copy Constructor.
         * @param other
         * MultivariatePolyaDistribution to copy.
         */
        public PMF(
            MultivariatePolyaDistribution other )
        {
            super( other );
        }

        @Override
        public MultivariatePolyaDistribution.PMF getProbabilityFunction()
        {
            return this;
        }

        @Override
        public double logEvaluate(
            final Vector input)
        {
            final int dim = this.getInputDimensionality();
            input.assertDimensionalityEquals(dim);
            final int ni = (int) Math.round( input.norm1() );
            final int N = this.getNumTrials();
            final double A = this.parameters.norm1();
            if( ni != N )
            {
                return Math.log(0.0);
            }

            double logSum = 0.0;
            logSum += Math.log(ni);
            logSum += MathUtil.logBetaFunction(A, ni);
            for( int i = 0; i < dim; i++ )
            {
                double pi = this.parameters.getElement(i);
                double xi = input.getElement(i);
                if( (pi > 0.0) && (xi > 0.0) )
                {
                    logSum -= Math.log(xi);
                    logSum -= MathUtil.logBetaFunction( pi, xi );
                }
            }
            return logSum;
        }

        @Override
        public Double evaluate(
            final Vector input)
        {
            return Math.exp( this.logEvaluate(input) );
        }

        @Override
        public double getEntropy()
        {
            return ProbabilityMassFunctionUtil.getEntropy(this);
        }
        
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy