All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.improbable.keanu.algorithms.mcmc.nuts.AdaptiveQuadraticPotential Maven / Gradle / Ivy

package io.improbable.keanu.algorithms.mcmc.nuts;

import com.google.common.base.Preconditions;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import lombok.Getter;

import java.util.HashMap;
import java.util.Map;

import static io.improbable.keanu.algorithms.mcmc.nuts.VariableValues.dotProduct;
import static io.improbable.keanu.algorithms.mcmc.nuts.VariableValues.pow;
import static io.improbable.keanu.algorithms.mcmc.nuts.VariableValues.times;
import static io.improbable.keanu.algorithms.mcmc.nuts.VariableValues.withShape;
import static io.improbable.keanu.algorithms.mcmc.nuts.VariableValues.zeros;


public class AdaptiveQuadraticPotential implements Potential {

    private final double initialWeight;
    private final double initialMean;
    private final double initialVariance;
    private final int adaptionWindowSize;
    private VarianceCalculator forwardVariance;
    private VarianceCalculator backgroundVariance;
    private long nSamples;

    @Getter
    private Map variance;

    @Getter
    private Map standardDeviation;

    public AdaptiveQuadraticPotential(double initialMean,
                                      double initialVariance,
                                      double initialWeight,
                                      int adaptionWindowSize) {
        Preconditions.checkArgument(adaptionWindowSize > 1, "Adapt window size must be greater than 1");

        this.initialWeight = initialWeight;
        this.initialMean = initialMean;
        this.initialVariance = initialVariance;

        this.adaptionWindowSize = adaptionWindowSize;
        this.nSamples = 0;
    }

    public void initialize(Map shapeLike) {

        Map varianceShapedLike = withShape(initialVariance, shapeLike);
        Map meanShapedLike = withShape(initialMean, shapeLike);

        this.setVariance(varianceShapedLike);

        this.forwardVariance = new VarianceCalculator(meanShapedLike, varianceShapedLike, initialWeight);
        this.backgroundVariance = new VarianceCalculator(zeros(meanShapedLike), zeros(meanShapedLike), 0);
    }

    private void setVariance(Map variance) {
        this.variance = variance;
        this.standardDeviation = pow(this.variance, 0.5);
    }

    @Override
    public void update(Map position) {

        if (nSamples > 0 && nSamples % adaptionWindowSize == 0) {
            forwardVariance = backgroundVariance;
            backgroundVariance = new VarianceCalculator(zeros(variance), zeros(variance), 0);
        }

        forwardVariance.addSample(position);
        backgroundVariance.addSample(position);

        this.setVariance(forwardVariance.calculateCurrentVariance());

        nSamples++;
    }

    @Override
    public Map randomMomentum(KeanuRandom random) {

        Map result = new HashMap<>();
        for (VariableReference variable : standardDeviation.keySet()) {

            DoubleTensor standardDeviationForVariable = standardDeviation.get(variable);

            DoubleTensor randomForVariable = random
                .nextGaussian(standardDeviationForVariable.getShape())
                .divInPlace(standardDeviationForVariable);

            result.put(variable, randomForVariable);
        }

        return result;
    }

    @Override
    public Map getVelocity(Map momentum) {
        return times(variance, momentum);
    }

    @Override
    public double getKineticEnergy(Map momentum,
                                   Map velocity) {

        return 0.5 * dotProduct(momentum, velocity);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy