All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.broadinstitute.hellbender.utils.mcmc.MinibatchSliceSampler Maven / Gradle / Ivy

The newest version!
package org.broadinstitute.hellbender.utils.mcmc;


import org.apache.commons.math3.distribution.TDistribution;
import org.apache.commons.math3.primes.Primes;
import org.apache.commons.math3.random.RandomGenerator;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * Implements slice sampling of a continuous, univariate, unnormalized probability density function (PDF),
 * which is assumed to be unimodal.  See Neal 2003 at https://projecteuclid.org/euclid.aos/1056562461 for details.
 * Minibatching is implemented as in Dubois et al. 2014 at http://proceedings.mlr.press/v33/dubois14.pdf and requires
 * that the PDF, which is assumed to be a posterior function of a parameter value and the data, is specified in terms
 * of a prior, a likelihood, and the data.
 *
 * Note that likelihoods are cached and retrieved based on the hash of {@code DATA}, i.e., we assume it is
 * cheaper to compute this hash rather than the likelihood.
 *
 * @author Samuel Lee <[email protected]>
 */
public final class MinibatchSliceSampler extends AbstractSliceSampler {
    private final List data;
    private final Function logPrior;
    private final BiFunction logLikelihood;
    private final Integer minibatchSize;
    private final Double approxThreshold;

    private final int numDataPoints;

    private Double xSampleCache = null;
    private Double logPriorCache = null;
    private Map logLikelihoodsCache = null;    //data -> log likelihood

    /**
     * Creates a new sampler for a bounded univariate random variable, given a random number generator, a list of data,
     * a continuous, univariate, unimodal, unnormalized log probability density function
     * (assumed to be a posterior and specified by a prior and a likelihood),
     * hard limits on the random variable, a step width, a minibatch size, and a minibatch approximation threshold.
     * @param rng                       random number generator, never {@code null}
     * @param data                      list of data, never {@code null}
     * @param logPrior                  log prior component of continuous, univariate, unimodal log posterior (up to additive constant), never {@code null}
     * @param logLikelihood             log likelihood component of continuous, univariate, unimodal log posterior (up to additive constant), never {@code null}
     * @param xMin                      minimum allowed value of the random variable
     * @param xMax                      maximum allowed value of the random variable
     * @param width                     step width for slice expansion
     * @param minibatchSize             minibatch size
     * @param approxThreshold           threshold for approximation used in {@link MinibatchSliceSampler#isGreaterThanSliceHeight};
     *                                  approximation is exact when this threshold is zero
     */
    public MinibatchSliceSampler(final RandomGenerator rng,
                                 final List data,
                                 final Function logPrior,
                                 final BiFunction logLikelihood,
                                 final double xMin,
                                 final double xMax,
                                 final double width,
                                 final int minibatchSize,
                                 final double approxThreshold) {
        super(rng, xMin, xMax, width);
        Utils.nonNull(data);
        Utils.nonNull(logPrior);
        Utils.nonNull(logLikelihood);
        Utils.validateArg(minibatchSize > 1, "Minibatch size must be greater than 1.");
        ParamUtils.isPositiveOrZero(approxThreshold, "Minibatch approximation threshold must be non-negative.");
        this.data = Collections.unmodifiableList(new ArrayList<>(data));
        this.logPrior = logPrior;
        this.logLikelihood = logLikelihood;
        this.minibatchSize = minibatchSize;
        this.approxThreshold = approxThreshold;
        numDataPoints = data.size();
    }

    /**
     * Creates a new sampler for an unbounded univariate random variable, given a random number generator, a list of data,
     * a continuous, univariate, unimodal, unnormalized log probability density function
     * (assumed to be a posterior and specified by a prior and a likelihood),
     * a step width, a minibatch size, and a minibatch approximation threshold.
     * @param rng                       random number generator, never {@code null}
     * @param data                      list of data, never {@code null}
     * @param logPrior                  log prior component of continuous, univariate, unimodal log posterior (up to additive constant), never {@code null}
     * @param logLikelihood             log likelihood component of continuous, univariate, unimodal log posterior (up to additive constant), never {@code null}
     * @param width                     step width for slice expansion
     * @param minibatchSize             minibatch size
     * @param approxThreshold           threshold for approximation used in {@link MinibatchSliceSampler#isGreaterThanSliceHeight};
     *                                  approximation is exact when this threshold is zero
     */
    public MinibatchSliceSampler(final RandomGenerator rng,
                                 final List data,
                                 final Function logPrior,
                                 final BiFunction logLikelihood,
                                 final double width,
                                 final int minibatchSize,
                                 final double approxThreshold) {
        this(rng, data, logPrior, logLikelihood, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, width, minibatchSize, approxThreshold);
    }

    /**
     * Implements the OnSlice procedure from Dubois et al. 2014.
     */
    @Override
    boolean isGreaterThanSliceHeight(final double xProposed,
                                     final double xSample,
                                     final double z) {
        if (xProposed < xMin || xMax < xProposed) {
            return false;
        }

        //we cache values calculated from xSample, since this method is called multiple times for the same value
        //of xSample when expanding slice interval and proposing samples
        if (xSampleCache == null || xSampleCache != xSample) {
            xSampleCache = xSample;
            logPriorCache = logPrior.apply(xSample);
            logLikelihoodsCache = new HashMap<>(numDataPoints);
        }
        if (!(xSampleCache != null && logPriorCache != null && logLikelihoodsCache != null)) {
            throw new GATKException.ShouldNeverReachHereException("Cache for xSample is in an invalid state.");
        }

        //if no data points are provided, sample from the prior
        if (numDataPoints == 0) {
            return logPrior.apply(xProposed) > logPriorCache - z;
        }

        //see Eq. 1 of Dubois et al. 2014
        final double mu0 = (logPriorCache - logPrior.apply(xProposed) - z) / numDataPoints;

        //initialize the lazy data iterator (or just use the standard iterator if only a single batch is required)
        final int numMinibatches = Math.max(numDataPoints / minibatchSize, 1);
        final Iterator shuffledDataIterator = numMinibatches > 1
                ? lazyShuffleIterator(rng, data)
                : data.iterator();

        //initialize running quantities needed for statistical test
        int numDataIndicesSeen = 0;
        double logLikelihoodDifferencesMean = 0.;
        double logLikelihoodDifferencesSquaredMean = 0.;

        for (int minibatchIndex = 0; minibatchIndex < numMinibatches; minibatchIndex++) {
            //get a minibatch of data
            final int dataIndexStart = minibatchIndex * minibatchSize;
            final int dataIndexEnd = Math.min((minibatchIndex + 1) * minibatchSize, numDataPoints);
            final int actualMinibatchSize = dataIndexEnd - dataIndexStart;  //equals minibatchSize except perhaps for last minibatch
            final List dataMinibatch = IntStream.range(0, actualMinibatchSize).boxed()
                    .map(i -> shuffledDataIterator.next())
                    .collect(Collectors.toList());

            //calculate quantities for this minibatch
            double logLikelihoodDifferencesMinibatchSum = 0.;
            double logLikelihoodDifferencesSquaredMinibatchSum = 0.;
            for (final DATA dataPoint : dataMinibatch) {
                final double logLikelihoodxSample = logLikelihoodsCache.computeIfAbsent(
                        dataPoint, d -> logLikelihood.apply(d, xSample));
                final double logLikelihoodxProposed = logLikelihood.apply(dataPoint, xProposed);
                final double logLikelihoodDifference = logLikelihoodxProposed - logLikelihoodxSample;
                logLikelihoodDifferencesMinibatchSum += logLikelihoodDifference;
                logLikelihoodDifferencesSquaredMinibatchSum += logLikelihoodDifference * logLikelihoodDifference;
            }

            //update running quantities
            logLikelihoodDifferencesMean =
                    (numDataIndicesSeen * logLikelihoodDifferencesMean + logLikelihoodDifferencesMinibatchSum) /
                            (numDataIndicesSeen + actualMinibatchSize);
            logLikelihoodDifferencesSquaredMean =
                    (numDataIndicesSeen * logLikelihoodDifferencesSquaredMean + logLikelihoodDifferencesSquaredMinibatchSum) /
                            (numDataIndicesSeen + actualMinibatchSize);
            numDataIndicesSeen += actualMinibatchSize;

            //if there is only a single batch, then there is no need to perform the statistical test
            //(as all running quantities will be exact)
            if (numMinibatches == 1) {
                break;
            }

            //perform the statistical test and terminate if appropriate
            //(i.e., if we can determine if we are OnSlice to within the approximation threshold)
            final double s = Math.sqrt(1. - (double) numDataIndicesSeen / numDataPoints) *
                    Math.sqrt((logLikelihoodDifferencesSquaredMean - Math.pow(logLikelihoodDifferencesMean, 2)) / (numDataIndicesSeen - 1));
            final double delta = 1. - new TDistribution(null, numDataIndicesSeen - 1)
                    .cumulativeProbability(Math.abs((logLikelihoodDifferencesMean - mu0) / s));
            if (delta < approxThreshold) {
                break;
            }
        }
        return logLikelihoodDifferencesMean > mu0;
    }

    /**
     * To efficiently sample without replacement with the possibility of early stopping when creating minibatches,
     * we lazily shuffle to avoid unnecessarily shuffling all data.  Uses the properties of relative primes and is
     * random enough for our purposes.  Adapted from https://stackoverflow.com/questions/16165128/lazy-shuffle-algorithms.
     */
    private static  Iterator lazyShuffleIterator(final RandomGenerator rng,
                                                       final List data) {
        final int numDataPoints = data.size();

        //find first prime greater than or equal to numDataPoints
        final int nextPrime = Primes.nextPrime(numDataPoints);

        return new Iterator() {
            int numSeen = 0;
            int index = rng.nextInt(numDataPoints) + 1;
            final int increment = index;

            public boolean hasNext() {
                return numSeen < data.size();
            }

            @Override
            public T next() {
                while (true) {
                    index = (index + increment) % nextPrime;
                    if (index < numDataPoints) {
                        numSeen++;
                        return data.get(index);
                    }
                }
            }
        };
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy