All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.statistics.bayesian.MetropolisHastingsAlgorithm Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                MetropolisHastingsAlgorithm.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright Sep 30, 2009, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 *
 */

package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.WeightedValue;

/**
 * An implementation of the Metropolis-Hastings MCMC algorithm, which is the
 * most general formulation of MCMC but can be slow.
 * @author Kevin R. Dixon
 * @since 3.0
 * @param 
 * Type of observations handled by the MCMC algorithm.
 * @param 
 * Type of parameters to infer.
 */
@PublicationReference(
    author="Wikipedia",
    title="Metropolis–Hastings algorithm",
    type=PublicationType.WebPage,
    year=2009,
    url="http://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm"
)
public class MetropolisHastingsAlgorithm
    extends AbstractMarkovChainMonteCarlo
    implements MeasurablePerformanceAlgorithm
{

    /**
     * Performance statistic name, {@value}.
     */
    public static final String PERFORMANCE_NAME = "Current Log Likelihood";

    /**
     * Log Likelihood of the current parameters.
     */
    private double currentLogLikelihood;

    /**
     * The object that makes proposal samples from the current location.
     */
    protected MetropolisHastingsAlgorithm.Updater updater;

    /**
     * Creates a new instance of MetropolisHastingsAlgorithm.
     */
    public MetropolisHastingsAlgorithm()
    {
        super();
        this.setBurnInIterations(this.maxIterations/10);
        this.setIterationsPerSample(this.maxIterations/100);
    }

    @Override
    public MetropolisHastingsAlgorithm clone()
    {
        @SuppressWarnings("unchecked")
        MetropolisHastingsAlgorithm clone =
            (MetropolisHastingsAlgorithm) super.clone();
        clone.setUpdater( ObjectUtil.cloneSafe( this.getUpdater() ) );
        return clone;
    }

    @Override
    protected boolean initializeAlgorithm()
    {
        this.currentLogLikelihood = Double.NEGATIVE_INFINITY;
        return super.initializeAlgorithm();
    }

    @Override
    protected void mcmcUpdate()
    {

        WeightedValue proposal = null;
        double proposalLogLikelihood = 0.0;
        boolean acceptProposal = false;
        while( !acceptProposal )
        {
            proposal = this.updater.makeProposal(this.currentParameter);
            proposalLogLikelihood = this.getUpdater().computeLogLikelihood(
                proposal.getValue(), this.data );
            
            if( Double.isInfinite(this.currentLogLikelihood) )
            {
                acceptProposal = true;
                break;
            }

            final double pratio = Math.exp( proposalLogLikelihood - this.currentLogLikelihood );
            final double qratio = proposal.getWeight();

            final double a = qratio * pratio;

            // If the proposal has higher probability, then always accept it
            if( a >= 1.0 )
            {
                acceptProposal = true;
            }
            else
            {
                // Allow the sample with probability "a"
                double r = this.random.nextDouble();
                if( r <= a )
                {
                    acceptProposal = true;
                }
            }
        }

        this.currentParameter = proposal.getValue();
        this.currentLogLikelihood = proposalLogLikelihood;

    }

    public NamedValue getPerformance()
    {
        return new DefaultNamedValue( PERFORMANCE_NAME, this.currentLogLikelihood );
    }

    /**
     * Gets the object that makes proposal samples from the current location.
     * @return
     * The object that makes proposal samples from the current location.
     */
    public MetropolisHastingsAlgorithm.Updater getUpdater()
    {
        return this.updater;
    }

    /**
     * Sets the object that makes proposal samples from the current location.
     * @param updater
     * The object that makes proposal samples from the current location.
     */
    public void setUpdater(
        MetropolisHastingsAlgorithm.Updater updater)
    {
        this.updater = updater;
    }

    public ParameterType createInitialLearnedObject()
    {
        return this.getUpdater().createInitialParameter();
    }

    /**
     * Creates proposals for the MCMC steps.
     * @param 
     * Type of observations handled by the MCMC algorithm.
     * @param 
     * Type of parameters to infer.
     */
    public static interface Updater
        extends CloneableSerializable
    {

        /**
         * Creates the initial parameterization
         * @return
         * Initial parameters
         */
        public ParameterType createInitialParameter();

        /**
         * Computes the log likelihood of the data given the parameter
         * @param parameter
         * Parameter to consider
         * @param data
         * Data to consider
         * @return
         * log likelihood of the data given the parameter
         */
        public double computeLogLikelihood(
            ParameterType parameter,
            Iterable data );

        /**
         * Makes a proposal update given the current parameter set
         * @param location Location from which to make a proposal
         * @return
         * Location of the proposed sample, weighted by the "q" ratio.
         */
        public WeightedValue makeProposal(
            ParameterType location );

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy