All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.enterprisemath.math.statistics.DiagonalNormalDistributionMixtureEstimator Maven / Gradle / Ivy

The newest version!
package com.enterprisemath.math.statistics;

import org.apache.commons.lang3.builder.ToStringBuilder;

import com.enterprisemath.math.algebra.Hypercube;
import com.enterprisemath.math.algebra.Vector;
import com.enterprisemath.math.probability.DiagonalNormalDistribution;
import com.enterprisemath.math.probability.DiagonalNormalDistributionMixture;
import com.enterprisemath.math.statistics.observation.ObservationIterator;
import com.enterprisemath.math.statistics.observation.ObservationProvider;
import com.enterprisemath.utils.ValidationUtils;

/**
 * This class is responsible for estimating the diagonal normal distribution mixture.
 * Purpose of this class is to let the user to set high level limits and let the algorithm to do the rest.
 * The whole algorithm works in the way that whole estimation starts with one component.
 * Every iteration one component is selected and split. Then classic EM algorithm is invoked.
 *
 * @author radek.hecl
 *
 */
public class DiagonalNormalDistributionMixtureEstimator implements Estimator {

    /**
     * Constant for truncating components.
     */
    private static final double COMP_TRUNC = Math.log(Double.MIN_VALUE);

    /**
     * Builder object.
     */
    public static class Builder {

        /**
         * Maximum allowed number of components.
         */
        private Integer maxComponents = 20;

        /**
         * The minimum weight which is allowed for a component during estimation.
         * Every component which has less than the minWeight will be removed.
         * This essentially means that group of observations statistically less important than minWeight
         * might (but not necessary are) be ignored.
         * Value must be in interval [0, 1).
         */
        private Double minWeight = 0.05;

        /**
         * Minimum allowed sigma.
         */
        private Double minSigma = 0.01;

        /**
         * Step listener.
         */
        private EstimatorStepListener stepListener = EmptyEstimatorStepListener.create();

        /**
         * Sets maximum allowed number of components.
         *
         * @param maxComponents maximum of number of components in the result
         * @return this instance
         */
        public Builder setMaxComponents(int maxComponents) {
            this.maxComponents = maxComponents;
            return this;
        }

        /**
         * Sets minimal weight for the components.
         * Every component with weight less than minWeight will be removed from the result.
         * This essentially means that group of observations statistically less important than minWeight
         * might (but not necessary are) be ignored.
         * Value must be in interval [0, 1).
         *
         * @param minWeight component minimal weight value
         * @return this instance
         */
        public Builder setMinWeight(double minWeight) {
            this.minWeight = minWeight;
            return this;
        }

        /**
         * Sets minimal allowed sigma for all components.
         *
         * @param minSigma minimal allowed sigma for all components
         * @return this instance
         */
        public Builder setMinSigma(Double minSigma) {
            this.minSigma = minSigma;
            return this;
        }

        /**
         * Sets step listener.
         *
         * @param stepListener step listener
         * @return this instance
         */
        public Builder setStepListener(EstimatorStepListener stepListener) {
            this.stepListener = stepListener;
            return this;
        }

        /**
         * Builds the result object.
         *
         * @return created object
         */
        public DiagonalNormalDistributionMixtureEstimator build() {
            return new DiagonalNormalDistributionMixtureEstimator(this);
        }
    }

    /**
     * Maximum allowed number of components.
     */
    private Integer maxComponents;

    /**
     * The minimum weight which is allowed for a component during estimation.
     * Every component which has less than the minWeight will be removed.
     * This essentially means that group of observations statistically less important than minWeight
     * might (but not necessary are) be ignored.
     * Value must be in interval [0, 1).
     */
    private Double minWeight;

    /**
     * Minimum allowed sigma.
     */
    private Double minSigma;

    /**
     * Step listener.
     */
    private EstimatorStepListener stepListener;

    /**
     * Creates new instance.
     *
     * @param builder builder object
     */
    public DiagonalNormalDistributionMixtureEstimator(Builder builder) {
        minWeight = builder.minWeight;
        maxComponents = builder.maxComponents;
        minSigma = builder.minSigma;
        stepListener = builder.stepListener;
        guardInvariants();
    }

    /**
     * Guards this object to be consistent. Throws exception if this is not the case.
     */
    private void guardInvariants() {
        ValidationUtils.guardPositiveInt(maxComponents, "maxComponents must be positive");
        ValidationUtils.guardNotNegativeDouble(minWeight, "minWeight cannot be negative");
        ValidationUtils.guardGreaterDouble(1, minWeight, "minWeight must be less than 1");
        ValidationUtils.guardPositiveDouble(minSigma, "minSigma must be positive");
        ValidationUtils.guardNotNull(stepListener, "stepListener cannot be null");
    }

    @Override
    public DiagonalNormalDistributionMixture estimate(ObservationProvider observations) {
        Hypercube minMax = extractHypercube(observations);
        for (int i = 0; i < minMax.getDimension(); ++i) {
            ValidationUtils.guardGreaterOrEqualDouble(minMax.getMin().getComponent(i), -1000000,
                    "observation is out of range for calcualtion");
            ValidationUtils.guardGreaterOrEqualDouble(1000000, minMax.getMax().getComponent(i),
                    "observation is out of range for calcualtion");
        }

        DiagonalNormalDistributionMixture res = initializeOneCompoenent(observations, minMax.getDimension());
        stepListener.stepDone(res);
        double resL = Double.NEGATIVE_INFINITY;
        double newL = countLnL(observations, res);

        int iteration = 0;
        while (newL - resL > 0.01 && res.getNumComponents() < maxComponents && iteration < 100) {
            ++iteration;
            resL = newL;
            // find the maximum weight
            double splitValue = 0;
            int splitCompIdx = 0;
            int splitDimIdx = 0;
            for (int i = 0; i < res.getNumComponents(); ++i) {
                DiagonalNormalDistribution comp = res.getComponents().get(i);
                for (int j = 0; j < comp.getDimension(); ++j) {
                    if (res.getWeights().get(i) * res.getComponents().get(i).getSigma().getComponent(j) > splitValue) {
                        splitValue = res.getWeights().get(i) * res.getComponents().get(i).getSigma().getComponent(j);
                        splitCompIdx = i;
                        splitDimIdx = j;
                    }
                }
            }
            // split the component with highest weight
            DiagonalNormalDistributionMixture.Builder builder = new DiagonalNormalDistributionMixture.Builder();
            for (int i = 0; i < res.getNumComponents(); ++i) {
                if (i == splitCompIdx) {
                    DiagonalNormalDistribution comp = res.getComponents().get(i);
                    double[] mi1 = new double[res.getDimension()];
                    double[] mi2 = new double[res.getDimension()];
                    for (int j = 0; j < comp.getDimension(); ++j) {
                        if (j == splitDimIdx) {
                            mi1[j] = comp.getMi().getComponent(j) + comp.getSigma().getComponent(j) / 2;
                            mi2[j] = comp.getMi().getComponent(j) - comp.getSigma().getComponent(j) / 2;
                        }
                        else {
                            mi1[j] = comp.getMi().getComponent(j);
                            mi2[j] = comp.getMi().getComponent(j);
                        }
                    }
                    builder.addComponent(res.getWeights().get(i) / 2, Vector.create(mi1), comp.getSigma());
                    builder.addComponent(res.getWeights().get(i) / 2, Vector.create(mi2), comp.getSigma());
                }
                else {
                    if (res.getWeights().get(i) >= minWeight) {
                        builder.addComponent(res.getWeights().get(i), res.getComponents().get(i));
                    }
                }
            }
            DiagonalNormalDistributionMixture newMixture = builder.build();
            // iterations for the new mixture

            double help = Double.NEGATIVE_INFINITY;
            int emiteration = 0;
            while (emiteration < 5 || newL - help > 0.01) {
                ++emiteration;
                help = newL;
                newMixture = nextIteration(observations, newMixture);
                newL = countLnL(observations, newMixture);
                if (getMinWeigth(newMixture) < minWeight) {
                    newMixture = newMixture.createSignificantComponentMixture(minWeight);
                    newL = countLnL(observations, newMixture);
                }
                // assign new L value if possible
                if (newL > resL) {
                    res = newMixture;
                    stepListener.stepDone(res);
                }
            }
        }
        stepListener.stepDone(res);
        return res;
    }

    /**
     * Extracts hypercube from the specified observations.
     *
     * @param observations observations
     * @return extracted interval
     */
    private Hypercube extractHypercube(ObservationProvider observations) {
        ObservationIterator iterator = observations.getIterator();
        Hypercube.Builder res = new Hypercube.Builder();
        while (iterator.isNextAvailable()) {
            res.addVector(iterator.getNext());
        }
        return res.build();
    }

    /**
     * Makes the one component initialization for the EM algorithm.
     *
     * @param observations observation for which the initialization should be calculated
     * @param dimension dimension of the observations
     * @return mixture with initial parameters
     */
    private DiagonalNormalDistributionMixture initializeOneCompoenent(ObservationProvider observations, int dimension) {
        // calculates first and second central momentum
        double[] m1 = new double[dimension];
        double[] m2 = new double[dimension];
        ObservationIterator iterator = observations.getIterator();
        while (iterator.isNextAvailable()) {
            Vector x = iterator.getNext();
            for (int i = 0; i < dimension; ++i) {
                m1[i] += x.getComponent(i);
                m2[i] += x.getComponent(i) * x.getComponent(i);
            }
        }

        for (int i = 0; i < dimension; ++i) {
            m1[i] /= iterator.getNumIterated();
            m2[i] /= iterator.getNumIterated();
            m2[i] -= m1[i] * m1[i];
            m2[i] = Math.max(minSigma, Math.sqrt(m2[i]));
        }

        // determine the components
        return new DiagonalNormalDistributionMixture.Builder().
                addComponent(1, DiagonalNormalDistribution.create(Vector.create(m1), Vector.create(m2))).
                build();
    }

    /**
     * Makes one iteration of the EM algorithm. Returns the mixture after the iteration.
     *
     * @param observations observation for which the iteration should be calculated
     * @param start starting position
     * @return mixture after the iteration
     */
    private DiagonalNormalDistributionMixture nextIteration(ObservationProvider observations, DiagonalNormalDistributionMixture start) {
        int numComponents = start.getNumComponents();
        int dim = start.getDimension();
        double[] newW = new double[numComponents];
        double[][] newMi = new double[numComponents][dim];
        double[][] newSigma = new double[numComponents][dim];

        double[] c = new double[numComponents];
        double c0 = -Double.MAX_VALUE;
        double h = 0;
        double qmx = 0;
        //double L = 0;
        double[] mi = null;
        double[] sigma = null;

        // adding values
        ObservationIterator iterator = observations.getIterator();
        while (iterator.isNextAvailable()) {
            Vector x = iterator.getNext();
            c0 = -Double.MAX_VALUE;

            for (int j = 0; j < numComponents; ++j) {
                c[j] = Math.log(start.getWeights().get(j)) + start.getComponents().get(j).getLnValue(x);
                if (c[j] > c0) {
                    c0 = c[j];
                }
            }

            h = 0;
            for (int j = 0; j < numComponents; ++j) {
                c[j] -= c0;
                if (c[j] > COMP_TRUNC) {
                    c[j] = Math.exp(c[j]);
                    h += c[j];
                }
                else {
                    c[j] = 0;
                }
            }
            //L += Math.log(h) + c0;

            for (int j = 0; j < numComponents; ++j) {
                if (c[j] == 0) {
                    continue;
                }
                mi = newMi[j];
                sigma = newSigma[j];
                qmx = c[j] / h;
                newW[j] += qmx;

                for (int k = 0; k < dim; ++k) {
                    mi[k] += x.getComponent(k) * qmx;
                    sigma[k] += x.getComponent(k) * x.getComponent(k) * qmx;
                }
            }
        }

        // finishing
        for (int i = 0; i < numComponents; ++i) {
            mi = newMi[i];
            sigma = newSigma[i];
            for (int j = 0; j < dim; ++j) {
                mi[j] /= newW[i];
                sigma[j] = Math.max(minSigma, Math.sqrt(sigma[j] / newW[i] - mi[j] * mi[j]));
                if (Double.valueOf(sigma[j]).equals(Double.NaN)) {
                    sigma[j] = minSigma;
                }
            }
            newW[i] /= iterator.getNumIterated();
        }

        //System.out.println("L = " + (L / obs.size()));
        // creating new instance
        DiagonalNormalDistributionMixture.Builder builder = new DiagonalNormalDistributionMixture.Builder();
        for (int i = 0; i < numComponents; ++i) {
            builder.addComponent(newW[i], new DiagonalNormalDistribution.Builder().
                    setMi(Vector.create(newMi[i])).
                    setSigma(Vector.create(newSigma[i])).
                    build());
        }
        return builder.build();

    }

    /**
     * Calculates the ln(L) value.
     * Where L = prod_x( sum_i(w(i|x)P(i|x)) ) and ln(L) = sum_x( ln(sum_i(w(i|x)P(i|x))) ).
     *
     * @param observations observations
     * @param mixture mixture
     * @return ln(L) value
     */
    private double countLnL(ObservationProvider observations, DiagonalNormalDistributionMixture mixture) {
        double[] c = new double[mixture.getNumComponents()];
        double c0 = -Double.MAX_VALUE;
        double h = 0;
        double L = 0;

        // adding values
        ObservationIterator iterator = observations.getIterator();
        while (iterator.isNextAvailable()) {
            Vector x = iterator.getNext();
            c0 = -Double.MAX_VALUE;

            for (int j = 0; j < mixture.getNumComponents(); ++j) {
                c[j] = Math.log(mixture.getWeights().get(j)) + mixture.getComponents().get(j).getLnValue(x);
                if (c[j] > c0) {
                    c0 = c[j];
                }
            }

            h = 0;
            for (int j = 0; j < mixture.getNumComponents(); ++j) {
                c[j] -= c0;
                if (c[j] > COMP_TRUNC) {
                    c[j] = Math.exp(c[j]);
                    h += c[j];
                }
                else {
                    c[j] = 0;
                }
            }
            L += Math.log(h) + c0;
        }

        // not divide to fall back to the case what was before
        //L = L / iterator.getNumIterated();
        //System.out.println("L = " + L);
        return L;

    }

    /**
     * Returns minimal weight.
     *
     * @param mixture mixture
     * @return minimal weight
     */
    private double getMinWeigth(DiagonalNormalDistributionMixture mixture) {
        double res = 1;
        for (double w : mixture.getWeights()) {
            if (w < res) {
                res = w;
            }
        }
        return res;
    }

    @Override
    public String toString() {
        return ToStringBuilder.reflectionToString(this);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy