All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.enterprisemath.math.probability.DiagonalNormalDistribution Maven / Gradle / Ivy

The newest version!
package com.enterprisemath.math.probability;

import java.security.SecureRandom;

import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.apache.commons.lang3.builder.ToStringBuilder;

import com.enterprisemath.math.algebra.Hypercube;
import com.enterprisemath.math.algebra.Vector;
import com.enterprisemath.utils.ValidationUtils;

/**
 * Multidimensional normal distribution with diagonal covariance matrix.
 *
 * @author radek.hecl
 *
 */
public class DiagonalNormalDistribution implements ProbabilityDistribution {

    /**
     * Builder object.
     */
    public static class Builder {

        /**
         * Mi value of the distribution.
         */
        private Vector mi;

        /**
         * Sigma value of the distribution. Here are only the diagonal elements ofthe covariance matrix.
         */
        private Vector sigma;

        /**
         * Sets the value mi.
         *
         * @param mi mi value
         * @return this instance
         */
        public Builder setMi(Vector mi) {
            this.mi = mi;
            return this;
        }

        /**
         * Sets the value sigma. Here are only the diagonal elements ofthe covariance matrix.
         *
         * @param sigma sigma value
         * @return this instance
         */
        public Builder setSigma(Vector sigma) {
            this.sigma = sigma;
            return this;
        }

        /**
         * Creates the result object.
         *
         * @return created object
         */
        public DiagonalNormalDistribution build() {
            return new DiagonalNormalDistribution(this);
        }
    }

    /**
     * Object for generating random values.
     */
    private static final SecureRandom random = new SecureRandom();

    /**
     * Mi value of the distribution.
     */
    private Vector mi;

    /**
     * Sigma value of the distribution. Here are only the diagonal elements ofthe covariance matrix.
     */
    private Vector sigma;

    /**
     * This is cached constant calculated to speed up. Value is ln (1 / (sigma[i] * sqrt(2pi))).
     */
    private Vector lnFraction;

    /**
     * Creates new instance.
     *
     * @param builder builder object
     */
    public DiagonalNormalDistribution(Builder builder) {
        mi = builder.mi;
        sigma = builder.sigma;
        guardInvariants();

        double[] comps = new double[mi.getDimension()];
        for (int i = 0; i < mi.getDimension(); ++i) {
            comps[i] = -Math.log(sigma.getComponent(i)) - Math.log(Math.sqrt(2 * Math.PI));
            ValidationUtils.guardBoundedDouble(comps[i],
                    "-Math.log(sigma[i]) - Math.log(Math.sqrt(2 * Math.PI)) went out of range: sigma = " + sigma);
        }
        lnFraction = Vector.create(comps);
    }

    /**
     * Guards this object to be consistent. Throws exception if this is not the case.
     */
    private void guardInvariants() {
        ValidationUtils.guardEquals(mi.getDimension(), sigma.getDimension(), "mi and sigma must have same dimension");
        for (int i = 0; i < mi.getDimension(); ++i) {
            ValidationUtils.guardPositiveDouble(sigma.getComponent(i), "sigma must be positive");
            ValidationUtils.guardBoundedDouble(mi.getComponent(i), "mi must be bounded");
            ValidationUtils.guardBoundedDouble(sigma.getComponent(i), "sigma must be bounded");
        }
    }

    @Override
    public double getValue(Vector x) {
        ValidationUtils.guardEquals(mi.getDimension(), x.getDimension(), "x must have same dimension as mixture");
        double value = 0;
        for (int i = 0; i < mi.getDimension(); ++i) {
            value += lnFraction.getComponent(i) - sqr(x.getComponent(i) - mi.getComponent(i)) / (2 * sqr(sigma.getComponent(i)));
        }
        value = Math.exp(value);
        return value;
    }

    @Override
    public double getLnValue(Vector x) {
        double value = 0;
        for (int i = 0; i < mi.getDimension(); ++i) {
            value += lnFraction.getComponent(i) - sqr(x.getComponent(i) - mi.getComponent(i)) / (2 * sqr(sigma.getComponent(i)));
        }
        return value;
    }

    @Override
    public Vector generateRandom() {
        double[] res = new double[mi.getDimension()];
        for (int i = 0; i < mi.getDimension(); ++i) {
            res[i] = random.nextGaussian() * sigma.getComponent(i) + mi.getComponent(i);
        }
        return Vector.create(res);
    }

    /**
     * Returns dimension of this distribution.
     *
     * @return dimension of this distribution
     */
    public int getDimension() {
        return mi.getDimension();
    }

    /**
     * Returns the value of the parameter mi.
     *
     * @return value of the parameter mi
     */
    public Vector getMi() {
        return mi;
    }

    /**
     * Returns the value of the parameter sigma.
     *
     * @return value of the parameter sigma
     */
    public Vector getSigma() {
        return sigma;
    }

    /**
     * Returns hypercube with at least specified probability that random elements
     * generated by this distribution will be inside.
     *
     * @param min minimum probability, must be in interval [0, 1)
     * @return interval with at least specified probability that random elements generated by this distribution will be inside
     */
    public Hypercube getProbableHypercube(double min) {
        ValidationUtils.guardNotNegativeDouble(min, "min cannot be negative");
        ValidationUtils.guardGreaterDouble(1, min, "min must be less than 1");
        if (min < 0.6) {
            return Hypercube.createFromVectors(
                    createMiCenteredVector(-1 * mi.getDimension()), createMiCenteredVector(1 * mi.getDimension()));
        }
        else if (min < 0.9) {
            return Hypercube.createFromVectors(
                    createMiCenteredVector(-2 * mi.getDimension()), createMiCenteredVector(2 * mi.getDimension()));
        }
        else if (min < 0.95) {
            return Hypercube.createFromVectors(
                    createMiCenteredVector(-3 * mi.getDimension()), createMiCenteredVector(3 * mi.getDimension()));
        }
        else if (min < 0.98) {
            return Hypercube.createFromVectors(
                    createMiCenteredVector(-4 * mi.getDimension()), createMiCenteredVector(4 * mi.getDimension()));
        }
        else {
            return Hypercube.createFromVectors(
                    createMiCenteredVector(-10 * mi.getDimension()), createMiCenteredVector(10 * mi.getDimension()));
        }
    }

    /**
     * Returns the x * x.
     *
     * @param x value x
     * @return x * x
     */
    private double sqr(double x) {
        return x * x;
    }

    /**
     * Creates vector which has distance from mi value of the specified sigma length.
     *
     * @param scale scale of the sigma value
     * @return created vector
     */
    private Vector createMiCenteredVector(double scale) {
        double[] res = new double[mi.getDimension()];
        for (int i = 0; i < mi.getDimension(); ++i) {
            res[i] = mi.getComponent(i) + scale * sigma.getComponent(i);
        }
        return Vector.create(res);
    }

    @Override
    public int hashCode() {
        return HashCodeBuilder.reflectionHashCode(this);
    }

    @Override
    public boolean equals(Object obj) {
        return EqualsBuilder.reflectionEquals(this, obj);
    }

    @Override
    public String toString() {
        return ToStringBuilder.reflectionToString(this);
    }

    /**
     * Creates normal distribution.
     *
     * @param mi mi value
     * @param sigma sigma value as a diagonal elements of a covariance matrix
     * @return created distribution
     */
    public static DiagonalNormalDistribution create(Vector mi, Vector sigma) {
        return new DiagonalNormalDistribution.Builder().
                setMi(mi).
                setSigma(sigma).
                build();
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy