All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.statistics.distribution.LinearMixtureModel Maven / Gradle / Ivy

/*
 * File:                LinearMixtureModel.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright November 3, 2006, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 */

package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.Distribution;
import gov.sandia.cognition.statistics.ProbabilityMassFunctionUtil;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Random;

/**
 * A linear mixture of RandomVariables, with a prior probability distribution.
 * The posterior pdf is:
 * p(x|this) = \sum_{y\in this} p(x|y,this)P(y|this),
 * where p(x|y,this) is the pdf of the underlying RandomVariable, and
 * P(y|this) is the prior probability of RandomVariable y in this.
 *
 * @param  
 *      Type of data in this mixture model
 * @param 
 *      The type of the internal distributions inside the mixture.
 * @author Kevin R. Dixon
 * @since  1.0
 *
 */
@PublicationReference(
    author="Wikipedia",
    title="Mixture Model",
    type=PublicationType.WebPage,
    year=2009,
    url="http://en.wikipedia.org/wiki/Mixture_model"
)
public abstract class LinearMixtureModel>
    extends AbstractDistribution
{
    
    /**
     * Underlying distributions from which we sample
     */
    protected ArrayList distributions;

    /**
     * Weights proportionate by which the distributions are sampled
     */
    protected double[] priorWeights;

    /**
     * Creates a new instance of LinearMixtureModel
     * @param distributions
     * Underlying distributions from which we sample
     */
    public LinearMixtureModel(
        Collection distributions )
    {
        this( distributions, null );
    }

    /**
     * Creates a new instance of LinearMixtureModel
     * @param distributions
     * Underlying distributions from which we sample
     * @param priorWeights
     * Weights proportionate by which the distributions are sampled
     */
    public LinearMixtureModel(
        Collection distributions,
        double[] priorWeights)
    {

        if( priorWeights == null )
        {
            priorWeights = new double[distributions.size()];
            Arrays.fill(priorWeights, 1.0);
        }

        if( distributions.size() != priorWeights.length )
        {
            throw new IllegalArgumentException(
                "Distribution count must equal number of priors" );
        }
        for( int i = 0; i < priorWeights.length; i++ )
        {
            if( priorWeights[i] < 0.0 )
            {
                throw new IllegalArgumentException( "weights must be >= 0.0!" );
            }
        }

        this.setDistributions(CollectionUtil.asArrayList(distributions));
        this.setPriorWeights(priorWeights);
    }

    @Override
    @SuppressWarnings("unchecked")
    public LinearMixtureModel clone()
    {
        LinearMixtureModel clone = 
            (LinearMixtureModel) super.clone();
        clone.setDistributions( ObjectUtil.cloneSmartElementsAsArrayList(
            this.getDistributions() ) );
        clone.setPriorWeights( ObjectUtil.cloneSmart( this.getPriorWeights() ) );
        return clone;
    }

    @Override
    public String toString()
    {
        StringBuilder retval = new StringBuilder(1000);
        retval.append("LinearMixtureModel has " + this.getDistributionCount() + " distributions:\n");
        int k = 0;
        for( DistributionType distribution : this.getDistributions() )
        {
            retval.append( " " + k + ": Prior: " + this.getPriorWeights()[k] + ", Distribution: " + distribution + "\n" );
            k++;
        }
        return retval.toString();
    }

    /**
     * Getter for distributions
     * @return
     * Underlying distributions from which we sample
     */
    public ArrayList getDistributions()
    {
        return this.distributions;
    }

    /**
     * Setter for distributions
     * @param distributions
     * Underlying distributions from which we sample
     */
    public void setDistributions(
        ArrayList distributions)
    {
        this.distributions = distributions;
    }
    
    /**
     * Gets the number of distributions in the model
     * @return
     * Number of distributions in the model
     */
    public int getDistributionCount()
    {
        return this.distributions.size();
    }

    @Override
    public DataType sample(
        Random random)
    {
        final DistributionType d = ProbabilityMassFunctionUtil.sampleSingle(
            this.getPriorWeights(), this.getDistributions(), random);
        return d.sample(random);
    }
    
    @Override
    public void sampleInto(
        final Random random,
        final int sampleCount,
        final Collection output)
    {
        // Build the cumulative distribution for batch sampling.
        final int distributionCount = this.getDistributionCount();
        final double[] priorWeights = this.getPriorWeights();
        final double[] cumulativeWeights = new double[distributionCount];
        double sum = 0.0;
        for(int n = 0; n < distributionCount; n++)
        {
            sum += priorWeights[n];
            cumulativeWeights[n] = sum;
        }

        // Sample each of the mixtures.
        for (int i = 0; i < sampleCount; i++)
        {
            final DistributionType d = ProbabilityMassFunctionUtil.sample(
                cumulativeWeights, this.getDistributions(), random);
            output.add(d.sample(random));
        }
    }

    /**
     * Getter for priorWeights
     * @return
     * Weights proportionate by which the distributions are sampled
     */
    public double[] getPriorWeights()
    {
        return this.priorWeights;
    }

    /**
     * Getter for priorWeights
     * @param priorWeights
     * Weights proportionate by which the distributions are sampled
     */
    public void setPriorWeights(
        double[] priorWeights)
    {
        this.priorWeights = priorWeights;
    }

    /**
     * Computes the sum of the prior weights
     * @return
     * Sum of the prior weights
     */
    public double getPriorWeightSum()
    {
        double sum = 0.0;
        final int K = this.getPriorWeights().length;
        for( int k = 0; k < K; k++ )
        {
            sum += this.getPriorWeights()[k];
        }
        return (sum <= 0.0) ? 1.0 : sum;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy