All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.enterprisemath.math.statistics.NormalDistributionMixtureEstimator Maven / Gradle / Ivy

The newest version!
package com.enterprisemath.math.statistics;

import org.apache.commons.lang3.builder.ToStringBuilder;

import com.enterprisemath.math.algebra.Interval;
import com.enterprisemath.math.probability.NormalDistribution;
import com.enterprisemath.math.probability.NormalDistributionMixture;
import com.enterprisemath.math.statistics.observation.ObservationIterator;
import com.enterprisemath.math.statistics.observation.ObservationProvider;
import com.enterprisemath.utils.ValidationUtils;

/**
 * This class is responsible for estimating the normal distribution mixture.
 * Purpose of this class is to let the user to set high level limits and let the algorithm to do the rest.
 * The whole algorithm works in the way that whole estimation starts with one component.
 * Every iteration one component is selected and split. Then classic EM algorithm is invoked.
 *
 * @author radek.hecl
 *
 */
public class NormalDistributionMixtureEstimator implements Estimator {

    /**
     * Builder object.
     */
    public static class Builder {

        /**
         * Maximum allowed number of components.
         */
        private Integer maxComponents = 20;

        /**
         * The minimum weight which is allowed for a component during estimation.
         * Every component which has less than the minWeight will be removed.
         * This essentially means that group of observations statistically less important than minWeight
         * might (but not necessary are) be ignored.
         * Value must be in interval [0, 1).
         */
        private Double minWeight = 0.05;

        /**
         * Minimum allowed sigma.
         */
        private Double minSigma = 0.01;

        /**
         * Step listener.
         */
        private EstimatorStepListener stepListener = EmptyEstimatorStepListener.create();

        /**
         * Sets maximum allowed number of components.
         *
         * @param maxComponents maximum of number of components in the result
         * @return this instance
         */
        public Builder setMaxComponents(int maxComponents) {
            this.maxComponents = maxComponents;
            return this;
        }

        /**
         * Sets minimal weight for the components.
         * Every component with weight less than minWeight will be removed from the result.
         * This essentially means that group of observations statistically less important than minWeight
         * might (but not necessary are) be ignored.
         * Value must be in interval [0, 1).
         *
         * @param minWeight component minimal weight value
         * @return this instance
         */
        public Builder setMinWeight(double minWeight) {
            this.minWeight = minWeight;
            return this;
        }

        /**
         * Sets minimal allowed sigma for all components.
         *
         * @param minSigma minimal allowed sigma for all components
         * @return this instance
         */
        public Builder setMinSigma(Double minSigma) {
            this.minSigma = minSigma;
            return this;
        }

        /**
         * Sets step listener.
         *
         * @param stepListener step listener
         * @return this instance
         */
        public Builder setStepListener(EstimatorStepListener stepListener) {
            this.stepListener = stepListener;
            return this;
        }

        /**
         * Builds the result object.
         *
         * @return created object
         */
        public NormalDistributionMixtureEstimator build() {
            return new NormalDistributionMixtureEstimator(this);
        }
    }

    /**
     * Maximum allowed number of components.
     */
    private Integer maxComponents;

    /**
     * The minimum weight which is allowed for a component during estimation.
     * Every component which has less than the minWeight will be removed.
     * This essentially means that group of observations statistically less important than minWeight
     * might (but not necessary are) be ignored.
     * Value must be in interval [0, 1).
     */
    private Double minWeight;

    /**
     * Minimum allowed sigma.
     */
    private Double minSigma;

    /**
     * Step listener.
     */
    private EstimatorStepListener stepListener;

    /**
     * Creates new instance.
     *
     * @param builder builder object
     */
    public NormalDistributionMixtureEstimator(Builder builder) {
        minWeight = builder.minWeight;
        maxComponents = builder.maxComponents;
        minSigma = builder.minSigma;
        stepListener = builder.stepListener;
        guardInvariants();
    }

    /**
     * Guards this object to be consistent. Throws exception if this is not the case.
     */
    private void guardInvariants() {
        ValidationUtils.guardPositiveInt(maxComponents, "maxComponents must be positive");
        ValidationUtils.guardNotNegativeDouble(minWeight, "minWeight cannot be negative");
        ValidationUtils.guardGreaterDouble(1, minWeight, "minWeight must be less than 1");
        ValidationUtils.guardPositiveDouble(minSigma, "minSigma must be positive");
        ValidationUtils.guardNotNull(stepListener, "stepListener cannot be null");
    }

    @Override
    public NormalDistributionMixture estimate(ObservationProvider observations) {
        Interval minMax = extractInterval(observations);
        ValidationUtils.guardGreaterOrEqualDouble(minMax.getMin(), -1000000,
                "observation interval is out of range for calcualtion");
        ValidationUtils.guardGreaterOrEqualDouble(1000000, minMax.getMax(),
                "observation interval is out of range for calcualtion");

        NormalDistributionMixture res = initializeOneCompoenent(observations);
        stepListener.stepDone(res);
        double resL = Double.NEGATIVE_INFINITY;
        double newL = countLnL(observations, res);

        int iteration = 0;
        while (newL - resL > 0.01 && res.getNumComponents() < maxComponents && iteration < 100) {
            ++iteration;
            resL = newL;
            // find the maximum weight
            double splitValue = 0;
            int splitIdx = 0;
            for (int i = 0; i < res.getNumComponents(); ++i) {
                if (res.getWeights().get(i) * res.getComponents().get(i).getSigma() > splitValue) {
                    splitValue = res.getWeights().get(i) * res.getComponents().get(i).getSigma();
                    splitIdx = i;
                }
            }
            // split the component with highest weight
            NormalDistributionMixture.Builder builder = new NormalDistributionMixture.Builder();
            for (int i = 0; i < res.getNumComponents(); ++i) {
                if (i == splitIdx) {
                    NormalDistribution comp = res.getComponents().get(i);
                    builder.addComponent(res.getWeights().get(i) / 2, comp.getMi() + comp.getSigma() / 2, comp.getSigma());
                    builder.addComponent(res.getWeights().get(i) / 2, comp.getMi() - comp.getSigma() / 2, comp.getSigma());
                }
                else {
                    if (res.getWeights().get(i) >= minWeight) {
                        builder.addComponent(res.getWeights().get(i), res.getComponents().get(i));
                    }
                }
            }
            NormalDistributionMixture newMixture = builder.build();
            // iterations for the new mixture

            double help = Double.NEGATIVE_INFINITY;
            int emiteration = 0;
            while (emiteration < 5 || newL - help > 0.01) {
                ++emiteration;
                help = newL;
                newMixture = nextIteration(observations, newMixture);
                newL = countLnL(observations, newMixture);
                if (getMinWeigth(newMixture) < minWeight) {
                    newMixture = newMixture.createSignificantComponentMixture(minWeight);
                    newL = countLnL(observations, newMixture);
                }
                // assign new L value if possible
                if (newL > resL) {
                    res = newMixture;
                    stepListener.stepDone(res);
                }
            }
        }
        stepListener.stepDone(res);
        return res;
    }

    /**
     * Extracts interval from the specified observations.
     *
     * @param observations observations
     * @return extracted interval
     */
    private Interval extractInterval(ObservationProvider observations) {
        ObservationIterator iterator = observations.getIterator();
        Interval.Builder res = new Interval.Builder();
        while (iterator.isNextAvailable()) {
            res.addPoint(iterator.getNext());
        }
        return res.build();
    }

    /**
     * Makes the one component initialization for the EM algorithm.
     *
     * @param observations observation for which the initialization should be calculated
     * @return mixture with initial parameters
     */
    private NormalDistributionMixture initializeOneCompoenent(ObservationProvider observations) {
        // calculates first and second central momentum
        double m1 = 0;
        double m2 = 0;
        ObservationIterator iterator = observations.getIterator();
        while (iterator.isNextAvailable()) {
            double x = iterator.getNext();
            m1 += x;
            m2 += x * x;
        }
        m1 /= iterator.getNumIterated();
        m2 /= iterator.getNumIterated();
        m2 -= m1 * m1;

        double sigma = Math.sqrt(m2);
        if (sigma < minSigma) {
            sigma = minSigma;
        }
        // determine the components
        return new NormalDistributionMixture.Builder().
                addComponent(1, NormalDistribution.create(m1, sigma)).
                build();
    }

    /**
     * Makes one iteration of the EM algorithm. Returns the mixture after the iteration.
     *
     * @param observations observation for which the iteration should be calculated
     * @param start starting position
     * @return mixture after the iteration
     */
    private NormalDistributionMixture nextIteration(ObservationProvider observations, NormalDistributionMixture start) {
        // 
        int numComponents = start.getNumComponents();
        //double L = 0;
        double q = 0;
        double[] w = new double[numComponents];
        double[] mi = new double[numComponents];
        double[] sigma = new double[numComponents];

        // iteration over all observations
        ObservationIterator iterator = observations.getIterator();
        while (iterator.isNextAvailable()) {
            double x = iterator.getNext();
            for (int j = 0; j < numComponents; ++j) {
                q = start.getWeights().get(j) * start.getComponents().get(j).getValue(x) / start.getValue(x);
                if (Double.isNaN(q)) {
                    q = 0;
                }
                w[j] += q;
                mi[j] += x * q;
                sigma[j] += x * x * q;
            }
            //L += Math.log(getValue(samples[i]));
        }

        // finalizing the calculation
        for (int i = 0; i < numComponents; ++i) {
            mi[i] = mi[i] / w[i];
            sigma[i] = Math.sqrt(-mi[i] * mi[i] + sigma[i] / w[i]);
            if (Double.valueOf(sigma[i]).equals(Double.NaN)) {
                sigma[i] = Double.MIN_VALUE;
            }
            w[i] = w[i] / iterator.getNumIterated();
        }
        //L /= samples.length;

        // creating new instance
        NormalDistributionMixture.Builder builder = new NormalDistributionMixture.Builder();
        for (int i = 0; i < numComponents; ++i) {
            builder.addComponent(w[i], new NormalDistribution.Builder().
                    setMi(mi[i]).
                    setSigma(Math.max(minSigma, sigma[i])).
                    build());
        }
        return builder.build();
    }

    /**
     * Calculates the ln(L) value.
     * Where L = prod_x( sum_i(w(i|x)P(i|x)) ) and ln(L) = sum_x( ln(sum_i(w(i|x)P(i|x))) ).
     *
     * @param observations observations
     * @param mixture mixture
     * @return ln(L) value
     */
    private double countLnL(ObservationProvider observations, NormalDistributionMixture mixture) {
        double res = 0;
        ObservationIterator iterator = observations.getIterator();
        while (iterator.isNextAvailable()) {
            double x = iterator.getNext();
            res += mixture.getLnValue(x);
        }
        return res;
    }

    /**
     * Returns minimal weight.
     *
     * @param mixture mixture
     * @return minimal weight
     */
    private double getMinWeigth(NormalDistributionMixture mixture) {
        double res = 1;
        for (double w : mixture.getWeights()) {
            if (w < res) {
                res = w;
            }
        }
        return res;
    }

    @Override
    public String toString() {
        return ToStringBuilder.reflectionToString(this);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy