gov.sandia.cognition.statistics.bayesian.RejectionSampling Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: RejectionSampling.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Mar 3, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.bayesian;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.minimization.line.LineMinimizerDerivativeFree;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.UnivariateProbabilityDensityFunction;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.AbstractRandomized;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;
/**
* Rejection sampling is a method of inferring hidden parameters by using
* an easy-to-sample-from distribution (times a scale factor) that envelopes
* another distribution that is difficult to sample from. Typically, we sample
* from the parameter prior to infer the likelihood of the parameters given
* an observation sequence. In my limited experience, vanilla rejection
* sampling, implemented here, is inferior to ImportanceSamping.
* @param Type of observation
* @param Type of parameters to infer
* @author Kevin R. Dixon
* @since 3.0
*/
@PublicationReference(
author="Wikipedia",
title="Rejection Sampling",
type=PublicationType.WebPage,
year=2009,
url="http://en.wikipedia.org/wiki/Rejection_sampling"
)
public class RejectionSampling
extends AbstractRandomized
implements BayesianEstimator>
{
/**
* Default number of samples, {@value}.
*/
public static final int DEFAULT_NUM_SAMPLES = 1000;
/**
* Number of samples.
*/
private int numSamples;
/**
* Updater for the ImportanceSampling algorithm.
*/
protected RejectionSampling.Updater updater;
/**
* Creates a new instance of RejectionSampling
*/
public RejectionSampling()
{
super( null );
this.setNumSamples(DEFAULT_NUM_SAMPLES);
}
@Override
public RejectionSampling clone()
{
@SuppressWarnings("unchecked")
RejectionSampling clone =
(RejectionSampling) super.clone();
clone.setUpdater( ObjectUtil.cloneSafe( this.getUpdater() ) );
return clone;
}
@Override
public DataDistribution learn(
final Collection extends ObservationType> data)
{
DataDistribution retval =
new DefaultDataDistribution( this.getNumSamples() );
for( int n = 0; n < numSamples; n++ )
{
ParameterType parameter = null;
boolean accepted = false;
while( !accepted )
{
parameter = this.getUpdater().makeProposal( this.getRandom() );
double acceptProbability =
this.getUpdater().computeAcceptanceProbability(parameter,data);
double p = this.getRandom().nextDouble();
if( p <= acceptProbability )
{
accepted = true;
break;
}
}
retval.increment(parameter);
}
return retval;
}
/**
* Getter for numSamples
* @return
* Number of samples.
*/
public int getNumSamples()
{
return this.numSamples;
}
/**
* Setter for numSamples
* @param numSamples
* Number of samples.
*/
public void setNumSamples(
final int numSamples)
{
this.numSamples = numSamples;
}
/**
* Getter for updater
* @return
* Updater for the ImportanceSampling algorithm.
*/
public RejectionSampling.Updater getUpdater()
{
return this.updater;
}
/**
* Setter for updater
* @param updater
* Updater for the ImportanceSampling algorithm.
*/
public void setUpdater(
final RejectionSampling.Updater updater)
{
this.updater = updater;
}
/**
* Routine for estimating the minimum scalar needed to envelop the
* conjunctive distribution.
* @param
* Type of observations to use.
*/
public static class ScalarEstimator
{
/**
* Defines the parameter that connects the conditional and prior
* distributions.
*/
BayesianParameter,? extends UnivariateProbabilityDensityFunction> conjunctive;
/**
* Data to consider
*/
Iterable extends ObservationType> data;
/**
* Creates a new instance of ScalarEstimator
* @param conjunctive
* Defines the parameter that connects the conditional and prior
* distributions.
* @param data
* Data to consider
*/
public ScalarEstimator(
final BayesianParameter,? extends UnivariateProbabilityDensityFunction> conjunctive,
final Iterable extends ObservationType> data )
{
this.conjunctive = conjunctive;
this.data = data;
}
/**
* Computes the logarithm of the conjunctive likelihood for the given
* parameter
* @param parameter
* Parameter to update.
* @return
* Logarithm of the conjunctive likelihood for the given parameter
*/
public double logConjunctive(
final Double parameter )
{
double logSum = this.conjunctive.getParameterPrior().logEvaluate(parameter);
if( !Double.isInfinite(logSum) )
{
this.conjunctive.setValue(parameter);
logSum += BayesianUtil.logLikelihood(
this.conjunctive.getConditionalDistribution(), this.data);
}
return logSum;
}
/**
* Estimates the minimum scalar needed for the sampler distribution to
* envelope the conjunctive distribution
* @param sampler
* Distribution from which we sample and envelop the conjunctive
* distribution.
* @return
* Minimum scalar needed for the sampler distribution to envelope the
* conjunctive distribution
*/
public double estimateScalarFactor(
final UnivariateProbabilityDensityFunction sampler )
{
MinimizerFunction f = new MinimizerFunction( sampler );
LineMinimizerDerivativeFree minimizer =
new LineMinimizerDerivativeFree();
minimizer.setInitialGuess( sampler.getMean() );
InputOutputPair mode = minimizer.learn(f);
return Math.exp(-mode.getOutput());
}
/**
* Minimization function that measures the difference between the
* logarithm of the sampler function minus the logarithm of the
* conjunctive distribution.
*/
public class MinimizerFunction
implements Evaluator
{
/**
* Sampler function
*/
protected ProbabilityFunction sampler;
/**
* Creates a new instance of MinimizerFunction
* @param sampler
* Sampler function
*/
public MinimizerFunction(
final ProbabilityFunction sampler)
{
this.sampler = sampler;
}
@Override
public Double evaluate(
final Double parameter)
{
// Find the point where the conjuctive is the largest compared
// to the sampler: min(logSampler - logConjuctive)
final double logSampler = this.sampler.logEvaluate(parameter);
final double logCon = logConjunctive(parameter);
if( Double.isInfinite(logSampler) ||
Double.isInfinite(logCon) )
{
return Double.POSITIVE_INFINITY;
}
else
{
return logSampler-logCon;
}
}
}
}
/**
* Updater for ImportanceSampling
* @param Type of observation
* @param Type of parameters to infer
*/
public static interface Updater
extends CloneableSerializable
{
/**
* Computes the probability of accepting the parameter for the given
* data.
* @param parameter
* Parameter to consider
* @param data
* Data to consider.
* @return
* Probability of accepting the parameter
*/
public double computeAcceptanceProbability(
final ParameterType parameter,
final Iterable extends ObservationType> data );
/**
* Samples from the parameter prior
* @param random
* Random number generator.
* @return
* Location of the proposed sample
*/
public ParameterType makeProposal(
final Random random );
}
/**
* Default ImportanceSampling Updater that uses a BayesianParameter
* to compute the quantities of interest.
* @param Type of observation
* @param Type of parameters to infer
*/
public static class DefaultUpdater
extends AbstractCloneableSerializable
implements Updater
{
/**
* Defines the parameter that connects the conditional and prior
* distributions.
*/
protected BayesianParameter,? extends ProbabilityFunction> conjuctive;
/**
* Scale factor to multiply the sampler function by to envelop the
* conjunctive distribution.
*/
protected Double scale;
/**
* Distribution from which we sample and envelop the conjunctive
* distribution.
*/
private ProbabilityFunction sampler;
/**
* Number of proposals suggested
*/
protected int proposals;
/**
* Default constructor.
*/
public DefaultUpdater()
{
this( null );
}
/**
* Creates a new instance of DefaultUpdater
* @param conjuctive
* Defines the parameter that connects the conditional and prior
* distributions.
*/
public DefaultUpdater(
final BayesianParameter,? extends ProbabilityFunction> conjuctive)
{
this( conjuctive, (conjuctive != null) ? conjuctive.getParameterPrior() : null );
}
/**
* Creates a new instance of DefaultUpdater
* @param conjuctive
* Defines the parameter that connects the conditional and prior
* distributions.
* @param sampler
* Distribution from which we sample and envelop the conjunctive
* distribution.
*/
public DefaultUpdater(
final BayesianParameter,? extends ProbabilityFunction> conjuctive,
final ProbabilityFunction sampler)
{
this( conjuctive, null, sampler );
}
/**
* Creates a new instance of DefaultUpdater
* @param conjuctive
* Defines the parameter that connects the conditional and prior
* distributions.
* @param scale
* Scale factor to multiply the sampler function by to envelop the
* conjunctive distribution.
* @param sampler
* Distribution from which we sample and envelop the conjunctive
* distribution.
*/
public DefaultUpdater(
final BayesianParameter,? extends ProbabilityFunction> conjuctive,
final Double scale,
final ProbabilityFunction sampler)
{
this.setConjuctive(conjuctive);
this.setScale(scale);
this.setSampler(sampler);
this.setProposals(0);
}
@Override
public double computeAcceptanceProbability(
final ParameterType parameter,
final Iterable extends ObservationType> data)
{
if( this.scale == null )
{
this.scale = this.computeOptimalScale(data);
}
this.conjuctive.setValue(parameter);
double logSum = BayesianUtil.logLikelihood(
this.conjuctive.getConditionalDistribution(), data);
logSum += this.conjuctive.getParameterPrior().logEvaluate(parameter);
logSum -= this.sampler.logEvaluate(parameter);
logSum -= Math.log( this.getScale());
return Math.exp( logSum );
}
@Override
public ParameterType makeProposal(
final Random random)
{
this.proposals++;
ParameterType parameter = null;
boolean keepGoing = true;
while( keepGoing )
{
parameter = this.sampler.sample(random);
if( this.conjuctive.getParameterPrior().evaluate(parameter) > 0.0 )
{
keepGoing = false;
}
}
return parameter;
}
/**
* Getter for conjunctive
* @return
* Defines the parameter that connects the conditional and prior
* distributions.
*/
public BayesianParameter,? extends ProbabilityFunction> getConjuctive()
{
return this.conjuctive;
}
/**
* Setter for conjunctive
* @param conjuctive
* Defines the parameter that connects the conditional and prior
* distributions.
*/
public void setConjuctive(
final BayesianParameter,? extends ProbabilityFunction> conjuctive)
{
this.conjuctive = conjuctive;
}
/**
* Computes the optimal scale factor for enveloping the conjunctive
* distribution with the sampler function given the data
* @param data
* Data to consider
* @return
* optimal scale factor for enveloping the conjunctive
* distribution with the sampler function given the data
*/
@SuppressWarnings("unchecked")
public double computeOptimalScale(
final Iterable extends ObservationType> data )
{
ScalarEstimator estimator =
new ScalarEstimator(
(BayesianParameter,? extends UnivariateProbabilityDensityFunction>) this.getConjuctive(), data );
return estimator.estimateScalarFactor(
(UnivariateProbabilityDensityFunction) this.getSampler() );
}
/**
* Computes a Gaussian sample for the parameter, assuming it has is
* a Double, using importance sampling.
* @param data
* (Sub)set of the data to use to estimate the Gaussian
* @param random
* Random number generator
* @param numSamples
* Number of samples to create the Gaussian... doesn't need to be
* very large.
* @return
* Gaussian that has the appropriate mean and variance to generate
* parameters.
*/
@SuppressWarnings("unchecked")
public UnivariateGaussian.PDF computeGaussianSampler(
final Iterable extends ObservationType> data,
final Random random,
final int numSamples )
{
ArrayList extends ParameterType> parameters =
this.conjuctive.getParameterPrior().sample(random, numSamples);
UnivariateGaussian.WeightedMaximumLikelihoodEstimator mle =
new UnivariateGaussian.WeightedMaximumLikelihoodEstimator();
ArrayList> values =
new ArrayList>( parameters.size() );
for( int n = 0; n < parameters.size(); n++ )
{
ParameterType parameter = parameters.get(n);
if( this.conjuctive.getParameterPrior().evaluate(parameter) > 0.0 )
{
this.conjuctive.setValue(parameter);
double weight = BayesianUtil.logLikelihood(
this.conjuctive.getConditionalDistribution(), data);
values.add( new DefaultWeightedValue(
(Double) parameter, weight ) );
}
}
return mle.learn(values);
}
/**
* Getter for scale
* @return
* Scale factor to multiply the sampler function by to envelop the
* conjunctive distribution.
*/
public Double getScale()
{
return this.scale;
}
/**
* Setter for scale
* @param scale
* Scale factor to multiply the sampler function by to envelop the
* conjunctive distribution.
*/
public void setScale(
final Double scale)
{
this.scale = scale;
}
/**
* Getter for proposals
* @return
* Number of proposals suggested
*/
public int getProposals()
{
return this.proposals;
}
/**
* Setter for proposals
* @param proposals
* Number of proposals suggested
*/
protected void setProposals(
final int proposals)
{
this.proposals = proposals;
}
/**
* Getter for sampler
* @return
* Distribution from which we sample and envelop the conjunctive
* distribution.
*/
public ProbabilityFunction getSampler()
{
return this.sampler;
}
/**
* Setter for sampler
* @param sampler
* Distribution from which we sample and envelop the conjunctive
* distribution.
*/
public void setSampler(
final ProbabilityFunction sampler)
{
this.sampler = sampler;
}
}
}