All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.statistics.bayesian.ImportanceSampling Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                ImportanceSampling.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright Oct 22, 2009, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 *
 */

package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.AbstractRandomized;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;

/**
 * Importance sampling is a Monte Carlo inference technique where we sample
 * from an easy distribution over the hidden variables (parameters) and then
 * weight the result by the ratio of the likelihood of the parameters given
 * the evidence and the likelihood of generating the parameters.  This is a
 * simple alternative to MCMC that is computationally simple, but does not
 * scale well to many data points or many dimensions.
 * @param  Type of observation
 * @param  Type of parameters to infer
 * @author Kevin R. Dixon
 * @since 3.0
 */
@PublicationReference(
    author="Wikipedia",
    title="Importance Sampling",
    type=PublicationType.WebPage,
    year=2009,
    url="http://en.wikipedia.org/wiki/Importance_sampling"
)
public class ImportanceSampling
    extends AbstractRandomized
    implements BayesianEstimator>
{

    /**
     * Default maximum number of samples, {@value}.
     */
    public static final int DEFAULT_NUM_SAMPLES = 1000;

    /**
     * Updater for the ImportanceSampling algorithm.
     */
    protected ImportanceSampling.Updater updater;

    /**
     * Number of samples.
     */
    private int numSamples;

    /** 
     * Creates a new instance of ImportanceSampling 
     */
    public ImportanceSampling()
    {
        super( null );
        this.setNumSamples(numSamples);
        this.numSamples = DEFAULT_NUM_SAMPLES;
    }

    @Override
    public ImportanceSampling clone()
    {
        @SuppressWarnings("unchecked")
        ImportanceSampling clone =
            (ImportanceSampling) super.clone();
        clone.setUpdater( ObjectUtil.cloneSafe( this.getUpdater() ) );
        return clone;
    }

    @Override
    public DataDistribution learn(
        final Collection data)
    {

        ArrayList> weightedSamples =
            new ArrayList>( this.getNumSamples());

        double maxWeight = Double.NEGATIVE_INFINITY;
        for( int n = 0; n < this.getNumSamples(); n++ )
        {
            ParameterType parameter = this.getUpdater().makeProposal(random);
            double ll = this.getUpdater().computeLogLikelihood(parameter, data);
            double lq = this.getUpdater().computeLogImportanceValue(parameter);
            double weight = ll - lq;
            if( maxWeight < weight )
            {
                maxWeight = weight;
            }
            weightedSamples.add( new DefaultWeightedValue( parameter, weight ) );
        }

        maxWeight -= Math.log(Double.MAX_VALUE/ this.getNumSamples() / 2.0 );

        DataDistribution retval =
            new DefaultDataDistribution( this.getNumSamples());
        for( DefaultWeightedValue weightedSample : weightedSamples )
        {
            double mass = Math.exp(weightedSample.getWeight() - maxWeight);
            retval.increment( weightedSample.getValue(), mass );
        }

        return retval;

    }

    /**
     * Getter for updater
     * @return
     * Updater for the ImportanceSampling algorithm.
     */
    public ImportanceSampling.Updater getUpdater()
    {
        return this.updater;
    }

    /**
     * Setter for updater
     * @param updater
     * Updater for the ImportanceSampling algorithm.
     */
    public void setUpdater(
        final ImportanceSampling.Updater updater)
    {
        this.updater = updater;
    }

    /**
     * Getter for numSamples
     * @return
     * Number of samples.
     */
    public int getNumSamples()
    {
        return this.numSamples;
    }

    /**
     * Setter for numSamples
     * @param numSamples
     * Number of samples.
     */
    public void setNumSamples(
        final int numSamples)
    {
        this.numSamples = numSamples;
    }


    /**
     * Updater for ImportanceSampling
     * @param  Type of observation
     * @param  Type of parameters to infer
     */
    public static interface Updater
        extends CloneableSerializable
    {

        /**
         * Computes the log likelihood of the data given the parameter
         * @param parameter
         * Parameter to consider
         * @param data
         * Data to consider
         * @return
         * log likelihood of the data given the parameter
         */
        public double computeLogLikelihood(
            final ParameterType parameter,
            final Iterable data );

        /**
         * Computes the parameter's importance weight.
         * @param parameter
         * Parameter to consider
         * @return
         * Importance value
         */
        public double computeLogImportanceValue(
            final ParameterType parameter );

        /**
         * Samples from the parameter prior
         * @param random
         * Random number generator.
         * @return
         * Location of the proposed sample
         */
        public ParameterType makeProposal(
            final Random random );

    }

    /**
     * Default ImportanceSampling Updater that uses a BayesianParameter
     * to compute the quantities of interest.
     * @param  Type of observation
     * @param  Type of parameters to infer
     */
    public static class DefaultUpdater
        extends AbstractCloneableSerializable
        implements Updater
    {

        /**
         * Defines the parameter that connects the conditional and prior
         * distributions.
         */
        protected BayesianParameter,? extends ProbabilityFunction> conjuctive;

        /**
         * Default constructor.
         */
        public DefaultUpdater()
        {
            this( null );
        }

        /**
         * Creates a new instance of DefaultUpdater
         * @param conjuctive
         * Defines the parameter that connects the conditional and prior
         * distributions.
         */
        public DefaultUpdater(
            final BayesianParameter,? extends ProbabilityFunction> conjuctive)
        {
            this.setConjuctive(conjuctive);
        }

        @Override
        public double computeLogLikelihood(
            final ParameterType parameter,
            final Iterable data)
        {
            this.conjuctive.setValue(parameter);
            return BayesianUtil.logLikelihood(
                this.conjuctive.getConditionalDistribution(), data);
        }

        @Override
        public double computeLogImportanceValue(
            final ParameterType parameter)
        {
            return this.conjuctive.getParameterPrior().logEvaluate(parameter);
        }

        @Override
        public ParameterType makeProposal(
            final Random random)
        {
            return this.conjuctive.getParameterPrior().sample(random);
        }

        /**
         * Getter for conjunctive
         * @return
         * Defines the parameter that connects the conditional and prior
         * distributions.
         */
        public BayesianParameter,? extends ProbabilityFunction> getConjuctive()
        {
            return this.conjuctive;
        }

        /**
         * Setter for conjunctive
         * @param conjuctive
         * Defines the parameter that connects the conditional and prior
         * distributions.
         */
        public void setConjuctive(
            final BayesianParameter,? extends ProbabilityFunction> conjuctive)
        {
            this.conjuctive = conjuctive;
        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy