All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.stat.distribution.MultivariateExponentialFamilyMixture Maven / Gradle / Ivy

/*******************************************************************************
 * Copyright (c) 2010-2020 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation, either version 3 of
 * the License, or (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with Smile.  If not, see .
 ******************************************************************************/

package smile.stat.distribution;

import smile.math.MathEx;

/**
 * The finite mixture of distributions from multivariate exponential family.
 * The EM algorithm can be used to learn the mixture model from data.
 *
 * @author Haifeng Li
 */
public class MultivariateExponentialFamilyMixture extends MultivariateMixture {
    private static final long serialVersionUID = 2L;
    private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MultivariateExponentialFamilyMixture.class);

    /** The log-likelihood when the distribution is fit on a sample data. */
    public final double L;
    /** The BIC score when the distribution is fit on a sample data. */
    public final double bic;

    /**
     * Constructor.
     * @param components a list of multivariate exponential family distributions.
     */
    public MultivariateExponentialFamilyMixture(Component... components) {
        this(0.0, 1, components);
    }

    /**
     * Constructor.
     * @param components a list of discrete exponential family distributions.
     * @param L the log-likelihood.
     * @param n the number of samples to fit the distribution.
     */
    MultivariateExponentialFamilyMixture(double L, int n, Component... components) {
        super(components);

        for (Component component : components) {
            if (component.distribution instanceof MultivariateExponentialFamily == false) {
                throw new IllegalArgumentException("Component " + component + " is not of multivariate exponential family.");
            }
        }

        this.L = L;
        this.bic = L - 0.5 * length() * Math.log(n);
    }

    /**
     * Fits the mixture model with the EM algorithm.
     * @param x the training data.
     * @param components the initial configuration of mixture. Components may have
     *                   different distribution form.
     */
    public static MultivariateExponentialFamilyMixture fit(double[][] x, Component... components) {
        return fit(x, components, 0.2, 500, 1E-4);
    }

    /**
     * Fits the mixture model with the EM algorithm.
     *
     * @param x the training data.
     * @param components the initial configuration of mixture. Components may have
     *                   different distribution form.
     * @param gamma the regularization parameter.
     * @param maxIter the maximum number of iterations.
     * @param tol the tolerance of convergence test.
     */
    public static MultivariateExponentialFamilyMixture fit(double[][] x, Component[] components, double gamma, int maxIter, double tol) {
        if (x.length < components.length / 2) {
            throw new IllegalArgumentException("Too many components");
        }

        if (gamma < 0.0 || gamma > 0.2) {
            throw new IllegalArgumentException("Invalid regularization factor gamma.");
        }

        int n = x.length;
        int k = components.length;

        double[][] posteriori = new double[k][n];

        // Log Likelihood
        double L = 0.0;

        // EM loop until convergence
        double diff = Double.MAX_VALUE;
        for (int iter = 1; iter <= maxIter && diff > tol; iter++) {
            // Expectation step
            for (int i = 0; i < k; i++) {
                Component c = components[i];

                for (int j = 0; j < n; j++) {
                    posteriori[i][j] = c.priori * c.distribution.p(x[j]);
                }
            }

            // Normalize posteriori probability.
            for (int j = 0; j < n; j++) {
                double p = 0.0;

                for (int i = 0; i < k; i++) {
                    p += posteriori[i][j];
                }

                for (int i = 0; i < k; i++) {
                    posteriori[i][j] /= p;
                }

                // Adjust posterior probabilites based on Regularized EM algorithm.
                if (gamma > 0) {
                    for (int i = 0; i < k; i++) {
                        posteriori[i][j] *= (1 + gamma * MathEx.log2(posteriori[i][j]));
                        if (Double.isNaN(posteriori[i][j]) || posteriori[i][j] < 0.0) {
                            posteriori[i][j] = 0.0;
                        }
                    }
                }
            }

            // Maximization step
            double Z = 0.0;
            for (int i = 0; i < k; i++) {
                components[i] = ((MultivariateExponentialFamily) components[i].distribution).M(x, posteriori[i]);
                Z += components[i].priori;
            }

            for (int i = 0; i < k; i++) {
                components[i] = new Component(components[i].priori / Z, components[i].distribution);
            }

            double loglikelihood = 0.0;
            for (double[] xi : x) {
                double p = 0.0;
                for (Component c : components) {
                    p += c.priori * c.distribution.p(xi);
                }
                if (p > 0) loglikelihood += Math.log(p);
            }


            diff = loglikelihood - L;
            L = loglikelihood;

            if (iter % 10 == 0) {
                logger.info(String.format("The log-likelihood after %d iterations: %.4f", iter, L));
            }
        }

        return new MultivariateExponentialFamilyMixture(L, x.length, components);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy