smile.stat.distribution.MultivariateGaussianDistribution Maven / Gradle / Ivy
/******************************************************************************
* Confidential Proprietary *
* (c) Copyright Haifeng Li 2011, All Rights Reserved *
******************************************************************************/
package smile.stat.distribution;
import smile.math.matrix.CholeskyDecomposition;
import smile.math.Math;
/**
* Multivariate Gaussian distribution.
*
* @see GaussianDistribution
*
* @author Haifeng Li
*/
public class MultivariateGaussianDistribution extends AbstractMultivariateDistribution implements MultivariateExponentialFamily {
private static final double LOG2PIE = Math.log(2 * Math.PI * Math.E);
double[] mu;
double[][] sigma;
boolean diagonal;
private int dim;
private double[][] sigmaInv;
private double[][] sigmaL;
private double sigmaDet;
private double pdfConstant;
private int numParameters;
/**
* Constructor. The distribution will have a diagonal covariance matrix of
* the same variance.
*
* @param mean mean vector.
* @param var variance.
*/
public MultivariateGaussianDistribution(double[] mean, double var) {
if (var <= 0) {
throw new IllegalArgumentException("Variance is not positive: " + var);
}
mu = new double[mean.length];
sigma = new double[mu.length][mu.length];
for (int i = 0; i < mu.length; i++) {
mu[i] = mean[i];
sigma[i][i] = var;
}
diagonal = true;
numParameters = mu.length + 1;
init();
}
/**
* Constructor. The distribution will have a diagonal covariance matrix.
* Each element has different variance.
*
* @param mean mean vector.
* @param var variance vector.
*/
public MultivariateGaussianDistribution(double[] mean, double[] var) {
if (mean.length != var.length) {
throw new IllegalArgumentException("Mean vector and covariance matrix have different dimension");
}
mu = new double[mean.length];
sigma = new double[mu.length][mu.length];
for (int i = 0; i < mu.length; i++) {
if (var[i] <= 0) {
throw new IllegalArgumentException("Variance is not positive: " + var[i]);
}
mu[i] = mean[i];
sigma[i][i] = var[i];
}
diagonal = true;
numParameters = 2 * mu.length;
init();
}
/**
* Constructor.
*
* @param mean mean vector.
* @param cov covariance matrix.
*/
public MultivariateGaussianDistribution(double[] mean, double[][] cov) {
if (mean.length != cov.length) {
throw new IllegalArgumentException("Mean vector and covariance matrix have different dimension");
}
mu = new double[mean.length];
sigma = new double[mean.length][mean.length];
for (int i = 0; i < mu.length; i++) {
mu[i] = mean[i];
System.arraycopy(cov[i], 0, sigma[i], 0, mu.length);
}
diagonal = false;
numParameters = mu.length + mu.length * (mu.length + 1) / 2;
init();
}
/**
* Constructor. Mean and covariance will be estimated from the data by MLE.
* @param data the training data.
*/
public MultivariateGaussianDistribution(double[][] data) {
this(data, false);
}
/**
* Constructor. Mean and covariance will be estimated from the data by MLE.
* @param data the training data.
* @param diagonal true if covariance matrix is diagonal.
*/
public MultivariateGaussianDistribution(double[][] data, boolean diagonal) {
this.diagonal = diagonal;
mu = Math.colMean(data);
if (diagonal) {
sigma = new double[data[0].length][data[0].length];
for (int i = 0; i < data.length; i++) {
for (int j = 0; j < mu.length; j++) {
sigma[j][j] += (data[i][j] - mu[j]) * (data[i][j] - mu[j]);
}
}
for (int j = 0; j < mu.length; j++) {
sigma[j][j] /= (data.length - 1);
}
} else {
sigma = Math.cov(data, mu);
}
numParameters = mu.length + mu.length * (mu.length + 1) / 2;
init();
}
/**
* Initialize the object.
*/
private void init() {
dim = mu.length;
CholeskyDecomposition cholesky = new CholeskyDecomposition(sigma);
sigmaInv = cholesky.inverse();
sigmaDet = cholesky.det();
sigmaL = cholesky.getL();
pdfConstant = (dim * Math.log(2 * Math.PI) + Math.log(sigmaDet)) / 2.0;
}
/**
* Returns true if the covariance matrix is diagonal.
* @return true if the covariance matrix is diagonal
*/
public boolean isDiagonal() {
return diagonal;
}
@Override
public int npara() {
return numParameters;
}
@Override
public double entropy() {
return (dim * LOG2PIE + Math.log(sigmaDet)) / 2;
}
@Override
public double[] mean() {
return mu;
}
@Override
public double[][] cov() {
return sigma;
}
/**
* Returns the scatter of distribution, which is defined as |Σ|.
*/
public double scatter() {
return sigmaDet;
}
@Override
public double logp(double[] x) {
if (x.length != dim) {
throw new IllegalArgumentException("Sample has different dimension.");
}
double[] v = x.clone();
Math.minus(v, mu);
double result = Math.xax(sigmaInv, v) / -2.0;
return result - pdfConstant;
}
@Override
public double p(double[] x) {
return Math.exp(logp(x));
}
/**
* Algorithm from Alan Genz (1992) Numerical Computation of
* Multivariate Normal Probabilities, Journal of Computational and
* Graphical Statistics, pp. 141-149.
*
* The difference between returned value and the true value of the
* CDF is less than 0.001 in 99.9% time. The maximum number of iterations
* is set to 10000.
*/
@Override
public double cdf(double[] x) {
if (x.length != dim) {
throw new IllegalArgumentException("Sample has different dimension.");
}
int Nmax = 10000;
double alph = GaussianDistribution.getInstance().quantile(0.999);
double errMax = 0.001;
double[] v = x.clone();
Math.minus(v, mu);
double p = 0.0;
double varSum = 0.0;
// d is always zero
double[] f = new double[dim];
f[0] = GaussianDistribution.getInstance().cdf(v[0] / sigmaL[0][0]);
double[] y = new double[dim];
double err = 2 * errMax;
int N;
for (N = 1; err > errMax && N <= Nmax; N++) {
double[] w = Math.random(dim - 1);
for (int i = 1; i < dim; i++) {
y[i - 1] = GaussianDistribution.getInstance().quantile(w[i - 1] * f[i - 1]);
double q = 0.0;
for (int j = 0; j < i; j++) {
q += sigmaL[i][j] * y[j];
}
f[i] = GaussianDistribution.getInstance().cdf((v[i] - q) / sigmaL[i][i]) * f[i - 1];
}
double del = (f[dim - 1] - p) / N;
p += del;
varSum = (N - 2) * varSum / N + del * del;
err = alph * Math.sqrt(varSum);
}
return p;
}
/**
* Generate a random multivariate Gaussian sample.
*/
public double[] rand() {
double[] spt = new double[mu.length];
for (int i = 0; i < mu.length; i++) {
double u, v, q;
do {
u = Math.random();
v = 1.7156 * (Math.random() - 0.5);
double x = u - 0.449871;
double y = Math.abs(v) + 0.386595;
q = x * x + y * (0.19600 * y - 0.25472 * x);
} while (q > 0.27597 && (q > 0.27846 || v * v > -4 * Math.log(u) * u * u));
spt[i] = v / u;
}
double[] pt = new double[sigmaL.length];
// pt = sigmaL * spt
for (int i = 0; i < pt.length; i++) {
for (int j = 0; j <= i; j++) {
pt[i] += sigmaL[i][j] * spt[j];
}
}
Math.plus(pt, mu);
return pt;
}
@Override
public MultivariateMixture.Component M(double[][] x, double[] posteriori) {
int n = x[0].length;
double alpha = 0.0;
double[] mean = new double[n];
double[][] cov = new double[n][n];
for (int k = 0; k < x.length; k++) {
alpha += posteriori[k];
for (int i = 0; i < n; i++) {
mean[i] += x[k][i] * posteriori[k];
}
}
for (int i = 0; i < mean.length; i++) {
mean[i] /= alpha;
}
if (diagonal) {
for (int k = 0; k < x.length; k++) {
for (int i = 0; i < n; i++) {
cov[i][i] += (x[k][i] - mean[i]) * (x[k][i] - mean[i]) * posteriori[k];
}
}
for (int i = 0; i < cov.length; i++) {
cov[i][i] /= alpha;
}
} else {
for (int k = 0; k < x.length; k++) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
cov[i][j] += (x[k][i] - mean[i]) * (x[k][j] - mean[j]) * posteriori[k];
}
}
}
for (int i = 0; i < cov.length; i++) {
for (int j = 0; j < cov[i].length; j++) {
cov[i][j] /= alpha;
}
// make sure the covariance matrix is positive definite.
cov[i][i] *= 1.00001;
}
}
MultivariateMixture.Component c = new MultivariateMixture.Component();
c.priori = alpha;
MultivariateGaussianDistribution g = new MultivariateGaussianDistribution(mean, cov);
g.diagonal = diagonal;
c.distribution = g;
return c;
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder("Multivariate Gaussian Distribution:\nmu = [");
for (int i = 0; i < mu.length; i++) {
builder.append(mu[i]).append(" ");
}
builder.setCharAt(builder.length() - 1, ']');
builder.append("\nSigma = [\n");
for (int i = 0; i < sigma.length; i++) {
builder.append('\t');
for (int j = 0; j < sigma[i].length; j++) {
builder.append(sigma[i][j]).append(" ");
}
builder.append('\n');
}
builder.append("\t]");
return builder.toString();
}
}