All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.stat.distribution.BetaDistribution Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 *   
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *  
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package smile.stat.distribution;

import smile.math.special.Beta;
import smile.math.special.Gamma;
import smile.math.Math;

/**
 * The beta distribution is defined on the interval [0, 1] parameterized by
 * two positive shape parameters, typically denoted by α and β.
 * It is the special case of the Dirichlet distribution with only two parameters.
 * The beta distribution is used as a prior distribution for binomial
 * proportions in Bayesian analysis. In Bayesian statistics, it can be seen as
 * the posterior distribution of the parameter α of a binomial distribution
 * after observing α - 1 independent events with probability α and β - 1 with
 * probability 1 - α, if the prior distribution of α was uniform.
 * If α = 1 and β =1, the Beta distribution is the uniform [0, 1] distribution.
 * The probability density function of the beta distribution is
 * f(x;α,β) = xα-1(1-x)β-1 / B(α,β)
 * where B(α,β) is the beta function.
 * @author Haifeng Li
 */
public class BetaDistribution extends AbstractDistribution implements ExponentialFamily {

    private double alpha;
    private double beta;
    private double mean;
    private double var;
    private double entropy;
    private RejectionLogLogistic rng;

    /**
     * Constructor.
     * @param alpha shape parameter.
     * @param beta shape parameter.
     */
    public BetaDistribution(double alpha, double beta) {
        if (alpha <= 0) {
            throw new IllegalArgumentException("Invalid alpha: " + alpha);
        }

        if (beta <= 0) {
            throw new IllegalArgumentException("Invalid beta: " + beta);
        }

        this.alpha = alpha;
        this.beta = beta;

        mean = alpha / (alpha + beta);
        var = alpha * beta / ((alpha + beta) * (alpha + beta) * (alpha + beta + 1));
        entropy = Math.log(Beta.beta(alpha, beta)) - (alpha - 1) * Gamma.digamma(alpha) - (beta - 1) * Gamma.digamma(beta) + (alpha + beta - 2) * Gamma.digamma(alpha + beta);
    }

    /**
     * Construct an Beta from the given samples. Parameter
     * will be estimated from the data by the moment method.
     */
    public BetaDistribution(double[] data) {
        for (int i = 0; i < data.length; i++) {
            if (data[i] < 0 || data[i] > 1) {
                throw new IllegalArgumentException("Samples are not in range [0, 1].");
            }
        }

        mean = Math.mean(data);
        var = Math.var(data);

        alpha = mean * (mean * (1 - mean) / var - 1);
        beta = (1 - mean) * (mean * (1 - mean) / var - 1);
        if (alpha <= 0 || beta <= 0) {
            throw new IllegalArgumentException("Samples don't follow Beta Distribution.");
        }

        mean = alpha / (alpha + beta);
        var = alpha * beta / ((alpha + beta) * (alpha + beta) * (alpha + beta + 1));
        entropy = Math.log(Beta.beta(alpha, beta)) - (alpha - 1) * Gamma.digamma(alpha) - (beta - 1) * Gamma.digamma(beta) + (alpha + beta - 2) * Gamma.digamma(alpha + beta);
    }

    /**
     * Returns the shape parameter alpha.
     * @return the shape parameter alpha
     */
    public double getAlpha() {
        return alpha;
    }

    /**
     * Returns the shape parameter beta.
     * @return the shape parameter beta
     */
    public double getBeta() {
        return beta;
    }

    @Override
    public int npara() {
        return 2;
    }

    @Override
    public double mean() {
        return mean;
    }

    @Override
    public double var() {
        return var;
    }

    @Override
    public double sd() {
        return Math.sqrt(var);
    }

    @Override
    public double entropy() {
        return entropy;
    }

    @Override
    public String toString() {
        return String.format("Beta Distribution(%.4f, %.4f)", alpha, beta);
    }

    @Override
    public double p(double x) {
        if (x < 0 || x > 1) {
            return 0.0;
        } else {
            return Math.pow(x, alpha - 1) * Math.pow(1 - x, beta - 1) / Beta.beta(alpha, beta);
        }
    }

    @Override
    public double logp(double x) {
        if (x < 0 || x > 1) {
            return Double.NEGATIVE_INFINITY;
        } else {
            return (alpha - 1) * Math.log(x) + (beta - 1) * Math.log(1 - x) - Math.log(Beta.beta(alpha, beta));
        }
    }

    @Override
    public double cdf(double x) {
        if (x <= 0) {
            return 0.0;
        } else if (x >= 1) {
            return 1.0;
        } else {
            return Beta.regularizedIncompleteBetaFunction(alpha, beta, x);
        }
    }

    @Override
    public double quantile(double p) {
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("Invalid p: " + p);
        }

        return Beta.inverseRegularizedIncompleteBetaFunction(alpha, beta, p);
    }

    @Override
    public Mixture.Component M(double[] x, double[] posteriori) {
        double weight = 0.0;
        double mu = 0.0;
        double v = 0.0;

        for (int i = 0; i < x.length; i++) {
            weight += posteriori[i];
            mu += x[i] * posteriori[i];
        }

        mu /= weight;

        for (int i = 0; i < x.length; i++) {
            double d = x[i] - mu;
            v += d * d * posteriori[i];
        }

        v = v / weight;

        double a = mu * (mu * (1 - mu) / v - 1);
        double b = (1 - mu) * (mu * (1 - mu) / v - 1);

        Mixture.Component c = new Mixture.Component();
        c.priori = weight;
        c.distribution = new BetaDistribution(a, b);

        return c;
    }

    @Override
    public double rand() {
        if (rng == null) {
            rng = new RejectionLogLogistic();
        }

        return rng.rand();
    }

    /**
     * Implements Beta random variate generators using
     * the rejection method with log-logistic envelopes.
     * The method draws the first two uniforms from the main stream
     * and uses the auxiliary stream for the remaining uniforms,
     * when more than two are needed (i.e., when rejection occurs).
     *
     */
    class RejectionLogLogistic {

        private static final int BB = 0;
        private static final int BC = 1;
        private int method;
        private double am;
        private double bm;
        private double al;
        private double alnam;
        private double be;
        private double ga;
        private double si;
        private double rk1;
        private double rk2;

        /**
         * Creates a beta random variate generator.
         */
        public RejectionLogLogistic() {
            if (alpha > 1.0 && beta > 1.0) {
                method = BB;
                am = (alpha < beta) ? alpha : beta;
                bm = (alpha > beta) ? alpha : beta;
                al = am + bm;
                be = Math.sqrt((al - 2.0) / (2.0 * alpha * beta - al));
                ga = am + 1.0 / be;
            } else {
                method = BC;
                am = (alpha > beta) ? alpha : beta;
                bm = (alpha < beta) ? alpha : beta;
                al = am + bm;
                alnam = al * Math.log(al / am) - 1.386294361;
                be = 1.0 / bm;
                si = 1.0 + am - bm;
                rk1 = si * (0.013888889 + 0.041666667 * bm) / (am * be - 0.77777778);
                rk2 = 0.25 + (0.5 + 0.25 / si) * bm;
            }
        }

        public double rand() {
            double X = 0.0;
            double u1, u2, v, w, y, z, r, s, t;
            switch (method) {
                case BB:
                    /* -X- generator code -X- */
                    while (true) {
                        /* Step 1 */
                        u1 = Math.random();
                        u2 = Math.random();
                        v = be * Math.log(u1 / (1.0 - u1));
                        w = am * Math.exp(v);
                        z = u1 * u1 * u2;
                        r = ga * v - 1.386294361;
                        s = am + r - w;

                        /* Step 2 */
                        if (s + 2.609437912 < 5.0 * z) {
                            /* Step 3 */
                            t = Math.log(z);
                            if (s < t) /* Step 4 */ {
                                if (r + al * Math.log(al / (bm + w)) < t) {
                                    continue;
                                }
                            }
                        }

                        /* Step 5 */
                        X = Math.equals(am, alpha) ? w / (bm + w) : bm / (bm + w);
                        break;
                    }
                    /* -X- end of generator code -X- */
                    break;

                case BC:
                    while (true) {
                        /* Step 1 */
                        u1 = Math.random();
                        u2 = Math.random();

                        if (u1 < 0.5) {
                            /* Step 2 */
                            y = u1 * u2;
                            z = u1 * y;

                            if ((0.25 * u2 - y + z) >= rk1) {
                                continue;  /* goto 1 */
                            }

                            /* Step 5 */
                            v = be * Math.log(u1 / (1.0 - u1));
                            if (v > 80.0) {
                                if (alnam < Math.log(z)) {
                                    continue;
                                }
                                X = Math.equals(am, alpha) ? 1.0 : 0.0;
                                break;
                            } else {
                                w = am * Math.exp(v);
                                if ((al * (Math.log(al / (bm + w)) + v) - 1.386294361) <
                                        Math.log(z)) {
                                    continue;  /* goto 1 */
                                }

                                /* Step 6_a */
                                X = !Math.equals(am, alpha) ? bm / (bm + w) : w / (bm + w);
                                break;
                            }
                        } else {
                            /* Step 3 */
                            z = u1 * u1 * u2;
                            if (z < 0.25) {
                                /* Step 5 */
                                v = be * Math.log(u1 / (1.0 - u1));
                                if (v > 80.0) {
                                    X = Math.equals(am, alpha) ? 1.0 : 0.0;
                                    break;
                                }

                                w = am * Math.exp(v);
                                X = !Math.equals(am, alpha) ? bm / (bm + w) : w / (bm + w);
                                break;
                            } else {
                                if (z >= rk2) {
                                    continue;
                                }
                                v = be * Math.log(u1 / (1.0 - u1));
                                if (v > 80.0) {
                                    if (alnam < Math.log(z)) {
                                        continue;
                                    }
                                    X = Math.equals(am, alpha) ? 1.0 : 0.0;
                                    break;
                                }
                                w = am * Math.exp(v);
                                if ((al * (Math.log(al / (bm + w)) + v) - 1.386294361) <
                                        Math.log(z)) {
                                    continue;  /* goto 1 */
                                }

                                /* Step 6_b */
                                X = !Math.equals(am, alpha) ? bm / (bm + w) : w / (bm + w);
                                break;
                            }
                        }
                    }
                    break;
                default:
                    throw new IllegalStateException();
            }

            return X;
        }
    }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy