All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.stat.distribution.MultivariateGaussianDistribution Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (c) 2010-2020 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation, either version 3 of
 * the License, or (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with Smile.  If not, see .
 ******************************************************************************/

package smile.stat.distribution;

import smile.math.MathEx;
import smile.math.blas.UPLO;
import smile.math.matrix.Matrix;

/**
 * Multivariate Gaussian distribution.
 *
 * @see GaussianDistribution
 *
 * @author Haifeng Li
 */
public class MultivariateGaussianDistribution implements MultivariateDistribution, MultivariateExponentialFamily {
    private static final long serialVersionUID = 2L;

    private static final double LOG2PIE = Math.log(2 * Math.PI * Math.E);

    /** The mean vector. */
    public final double[] mu;
    /** The covariance matrix. */
    public final Matrix sigma;
    /** True if the covariance matrix is diagonal. */
    public final boolean diagonal;

    private int dim;
    private Matrix sigmaInv;
    private Matrix sigmaL;
    private double sigmaDet;
    private double pdfConstant;
    private int length;

    /**
     * Constructor. The distribution will have a diagonal covariance matrix of
     * the same variance.
     *
     * @param mean mean vector.
     * @param variance variance.
     */
    public MultivariateGaussianDistribution(double[] mean, double variance) {
        if (variance <= 0) {
            throw new IllegalArgumentException("Variance is not positive: " + variance);
        }

        mu = new double[mean.length];
        sigma = new Matrix(mu.length, mu.length);
        for (int i = 0; i < mu.length; i++) {
            mu[i] = mean[i];
            sigma.set(i, i, variance);
        }

        diagonal = true;
        length = mu.length + 1;

        init();
    }

    /**
     * Constructor. The distribution will have a diagonal covariance matrix.
     * Each element has different variance.
     *
     * @param mean mean vector.
     * @param variance variance vector.
     */
    public MultivariateGaussianDistribution(double[] mean, double[] variance) {
        if (mean.length != variance.length) {
            throw new IllegalArgumentException("Mean vector and covariance matrix have different dimension");
        }

        mu = new double[mean.length];
        sigma = Matrix.diag(variance);
        for (int i = 0; i < mu.length; i++) {
            if (variance[i] <= 0) {
                throw new IllegalArgumentException("Variance is not positive: " + variance[i]);
            }

            mu[i] = mean[i];
        }

        diagonal = true;
        length = 2 * mu.length;

        init();
    }

    /**
     * Constructor.
     *
     * @param mean mean vector.
     * @param cov covariance matrix.
     */
    public MultivariateGaussianDistribution(double[] mean, Matrix cov) {
        if (mean.length != cov.nrows()) {
            throw new IllegalArgumentException("Mean vector and covariance matrix have different dimension");
        }

        mu = new double[mean.length];
        sigma = cov;
        for (int i = 0; i < mu.length; i++) {
            mu[i] = mean[i];
        }

        diagonal = false;
        length = mu.length + mu.length * (mu.length + 1) / 2;

        init();
    }

    /**
     * Estimates the mean and diagonal covariance by MLE.
     * @param data the training data.
     */
    public static MultivariateGaussianDistribution fit(double[][] data) {
        return fit(data, false);
    }

    /**
     * Estimates the mean and covariance by MLE.
     * @param data the training data.
     * @param diagonal true if covariance matrix is diagonal.
     */
    public static MultivariateGaussianDistribution fit(double[][] data, boolean diagonal) {
        double[] mu = MathEx.colMeans(data);
        int n = data.length;
        int d = mu.length;

        if (diagonal) {
            double[] variance = new double[d];
            for (int i = 0; i < n; i++) {
                double[] x = data[i];
                for (int j = 0; j < d; j++) {
                    variance[j] += (x[j] - mu[j]) * (x[j] - mu[j]);
                }
            }

            int n1 = n - 1;
            for (int j = 0; j < d; j++) {
                variance[j] /= n1;
            }

            return new MultivariateGaussianDistribution(mu, variance);
        } else {
            return new MultivariateGaussianDistribution(mu, new Matrix(MathEx.cov(data, mu)));
        }
    }

    /**
     * Initialize the object.
     */
    private void init() {
        dim = mu.length;
        sigma.uplo(UPLO.LOWER);
        Matrix.Cholesky cholesky = sigma.cholesky();
        sigmaInv = cholesky.inverse();
        sigmaDet = cholesky.det();
        sigmaL = cholesky.lu;
        pdfConstant = (dim * Math.log(2 * Math.PI) + Math.log(sigmaDet)) / 2.0;
    }

    @Override
    public int length() {
        return length;
    }

    @Override
    public double entropy() {
        return (dim * LOG2PIE + Math.log(sigmaDet)) / 2;
    }

    @Override
    public double[] mean() {
        return mu;
    }

    @Override
    public Matrix cov() {
        return sigma;
    }

    /**
     * Returns the scatter of distribution, which is defined as |Σ|.
     */
    public double scatter() {
        return sigmaDet;
    }

    @Override
    public double logp(double[] x) {
        if (x.length != dim) {
            throw new IllegalArgumentException("Sample has different dimension.");
        }

        double[] v = x.clone();
        MathEx.sub(v, mu);
        double result = sigmaInv.xAx(v) / -2.0;
        return result - pdfConstant;
    }

    @Override
    public double p(double[] x) {
        return Math.exp(logp(x));
    }

    /**
     * Algorithm from Alan Genz (1992) Numerical Computation of 
     * Multivariate Normal Probabilities, Journal of Computational and 
     * Graphical Statistics, pp. 141-149.
     *
     * The difference between returned value and the true value of the
     * CDF is less than 0.001 in 99.9% time. The maximum number of iterations
     * is set to 10000.
     */
    @Override
    public double cdf(double[] x) {
        if (x.length != dim) {
            throw new IllegalArgumentException("Sample has different dimension.");
        }

        int Nmax = 10000;
        double alph = GaussianDistribution.getInstance().quantile(0.999);
        double errMax = 0.001;

        double[] v = x.clone();
        MathEx.sub(v, mu);

        double p = 0.0;
        double varSum = 0.0;

        // d is always zero
        double[] e = new double[dim];
        double[] f = new double[dim];
        e[0] = GaussianDistribution.getInstance().cdf(v[0] / sigmaL.get(0, 0));
        f[0] = e[0];

        double[] y = new double[dim];

        double err = 2 * errMax;
        int N;
        for (N = 1; err > errMax && N <= Nmax; N++) {
            double[] w = MathEx.random(dim - 1);
            for (int i = 1; i < dim; i++) {
                y[i - 1] = GaussianDistribution.getInstance().quantile(w[i - 1] * e[i - 1]);
                double q = 0.0;
                for (int j = 0; j < i; j++) {
                    q += sigmaL.get(i, j) * y[j];
                }

                e[i] = GaussianDistribution.getInstance().cdf((v[i] - q) / sigmaL.get(i, i));
                f[i] = e[i] * f[i - 1];
            }

            double del = (f[dim - 1] - p) / N;
            p += del;
            varSum = (N - 2) * varSum / N + del * del;
            err = alph * Math.sqrt(varSum);
        }

        return p;
    }

    /**
     * Generate a random multivariate Gaussian sample.
     */
    public double[] rand() {
        double[] spt = new double[mu.length];

        for (int i = 0; i < mu.length; i++) {
            double u, v, q;
            do {
                u = MathEx.random();
                v = 1.7156 * (MathEx.random() - 0.5);
                double x = u - 0.449871;
                double y = Math.abs(v) + 0.386595;
                q = x * x + y * (0.19600 * y - 0.25472 * x);
            } while (q > 0.27597 && (q > 0.27846 || v * v > -4 * Math.log(u) * u * u));

            spt[i] = v / u;
        }

        double[] pt = new double[sigmaL.nrows()];

        // pt = sigmaL * spt
        for (int i = 0; i < pt.length; i++) {
            for (int j = 0; j <= i; j++) {
                pt[i] += sigmaL.get(i, j) * spt[j];
            }
        }

        MathEx.add(pt, mu);

        return pt;
    }

    /**
     * Generates a set of random numbers following this distribution.
     */
    public double[][] rand(int n) {
        double[][] data = new double[n][];
        for (int i = 0; i < n; i++) {
            data[i] = rand();
        }
        return data;
    }

    @Override
    public MultivariateMixture.Component M(double[][] data, double[] posteriori) {
        int n = data.length;
        int d = data[0].length;

        double alpha = 0.0;
        double[] mean = new double[d];

        for (int k = 0; k < n; k++) {
            alpha += posteriori[k];
            double[] x = data[k];
            for (int i = 0; i < d; i++) {
                mean[i] += x[i] * posteriori[k];
            }
        }

        for (int i = 0; i < d; i++) {
            mean[i] /= alpha;
        }

        MultivariateGaussianDistribution gaussian;
        if (diagonal) {
            double[] variance = new double[d];
            for (int k = 0; k < n; k++) {
                double[] x = data[k];
                for (int i = 0; i < d; i++) {
                    variance[i] += (x[i] - mean[i]) * (x[i] - mean[i]) * posteriori[k];
                }
            }

            for (int i = 0; i < d; i++) {
                variance[i] /= alpha;
            }

            gaussian = new MultivariateGaussianDistribution(mean, new Matrix(variance));
        } else {
            Matrix cov = new Matrix(d, d);
            for (int k = 0; k < n; k++) {
                double[] x = data[k];
                for (int i = 0; i < d; i++) {
                    for (int j = 0; j < d; j++) {
                        cov.add(i, j, (x[i] - mean[i]) * (x[j] - mean[j]) * posteriori[k]);
                    }
                }
            }

            for (int i = 0; i < d; i++) {
                for (int j = 0; j < d; j++) {
                    cov.div(i, j, alpha);
                }

                // make sure the covariance matrix is positive definite.
                cov.mul(i, i, 1.00001);
            }

            gaussian = new MultivariateGaussianDistribution(mean, cov);
        }

        return new MultivariateMixture.Component(alpha, gaussian);
    }

    @Override
    public String toString() {
        StringBuilder builder = new StringBuilder("Multivariate Gaussian Distribution:\nmu = [");
        for (int i = 0; i < mu.length; i++) {
            builder.append(mu[i]).append(" ");
        }
        builder.setCharAt(builder.length() - 1, ']');
        builder.append("\nSigma = [\n");
        for (int i = 0; i < sigma.nrows(); i++) {
            builder.append('\t');
            for (int j = 0; j < sigma.ncols(); j++) {
                builder.append(sigma.get(i, j)).append(" ");
            }
            builder.append('\n');
        }
        builder.append("\t]");
        return builder.toString();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy