All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.improbable.keanu.Keanu Maven / Gradle / Ivy

package io.improbable.keanu;

import io.improbable.keanu.algorithms.PosteriorSamplingAlgorithm;
import io.improbable.keanu.algorithms.graphtraversal.DifferentiableChecker;
import io.improbable.keanu.algorithms.mcmc.RollBackToCachedValuesOnRejection;
import io.improbable.keanu.algorithms.mcmc.proposal.PriorProposalDistribution;
import io.improbable.keanu.algorithms.variational.optimizer.gradient.GradientOptimizer;
import io.improbable.keanu.algorithms.variational.optimizer.nongradient.NonGradientOptimizer;
import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.network.KeanuProbabilisticModel;
import io.improbable.keanu.network.KeanuProbabilisticModelWithGradient;
import io.improbable.keanu.vertices.Vertex;
import lombok.experimental.UtilityClass;

import java.util.Collection;
import java.util.List;
import java.util.Set;

/**
 * The entry point for creating {@link PosteriorSamplingAlgorithm}s such as {@link Sampling.MetropolisHastings} and {@link Sampling.NUTS}
 */
@UtilityClass
public class Keanu {

    @UtilityClass
    public static class Sampling {

        @UtilityClass
        /**
         * Class for choosing the appropriate sampling algorithm given a network.
         * If the given network is differentiable, NUTS is proposed, otherwise Metropolis Hastings is chosen.
         *
         * Usage:
         * PosteriorSamplingAlgorithm samplingAlgorithm = Keanu.Sampling.MCMC.withDefaultConfig(yourModel);
         * samplingAlgorithm.getPosteriorSamples(...);
         */
        public static class MCMC {

            /**
             * @param model network for which to choose sampling algorithm.
             * @return recommended sampling algorithm for this network.
             */
            public PosteriorSamplingAlgorithm withDefaultConfigFor(KeanuProbabilisticModel model) {
                return withDefaultConfigFor(model, KeanuRandom.getDefaultRandom());
            }

            /**
             * @param model  network for which to choose sampling algorithm.
             * @param random the random number generator.
             * @return recommended sampling algorithm for this network.
             */
            public PosteriorSamplingAlgorithm withDefaultConfigFor(KeanuProbabilisticModel model, KeanuRandom random) {
                if (DifferentiableChecker.isDifferentiableWrtLatents(model.getLatentOrObservedVertices())) {
                    return Keanu.Sampling.NUTS.withDefaultConfig(random);
                } else {
                    return Keanu.Sampling.MetropolisHastings.withDefaultConfig(random);
                }
            }
        }

        @UtilityClass
        public static class MetropolisHastings {

            public static io.improbable.keanu.algorithms.mcmc.MetropolisHastings withDefaultConfig() {
                return withDefaultConfig(KeanuRandom.getDefaultRandom());
            }

            public static io.improbable.keanu.algorithms.mcmc.MetropolisHastings withDefaultConfig(KeanuRandom random) {
                return builder()
                    .proposalDistribution(new PriorProposalDistribution())
                    .rejectionStrategy(new RollBackToCachedValuesOnRejection())
                    .random(random)
                    .build();
            }

            public static io.improbable.keanu.algorithms.mcmc.MetropolisHastings.MetropolisHastingsBuilder builder() {
                return io.improbable.keanu.algorithms.mcmc.MetropolisHastings.builder();
            }
        }

        @UtilityClass
        public static class NUTS {

            public static io.improbable.keanu.algorithms.mcmc.nuts.NUTS withDefaultConfig() {
                return withDefaultConfig(KeanuRandom.getDefaultRandom());
            }

            public static io.improbable.keanu.algorithms.mcmc.nuts.NUTS withDefaultConfig(KeanuRandom random) {
                return builder()
                    .random(random)
                    .build();
            }

            public static io.improbable.keanu.algorithms.mcmc.nuts.NUTS.NUTSBuilder builder() {
                return io.improbable.keanu.algorithms.mcmc.nuts.NUTS.builder();
            }
        }

        @UtilityClass
        public static class Forward {

            public static io.improbable.keanu.algorithms.sampling.Forward withDefaultConfig() {
                return new io.improbable.keanu.algorithms.sampling.Forward(KeanuRandom.getDefaultRandom(), false);
            }

            public static io.improbable.keanu.algorithms.sampling.Forward withDefaultConfig(KeanuRandom random) {
                return new io.improbable.keanu.algorithms.sampling.Forward(random, false);
            }

            public static io.improbable.keanu.algorithms.sampling.Forward.ForwardBuilder builder() {
                return io.improbable.keanu.algorithms.sampling.Forward.builder();
            }
        }

        @UtilityClass
        public static class SimulatedAnnealing {

            public static io.improbable.keanu.algorithms.mcmc.SimulatedAnnealing withDefaultConfig() {
                return withDefaultConfig(KeanuRandom.getDefaultRandom());
            }

            public static io.improbable.keanu.algorithms.mcmc.SimulatedAnnealing withDefaultConfig(KeanuRandom random) {
                return builder()
                    .proposalDistribution(new PriorProposalDistribution())
                    .rejectionStrategy(new RollBackToCachedValuesOnRejection())
                    .random(random)
                    .build();
            }

            public static io.improbable.keanu.algorithms.mcmc.SimulatedAnnealing.SimulatedAnnealingBuilder builder() {
                return io.improbable.keanu.algorithms.mcmc.SimulatedAnnealing.builder();
            }
        }
    }

    @UtilityClass
    public static class Optimizer {

        /**
         * Creates a Bayesian network from the given vertices and uses this to
         * create an {@link io.improbable.keanu.algorithms.variational.optimizer.Optimizer}. This provides methods for optimizing the values of latent variables
         * of the Bayesian network to maximise probability.
         *
         * @param vertices The vertices to create a Bayesian network from.
         * @return an {@link io.improbable.keanu.algorithms.variational.optimizer.Optimizer}
         */
        public io.improbable.keanu.algorithms.variational.optimizer.Optimizer of(Collection vertices) {
            return of(new BayesianNetwork(vertices));
        }

        /**
         * Creates an {@link io.improbable.keanu.algorithms.variational.optimizer.Optimizer} which provides methods for optimizing the values of latent variables
         * of the Bayesian network to maximise probability.
         *
         * @param network The Bayesian network to run optimization on.
         * @return an {@link io.improbable.keanu.algorithms.variational.optimizer.Optimizer}
         */
        public io.improbable.keanu.algorithms.variational.optimizer.Optimizer of(BayesianNetwork network) {
            if (DifferentiableChecker.isDifferentiableWrtLatents(network.getLatentOrObservedVertices())) {
                return Gradient.of(network);
            } else {
                return NonGradient.of(network);
            }
        }

        /**
         * Creates a Bayesian network from the graph connected to the given vertex and uses this to
         * create an {@link io.improbable.keanu.algorithms.variational.optimizer.Optimizer}. This provides methods for optimizing the values of latent variables
         * of the Bayesian network to maximise probability.
         *
         * @param vertexFromNetwork A vertex in the graph to create the Bayesian network from.
         * @return an {@link io.improbable.keanu.algorithms.variational.optimizer.Optimizer}
         */
        public io.improbable.keanu.algorithms.variational.optimizer.Optimizer ofConnectedGraph(Vertex vertexFromNetwork) {
            return of(vertexFromNetwork.getConnectedGraph());
        }

        @UtilityClass
        public class NonGradient {

            /**
             * Creates a BOBYQA {@link NonGradientOptimizer} which provides methods for optimizing the values of latent variables
             * of the Bayesian network to maximise probability.
             *
             * @param bayesNet The Bayesian network to run optimization on.
             * @return a {@link NonGradientOptimizer}
             */
            public NonGradientOptimizer of(BayesianNetwork bayesNet) {
                bayesNet.cascadeObservations();
                return builderFor(bayesNet).build();
            }

            /**
             * Creates a Bayesian network from the given vertices and uses this to
             * create a BOBYQA {@link NonGradientOptimizer}. This provides methods for optimizing the
             * values of latent variables of the Bayesian network to maximise probability.
             *
             * @param vertices The vertices to create a Bayesian network from.
             * @return a {@link NonGradientOptimizer}
             */
            public NonGradientOptimizer of(Collection vertices) {
                return of(new BayesianNetwork(vertices));
            }

            /**
             * Creates a Bayesian network from the graph connected to the given vertex and uses this to
             * create a BOBYQA {@link NonGradientOptimizer}. This provides methods for optimizing the
             * values of latent variables of the Bayesian network to maximise probability.
             *
             * @param vertexFromNetwork A vertex in the graph to create the Bayesian network from
             * @return a {@link NonGradientOptimizer}
             */
            public NonGradientOptimizer ofConnectedGraph(Vertex vertexFromNetwork) {
                return of(vertexFromNetwork.getConnectedGraph());
            }

            public NonGradientOptimizer.NonGradientOptimizerBuilder builderFor(Collection vertices) {
                return builderFor(new BayesianNetwork(vertices));
            }

            public NonGradientOptimizer.NonGradientOptimizerBuilder builderFor(BayesianNetwork network) {
                initializeNetworkForOptimization(network);
                return NonGradientOptimizer.builder().probabilisticModel(new KeanuProbabilisticModel(network));
            }

        }

        @UtilityClass
        public class Gradient {

            /**
             * Creates a {@link GradientOptimizer} which provides methods for optimizing the values of latent variables
             * of the Bayesian network to maximise probability.
             *
             * @param bayesNet The Bayesian network to run optimization on.
             * @return a {@link GradientOptimizer}
             */
            public GradientOptimizer of(BayesianNetwork bayesNet) {
                return builderFor(bayesNet).build();
            }

            /**
             * Creates a Bayesian network from the given vertices and uses this to
             * create a {@link GradientOptimizer}. This provides methods for optimizing the values of latent variables
             * of the Bayesian network to maximise probability.
             *
             * @param vertices The vertices to create a Bayesian network from.
             * @return a {@link GradientOptimizer}
             */
            public GradientOptimizer of(Collection vertices) {
                return of(new BayesianNetwork(vertices));
            }


            /**
             * Creates a Bayesian network from the graph connected to the given vertex and uses this to
             * create a {@link GradientOptimizer}. This provides methods for optimizing the values of latent variables
             * of the Bayesian network to maximise probability.
             *
             * @param vertexFromNetwork A vertex in the graph to create the Bayesian network from
             * @return a {@link GradientOptimizer}
             */
            public GradientOptimizer ofConnectedGraph(Vertex vertexFromNetwork) {
                return of(vertexFromNetwork.getConnectedGraph());
            }

            public GradientOptimizer.GradientOptimizerBuilder builderFor(Set connectedGraph) {
                return builderFor(new BayesianNetwork(connectedGraph));
            }

            public GradientOptimizer.GradientOptimizerBuilder builderFor(BayesianNetwork network) {
                initializeNetworkForOptimization(network);
                return GradientOptimizer.builder().probabilisticModel(new KeanuProbabilisticModelWithGradient(network));
            }
        }

        void initializeNetworkForOptimization(BayesianNetwork bayesianNetwork) {
            List discreteLatentVertices = bayesianNetwork.getDiscreteLatentVertices();
            boolean containsDiscreteLatents = !discreteLatentVertices.isEmpty();

            if (containsDiscreteLatents) {
                throw new UnsupportedOperationException(
                    "Optimization unsupported on networks containing discrete latents. " +
                        "Found " + discreteLatentVertices.size() + " discrete latents.");
            }

            bayesianNetwork.cascadeObservations();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy