All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.commons.rng.sampling.distribution.DirichletSampler Maven / Gradle / Ivy

Go to download

The Apache Commons RNG Sampling module provides samplers for various distributions.

There is a newer version: 1.6
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.rng.sampling.distribution;

import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.SharedStateObjectSampler;

/**
 * Sampling from a Dirichlet
 * distribution.
 *
 * 

Sampling uses:

* *
    *
  • {@link UniformRandomProvider#nextLong()} *
  • {@link UniformRandomProvider#nextDouble()} *
* * @since 1.4 */ public abstract class DirichletSampler implements SharedStateObjectSampler { /** The minimum number of categories. */ private static final int MIN_CATGEORIES = 2; /** RNG (used for the toString() method). */ private final UniformRandomProvider rng; /** * Sample from a Dirichlet distribution with different concentration parameters * for each category. */ private static final class GeneralDirichletSampler extends DirichletSampler { /** Samplers for each category. */ private final SharedStateContinuousSampler[] samplers; /** * @param rng Generator of uniformly distributed random numbers. * @param samplers Samplers for each category. */ GeneralDirichletSampler(UniformRandomProvider rng, SharedStateContinuousSampler[] samplers) { super(rng); // Array is stored directly as it is generated within the DirichletSampler class this.samplers = samplers; } @Override protected int getK() { return samplers.length; } @Override protected double nextGamma(int i) { return samplers[i].sample(); } @Override public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) { final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length]; for (int i = 0; i < newSamplers.length; i++) { newSamplers[i] = samplers[i].withUniformRandomProvider(rng); } return new GeneralDirichletSampler(rng, newSamplers); } } /** * Sample from a symmetric Dirichlet distribution with the same concentration parameter * for each category. */ private static final class SymmetricDirichletSampler extends DirichletSampler { /** Number of categories. */ private final int k; /** Sampler for the categories. */ private final SharedStateContinuousSampler sampler; /** * @param rng Generator of uniformly distributed random numbers. * @param k Number of categories. * @param sampler Sampler for the categories. */ SymmetricDirichletSampler(UniformRandomProvider rng, int k, SharedStateContinuousSampler sampler) { super(rng); this.k = k; this.sampler = sampler; } @Override protected int getK() { return k; } @Override protected double nextGamma(int i) { return sampler.sample(); } @Override public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) { return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng)); } } /** * @param rng Generator of uniformly distributed random numbers. */ private DirichletSampler(UniformRandomProvider rng) { this.rng = rng; } /** {@inheritDoc} */ @Override public String toString() { return "Dirichlet deviate [" + rng.toString() + "]"; } @Override public double[] sample() { // Create Gamma(alpha_i, 1) deviates for all alpha final double[] y = new double[getK()]; double norm = 0; for (int i = 0; i < y.length; i++) { final double yi = nextGamma(i); norm += yi; y[i] = yi; } // Normalize by dividing by the sum of the samples norm = 1.0 / norm; // Detect an invalid normalization, e.g. cases of all zero samples if (!isNonZeroPositiveFinite(norm)) { // Sample again using recursion. // A stack overflow due to a broken RNG will eventually occur // rather than the alternative which is an infinite loop. return sample(); } // Normalise for (int i = 0; i < y.length; i++) { y[i] *= norm; } return y; } /** * Gets the number of categories. * * @return k */ protected abstract int getK(); /** * Create a gamma sample for the given category. * * @param category Category. * @return the sample */ protected abstract double nextGamma(int category); /** {@inheritDoc} */ // Redeclare the signature to return a DirichletSampler not a SharedStateObjectSampler @Override public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng); /** * Creates a new Dirichlet distribution sampler. * * @param rng Generator of uniformly distributed random numbers. * @param alpha Concentration parameters. * @return the sampler * @throws IllegalArgumentException if the number of concentration parameters * is less than 2; or if any concentration parameter is not strictly positive. */ public static DirichletSampler of(UniformRandomProvider rng, double... alpha) { validateNumberOfCategories(alpha.length); final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length]; for (int i = 0; i < samplers.length; i++) { samplers[i] = createSampler(rng, alpha[i]); } return new GeneralDirichletSampler(rng, samplers); } /** * Creates a new symmetric Dirichlet distribution sampler using the same concentration * parameter for each category. * * @param rng Generator of uniformly distributed random numbers. * @param k Number of categories. * @param alpha Concentration parameter. * @return the sampler * @throws IllegalArgumentException if the number of categories is * less than 2; or if the concentration parameter is not strictly positive. */ public static DirichletSampler symmetric(UniformRandomProvider rng, int k, double alpha) { validateNumberOfCategories(k); final SharedStateContinuousSampler sampler = createSampler(rng, alpha); return new SymmetricDirichletSampler(rng, k, sampler); } /** * Validate the number of categories. * * @param k Number of categories. * @throws IllegalArgumentException if the number of categories is * less than 2. */ private static void validateNumberOfCategories(int k) { if (k < MIN_CATGEORIES) { throw new IllegalArgumentException("Invalid number of categories: " + k); } } /** * Creates a gamma sampler for a category with the given concentration parameter. * * @param rng Generator of uniformly distributed random numbers. * @param alpha Concentration parameter. * @return the sampler * @throws IllegalArgumentException if the concentration parameter is not strictly positive. */ private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng, double alpha) { // Negation of logic will detect NaN if (!isNonZeroPositiveFinite(alpha)) { throw new IllegalArgumentException("Invalid concentration: " + alpha); } // Create a Gamma(shape=alpha, scale=1) sampler. if (alpha == 1) { // Special case // Gamma(shape=1, scale=1) == Exponential(mean=1) return ZigguratSampler.Exponential.of(rng); } return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1); } /** * Return true if the value is non-zero, positive and finite. * * @param x Value. * @return true if non-zero positive finite */ private static boolean isNonZeroPositiveFinite(double x) { return x > 0 && x < Double.POSITIVE_INFINITY; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy