All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.stat.distribution.DiscreteMixture Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 *   
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *  
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package smile.stat.distribution;

import java.util.List;
import java.util.ArrayList;
import smile.math.Math;

/**
 * The finite mixture of discrete distributions.
 *
 * @author Haifeng Li
 */
public class DiscreteMixture extends DiscreteDistribution {

    /**
     * A component in the mixture distribution is defined by a distribution
     * and its weight in the mixture.
     */
    public static class Component {

        /**
         * The distribution of component.
         */
        public DiscreteDistribution distribution;
        /**
         * The priori probability of component.
         */
        public double priori;
    }
    List components;

    /**
     * Constructor.
     */
    DiscreteMixture() {
        components = new ArrayList();
    }

    /**
     * Constructor.
     * @param mixture a list of discrete distributions.
     */
    public DiscreteMixture(List mixture) {
        components = new ArrayList();
        components.addAll(mixture);

        double sum = 0.0;
        for (Component component : mixture) {
            sum += component.priori;
            if (component.distribution instanceof DiscreteDistribution == false) {
                throw new IllegalArgumentException("Component " + component + " is not a discrete distribution.");
            }
        }

        if (Math.abs(sum - 1.0) > 1E-3) {
            throw new IllegalArgumentException("The sum of priori is not equal to 1.");
        }
    }

    @Override
    public double mean() {
        if (components.size() == 0) {
            throw new IllegalStateException("Mixture is empty!");
        }

        double mu = 0.0;

        for (Component c : components) {
            mu += c.priori * c.distribution.mean();
        }

        return mu;
    }

    @Override
    public double var() {
        if (components.size() == 0) {
            throw new IllegalStateException("Mixture is empty!");
        }

        double variance = 0.0;

        for (Component c : components) {
            variance += c.priori * c.priori * c.distribution.var();
        }

        return variance;
    }

    @Override
    public double sd() {
        return Math.sqrt(var());
    }

    /**
     * Shannon entropy. Not supported.
     */
    @Override
    public double entropy() {
        throw new UnsupportedOperationException("Mixture does not support entropy()");
    }

    @Override
    public double p(int x) {
        if (components.isEmpty()) {
            throw new IllegalStateException("Mixture is empty!");
        }

        double p = 0.0;

        for (Component c : components) {
            p += c.priori * c.distribution.p(x);
        }

        return p;
    }

    @Override
    public double logp(int x) {
        if (components.size() == 0) {
            throw new IllegalStateException("Mixture is empty!");
        }

        return Math.log(p(x));
    }

    @Override
    public double cdf(double x) {
        if (components.size() == 0) {
            throw new IllegalStateException("Mixture is empty!");
        }

        double p = 0.0;

        for (Component c : components) {
            p += c.priori * c.distribution.cdf(x);
        }

        return p;
    }

    @Override
    public double rand() {
        if (components.size() == 0) {
            throw new IllegalStateException("Mixture is empty!");
        }

        double r = Math.random();

        double p = 0.0;
        for (Component g : components) {
            p += g.priori;
            if (r <= p) {
                return g.distribution.rand();
            }
        }

        // we should not arrive here.
        return components.get(components.size() - 1).distribution.rand();
    }

    @Override
    public double quantile(double p) {
        if (components.size() == 0) {
            throw new IllegalStateException("Mixture is empty!");
        }

        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("Invalid p: " + p);
        }

        // Starting guess near peak of density.
        // Expand interval until we bracket.
        int xl, xu, inc = 1;
        int x = (int) mean();
        if (p < cdf(x)) {
            do {
                x -= inc;
                inc *= 2;
            } while (p < cdf(x));
            xl = x;
            xu = x + inc / 2;
        } else {
            do {
                x += inc;
                inc *= 2;
            } while (p > cdf(x));
            xu = x;
            xl = x - inc / 2;
        }

        return quantile(p, xl, xu);
    }

    @Override
    public int npara() {
        if (components.size() == 0) {
            throw new IllegalStateException("Mixture is empty!");
        }

        int f = 0;
        for (int i = 0; i < components.size(); i++) {
            f += components.get(i).distribution.npara();
        }

        return f;
    }

    /**
     * Returns the number of components in the mixture.
     */
    public int size() {
        return components.size();
    }

    /**
     * BIC score of the mixture for given data.
     */
    public double bic(double[] data) {
        if (components.size() == 0) {
            throw new IllegalStateException("Mixture is empty!");
        }

        int n = data.length;

        double logLikelihood = 0.0;
        for (double x : data) {
            double p = p(x);
            if (p > 0) {
                logLikelihood += Math.log(p);
            }
        }

        return logLikelihood - 0.5 * npara() * Math.log(n);
    }

    /**
     * Returns the list of components in the mixture.
     */
    public List getComponents() {
        return components;
    }

    @Override
    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("Mixture[");
        builder.append(components.size());
        builder.append("]:{");
        for (Component c : components) {
            builder.append(" (");
            builder.append(c.distribution);
            builder.append(':');
            builder.append(String.format("%.4f", c.priori));
            builder.append(')');
        }
        builder.append("}");
        return builder.toString();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy