All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.statistics.distribution.ChineseRestaurantProcess Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                ChineseRestaurantProcess.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Feb 9, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.MathUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.ClosedFormComputableDiscreteDistribution;
import gov.sandia.cognition.statistics.ProbabilityMassFunction;
import gov.sandia.cognition.statistics.ProbabilityMassFunctionUtil;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.Random;
import java.util.Set;

/**
 * A Chinese Restaurant Process is a discrete stochastic processes that
 * partitions data points to clusters.  This is done by imagining a
 * restaurant with an infinite number of tables.  The first customer sits at
 * an empty table.  The next customer picks an existing table proportionate to
 * how many customers are already sitting at the various tables, and a new
 * table with some nonzero probability.  This results in a Dirichlet
 * distribution with a variable number of parameters, which grows approximately
 * as O(log(n)), where n is the number of customers to assign to tables.
 * @author Kevin R. Dixon
 * @since 3.0
 */
@PublicationReferences(
    references={
        @PublicationReference(
            author="Michael I. Jordan",
            title="Dirichlet Processes, Chinese Restaurant Processes and All That",
            year=2005,
            type=PublicationType.Conference,
            publication="NIPS Tutorial",
            url="http://www.cs.berkeley.edu/~jordan/nips-tutorial05.ps"
        )
        ,
        @PublicationReference(
            author="Wikipedia",
            title="http://en.wikipedia.org/wiki/Chinese_restaurant_process",
            year=2010,
            type=PublicationType.WebPage,
            url="http://en.wikipedia.org/wiki/Chinese_restaurant_process",
            notes="Very poor, unclear description."
        )
    }
)
public class ChineseRestaurantProcess 
    extends AbstractDistribution
    implements ClosedFormComputableDiscreteDistribution
{

    /**
     * Default concentration parameter, {@value}.
     */
    public static final double DEFAULT_ALPHA = 1.0;

    /**
     * Default number of customers, {@value}.
     */
    public static final int DEFAULT_NUM_CUSTOMERS = 2;

    /**
     * CRP concentration parameter, must be greater than zero.
     */
    protected double alpha;

    /**
     * Total number of customers that we will arrange around tables,
     * must be greater than zero.
     */
    protected int numCustomers;

    /** 
     * Creates a new instance of ChineseRestaurantProcess 
     */
    public ChineseRestaurantProcess()
    {
        this( DEFAULT_ALPHA, DEFAULT_NUM_CUSTOMERS );
    }

    /**
     * Creates a new instance of ChineseRestaurantProcess
     * @param alpha
     * CRP concentration parameter, must be greater than zero.
     * @param numCustomers
     * Total number of customers that we will arrange around tables,
     * must be greater than zero.
     */
    public ChineseRestaurantProcess(
        final double alpha,
        final int numCustomers)
    {
        this.setAlpha(alpha);
        this.setNumCustomers(numCustomers);
    }

    /**
     * Default constructor
     * @param other
     * CRP to copy
     */
    public ChineseRestaurantProcess(
        final ChineseRestaurantProcess other )
    {
        this( other.getAlpha(), other.getNumCustomers() );
    }

    @Override
    public ChineseRestaurantProcess clone()
    {
        return (ChineseRestaurantProcess) super.clone();
    }

    @Override
    public Vector getMean()
    {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public Vector sample(
        final Random random)
    {

        ArrayList tables = new ArrayList( this.numCustomers );
        for( int n = 0; n < this.getNumCustomers(); n++ )
        {
            final int tableIndex =
                sampleNextCustomer(tables, n, this.alpha, random);
            if( tableIndex >= tables.size() )
            {
                // Add an empty table to the Restaurant
                tables.add( 0 );
            }

            // Increment the number of customers at the indexed table
            int nc = tables.get(tableIndex) + 1;
            tables.set(tableIndex, nc);
        }

        Vector parameters = VectorFactory.getDefault().copyValues(tables);
        return parameters;

    }

    /**
     * Determines where the next customer sits, given the number of customers
     * already sitting at the various tables and the concentration parameter
     * alpha.
     * @param tables
     * Number of customers sitting at the various tables.
     * @param numCustomers
     * Number of customers already sitting, should equal the sum
     * of "tables".
     * @param alpha
     * Concentration parameter.
     * @param random
     * Random number generator.
     * @return
     * Index of the table where the next customer sits, according to the
     * Chinese Restaurant Process.
     */
    public static int sampleNextCustomer(
        final Collection tables,
        final int numCustomers,
        final double alpha,
        final Random random )
    {
        double p = random.nextDouble();
        double pnew = alpha / (numCustomers + alpha);
        p -= pnew;
        if( p <= 0.0 )
        {
            return tables.size();
        }

        int tableIndex = 0;
        for( Integer customersAtTable : tables )
        {
            final double tableProb = customersAtTable / (numCustomers + alpha);
            p -= tableProb;
            if( p <= 0.0 )
            {
                return tableIndex;
            }
            tableIndex++;
        }

        throw new IllegalArgumentException(
            "Bad computation in sampleNextcustomer!!!" );
    }

    @Override
    public void sampleInto(
        final Random random,
        final int sampleCount,
        final Collection output)
    {
        for (int i = 0; i < sampleCount; i++)
        {
            output.add(this.sample(random));
        }
    }

    /**
     * Getter for alpha.
     * @return
     * CRP concentration parameter, must be greater than zero.
     */
    public double getAlpha()
    {
        return this.alpha;
    }

    /**
     * Setter for alpha.
     * @param alpha
     * CRP concentration parameter, must be greater than zero.
     */
    public void setAlpha(
        final double alpha)
    {
        if( alpha <= 0.0 )
        {
            throw new IllegalArgumentException( "Alpha must be > 0.0" );
        }
        this.alpha = alpha;
    }

    /**
     * Getter for numCustomers
     * @return
     * Total number of customers that we will arrange around tables,
     * must be greater than zero.
     */
    public int getNumCustomers()
    {
        return this.numCustomers;
    }

    /**
     * Setter for numCustomers
     * @param numCustomers
     * Total number of customers that we will arrange around tables,
     * must be greater than zero.
     */
    public void setNumCustomers(
        final int numCustomers)
    {
        if( numCustomers <= 0 )
        {
            throw new IllegalArgumentException(
                "numCustomers must be > 0" );
        }
        this.numCustomers = numCustomers;
    }

    @Override
    public Set getDomain()
    {
        LinkedHashSet domain = new LinkedHashSet(
            this.getNumCustomers()*this.getNumCustomers() );
        for( int i = 1; i <= this.getNumCustomers(); i++ )
        {
            domain.addAll( new MultinomialDistribution.Domain( i, this.getNumCustomers() ) );
        }
        return domain;
    }

    @Override
    public int getDomainSize()
    {
        return this.getDomain().size();
    }

    @Override
    public ChineseRestaurantProcess.PMF getProbabilityFunction()
    {
        return new ChineseRestaurantProcess.PMF( this );
    }

    @Override
    public Vector convertToVector()
    {
        return VectorFactory.getDefault().copyValues(
            this.getAlpha(), this.getNumCustomers() );
    }

    @Override
    public void convertFromVector(
        final Vector parameters)
    {
        parameters.assertDimensionalityEquals(2);
        this.setAlpha( parameters.getElement(0) );
        this.setNumCustomers( (int) Math.round( parameters.getElement(1) ) );
    }

    /**
     * PMF of the Chinese Restaurant Process
     */
    public static class PMF
        extends ChineseRestaurantProcess
        implements ProbabilityMassFunction
    {

        /**
         * Creates a new instance of ChineseRestaurantProcess
         */
        public PMF()
        {
            super();
        }

        /**
         * Creates a new instance of ChineseRestaurantProcess
         * @param alpha
         * CRP concentration parameter, must be greater than zero.
         * @param numCustomers
         * Total number of customers that we will arrange around tables,
         * must be greater than zero.
         */
        public PMF(
            final double alpha,
            final int numCustomers)
        {
            super( alpha, numCustomers );
        }

        /**
         * Copy constructor
         * @param other
         * CRP to copy
         */
        public PMF(
            final ChineseRestaurantProcess other )
        {
            super( other );
        }

        @Override
        public ChineseRestaurantProcess.PMF getProbabilityFunction()
        {
            return this;
        }

        @Override
        public double getEntropy()
        {
            return ProbabilityMassFunctionUtil.getEntropy(this);
        }

        @Override
        public double logEvaluate(
            final Vector input)
        {
            final int numTables = input.getDimensionality();
            double logSum = numTables * Math.log( this.alpha );
            int totalCustomers = 0;
            for( int table = 0; table < numTables; table++ )
            {
                // We must have at least 1 customer at each table in the CRP.
                final double value = input.getElement(table);
                if( value < 1.0 ||
                    value > this.numCustomers )
                {
                    return Math.log(0.0);
                }

                double floor = Math.floor(value);
                double ceil = Math.ceil(value);
                if( floor != ceil )
                {
                    throw new IllegalArgumentException(
                        "Customers at each table must be an integer: " + input );
                }
                // Posterior time: 31.828
                // Posterior time: 39.563
                final int customersAtTable = (int) floor;
                logSum += MathUtil.logFactorial(customersAtTable-1);
                totalCustomers += customersAtTable;
            }

            logSum += MathUtil.logGammaFunction( this.alpha );
            logSum -= MathUtil.logGammaFunction( this.alpha + totalCustomers );

            return logSum;
        }

        @Override
        public Double evaluate(
            final Vector input)
        {
            return Math.exp( this.logEvaluate(input) );
        }
        
    }
    
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy