gov.sandia.cognition.statistics.bayesian.ImportanceSampling Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: ImportanceSampling.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Oct 22, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.bayesian;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.AbstractRandomized;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;
/**
* Importance sampling is a Monte Carlo inference technique where we sample
* from an easy distribution over the hidden variables (parameters) and then
* weight the result by the ratio of the likelihood of the parameters given
* the evidence and the likelihood of generating the parameters. This is a
* simple alternative to MCMC that is computationally simple, but does not
* scale well to many data points or many dimensions.
* @param Type of observation
* @param Type of parameters to infer
* @author Kevin R. Dixon
* @since 3.0
*/
@PublicationReference(
author="Wikipedia",
title="Importance Sampling",
type=PublicationType.WebPage,
year=2009,
url="http://en.wikipedia.org/wiki/Importance_sampling"
)
public class ImportanceSampling
extends AbstractRandomized
implements BayesianEstimator>
{
/**
* Default maximum number of samples, {@value}.
*/
public static final int DEFAULT_NUM_SAMPLES = 1000;
/**
* Updater for the ImportanceSampling algorithm.
*/
protected ImportanceSampling.Updater updater;
/**
* Number of samples.
*/
private int numSamples;
/**
* Creates a new instance of ImportanceSampling
*/
public ImportanceSampling()
{
super( null );
this.setNumSamples(numSamples);
this.numSamples = DEFAULT_NUM_SAMPLES;
}
@Override
public ImportanceSampling clone()
{
@SuppressWarnings("unchecked")
ImportanceSampling clone =
(ImportanceSampling) super.clone();
clone.setUpdater( ObjectUtil.cloneSafe( this.getUpdater() ) );
return clone;
}
@Override
public DataDistribution learn(
final Collection extends ObservationType> data)
{
ArrayList> weightedSamples =
new ArrayList>( this.getNumSamples());
double maxWeight = Double.NEGATIVE_INFINITY;
for( int n = 0; n < this.getNumSamples(); n++ )
{
ParameterType parameter = this.getUpdater().makeProposal(random);
double ll = this.getUpdater().computeLogLikelihood(parameter, data);
double lq = this.getUpdater().computeLogImportanceValue(parameter);
double weight = ll - lq;
if( maxWeight < weight )
{
maxWeight = weight;
}
weightedSamples.add( new DefaultWeightedValue( parameter, weight ) );
}
maxWeight -= Math.log(Double.MAX_VALUE/ this.getNumSamples() / 2.0 );
DataDistribution retval =
new DefaultDataDistribution( this.getNumSamples());
for( DefaultWeightedValue weightedSample : weightedSamples )
{
double mass = Math.exp(weightedSample.getWeight() - maxWeight);
retval.increment( weightedSample.getValue(), mass );
}
return retval;
}
/**
* Getter for updater
* @return
* Updater for the ImportanceSampling algorithm.
*/
public ImportanceSampling.Updater getUpdater()
{
return this.updater;
}
/**
* Setter for updater
* @param updater
* Updater for the ImportanceSampling algorithm.
*/
public void setUpdater(
final ImportanceSampling.Updater updater)
{
this.updater = updater;
}
/**
* Getter for numSamples
* @return
* Number of samples.
*/
public int getNumSamples()
{
return this.numSamples;
}
/**
* Setter for numSamples
* @param numSamples
* Number of samples.
*/
public void setNumSamples(
final int numSamples)
{
this.numSamples = numSamples;
}
/**
* Updater for ImportanceSampling
* @param Type of observation
* @param Type of parameters to infer
*/
public static interface Updater
extends CloneableSerializable
{
/**
* Computes the log likelihood of the data given the parameter
* @param parameter
* Parameter to consider
* @param data
* Data to consider
* @return
* log likelihood of the data given the parameter
*/
public double computeLogLikelihood(
final ParameterType parameter,
final Iterable extends ObservationType> data );
/**
* Computes the parameter's importance weight.
* @param parameter
* Parameter to consider
* @return
* Importance value
*/
public double computeLogImportanceValue(
final ParameterType parameter );
/**
* Samples from the parameter prior
* @param random
* Random number generator.
* @return
* Location of the proposed sample
*/
public ParameterType makeProposal(
final Random random );
}
/**
* Default ImportanceSampling Updater that uses a BayesianParameter
* to compute the quantities of interest.
* @param Type of observation
* @param Type of parameters to infer
*/
public static class DefaultUpdater
extends AbstractCloneableSerializable
implements Updater
{
/**
* Defines the parameter that connects the conditional and prior
* distributions.
*/
protected BayesianParameter,? extends ProbabilityFunction> conjuctive;
/**
* Default constructor.
*/
public DefaultUpdater()
{
this( null );
}
/**
* Creates a new instance of DefaultUpdater
* @param conjuctive
* Defines the parameter that connects the conditional and prior
* distributions.
*/
public DefaultUpdater(
final BayesianParameter,? extends ProbabilityFunction> conjuctive)
{
this.setConjuctive(conjuctive);
}
@Override
public double computeLogLikelihood(
final ParameterType parameter,
final Iterable extends ObservationType> data)
{
this.conjuctive.setValue(parameter);
return BayesianUtil.logLikelihood(
this.conjuctive.getConditionalDistribution(), data);
}
@Override
public double computeLogImportanceValue(
final ParameterType parameter)
{
return this.conjuctive.getParameterPrior().logEvaluate(parameter);
}
@Override
public ParameterType makeProposal(
final Random random)
{
return this.conjuctive.getParameterPrior().sample(random);
}
/**
* Getter for conjunctive
* @return
* Defines the parameter that connects the conditional and prior
* distributions.
*/
public BayesianParameter,? extends ProbabilityFunction> getConjuctive()
{
return this.conjuctive;
}
/**
* Setter for conjunctive
* @param conjuctive
* Defines the parameter that connects the conditional and prior
* distributions.
*/
public void setConjuctive(
final BayesianParameter,? extends ProbabilityFunction> conjuctive)
{
this.conjuctive = conjuctive;
}
}
}