gov.sandia.cognition.statistics.bayesian.MetropolisHastingsAlgorithm Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: MetropolisHastingsAlgorithm.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Sep 30, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.bayesian;
import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.WeightedValue;
/**
* An implementation of the Metropolis-Hastings MCMC algorithm, which is the
* most general formulation of MCMC but can be slow.
* @author Kevin R. Dixon
* @since 3.0
* @param
* Type of observations handled by the MCMC algorithm.
* @param
* Type of parameters to infer.
*/
@PublicationReference(
author="Wikipedia",
title="Metropolis–Hastings algorithm",
type=PublicationType.WebPage,
year=2009,
url="http://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm"
)
public class MetropolisHastingsAlgorithm
extends AbstractMarkovChainMonteCarlo
implements MeasurablePerformanceAlgorithm
{
/**
* Performance statistic name, {@value}.
*/
public static final String PERFORMANCE_NAME = "Current Log Likelihood";
/**
* Log Likelihood of the current parameters.
*/
private double currentLogLikelihood;
/**
* The object that makes proposal samples from the current location.
*/
protected MetropolisHastingsAlgorithm.Updater updater;
/**
* Creates a new instance of MetropolisHastingsAlgorithm.
*/
public MetropolisHastingsAlgorithm()
{
super();
this.setBurnInIterations(this.maxIterations/10);
this.setIterationsPerSample(this.maxIterations/100);
}
@Override
public MetropolisHastingsAlgorithm clone()
{
@SuppressWarnings("unchecked")
MetropolisHastingsAlgorithm clone =
(MetropolisHastingsAlgorithm) super.clone();
clone.setUpdater( ObjectUtil.cloneSafe( this.getUpdater() ) );
return clone;
}
@Override
protected boolean initializeAlgorithm()
{
this.currentLogLikelihood = Double.NEGATIVE_INFINITY;
return super.initializeAlgorithm();
}
@Override
protected void mcmcUpdate()
{
WeightedValue proposal = null;
double proposalLogLikelihood = 0.0;
boolean acceptProposal = false;
while( !acceptProposal )
{
proposal = this.updater.makeProposal(this.currentParameter);
proposalLogLikelihood = this.getUpdater().computeLogLikelihood(
proposal.getValue(), this.data );
if( Double.isInfinite(this.currentLogLikelihood) )
{
acceptProposal = true;
break;
}
final double pratio = Math.exp( proposalLogLikelihood - this.currentLogLikelihood );
final double qratio = proposal.getWeight();
final double a = qratio * pratio;
// If the proposal has higher probability, then always accept it
if( a >= 1.0 )
{
acceptProposal = true;
}
else
{
// Allow the sample with probability "a"
double r = this.random.nextDouble();
if( r <= a )
{
acceptProposal = true;
}
}
}
this.currentParameter = proposal.getValue();
this.currentLogLikelihood = proposalLogLikelihood;
}
public NamedValue getPerformance()
{
return new DefaultNamedValue( PERFORMANCE_NAME, this.currentLogLikelihood );
}
/**
* Gets the object that makes proposal samples from the current location.
* @return
* The object that makes proposal samples from the current location.
*/
public MetropolisHastingsAlgorithm.Updater getUpdater()
{
return this.updater;
}
/**
* Sets the object that makes proposal samples from the current location.
* @param updater
* The object that makes proposal samples from the current location.
*/
public void setUpdater(
MetropolisHastingsAlgorithm.Updater updater)
{
this.updater = updater;
}
public ParameterType createInitialLearnedObject()
{
return this.getUpdater().createInitialParameter();
}
/**
* Creates proposals for the MCMC steps.
* @param
* Type of observations handled by the MCMC algorithm.
* @param
* Type of parameters to infer.
*/
public static interface Updater
extends CloneableSerializable
{
/**
* Creates the initial parameterization
* @return
* Initial parameters
*/
public ParameterType createInitialParameter();
/**
* Computes the log likelihood of the data given the parameter
* @param parameter
* Parameter to consider
* @param data
* Data to consider
* @return
* log likelihood of the data given the parameter
*/
public double computeLogLikelihood(
ParameterType parameter,
Iterable extends ObservationType> data );
/**
* Makes a proposal update given the current parameter set
* @param location Location from which to make a proposal
* @return
* Location of the proposed sample, weighted by the "q" ratio.
*/
public WeightedValue makeProposal(
ParameterType location );
}
}