All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.classification.QDA Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.classification;

import java.io.Serial;
import java.util.Properties;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.util.IntSet;
import smile.util.Strings;

/**
 * Quadratic discriminant analysis. QDA is closely related to linear discriminant
 * analysis (LDA). Like LDA, QDA models the conditional probability density
 * functions as a Gaussian distribution, then uses the posterior distributions
 * to estimate the class for a given test data. Unlike LDA, however,
 * in QDA there is no assumption that the covariance of each of the classes
 * is identical. Therefore, the resulting separating surface between
 * the classes is quadratic.
 * 

* The Gaussian parameters for each class can be estimated from training data * with maximum likelihood (ML) estimation. However, when the number of * training instances is small compared to the dimension of input space, * the ML covariance estimation can be ill-posed. One approach to resolve * the ill-posed estimation is to regularize the covariance estimation. * One of these regularization methods is {@link RDA regularized discriminant analysis}. * * @see LDA * @see RDA * @see NaiveBayes * * @author Haifeng Li */ public class QDA extends AbstractClassifier { @Serial private static final long serialVersionUID = 2L; /** * The dimensionality of data. */ private final int p; /** * The number of classes. */ private final int k; /** * Constant term of discriminant function of each class. */ private final double[] logppriori; /** * A priori probabilities of each class. */ private final double[] priori; /** * Mean vectors of each class. */ private final double[][] mu; /** * Eigen values of each covariance matrix. */ private final double[][] eigen; /** * Eigen vectors of each covariance matrix, which transforms observations * to discriminant functions, normalized so that within groups covariance * matrix is spherical. */ private final Matrix[] scaling; /** * Constructor. * @param priori a priori probabilities of each class. * @param mu the mean vectors of each class. * @param eigen the eigen values of each variance matrix. * @param scaling the eigen vectors of each covariance matrix. */ public QDA(double[] priori, double[][] mu, double[][] eigen, Matrix[] scaling) { this(priori, mu, eigen, scaling, IntSet.of(priori.length)); } /** * Constructor. * @param priori a priori probabilities of each class. * @param mu the mean vectors of each class. * @param eigen the eigen values of each variance matrix. * @param scaling the eigen vectors of each covariance matrix. * @param labels the class label encoder. */ public QDA(double[] priori, double[][] mu, double[][] eigen, Matrix[] scaling, IntSet labels) { super(labels); this.k = priori.length; this.p = mu[0].length; this.priori = priori; this.mu = mu; this.eigen = eigen; this.scaling = scaling; logppriori = new double[k]; for (int i = 0; i < k; i++) { double logev = 0.0; for (int j = 0; j < p; j++) { logev += Math.log(eigen[i][j]); } logppriori[i] = Math.log(priori[i]) - 0.5 * logev; } } /** * Fits quadratic discriminant analysis. * @param x training samples. * @param y training labels in [0, k), where k is the number of classes. * @return the model. */ public static QDA fit(double[][] x, int[] y) { return fit(x, y, null, 1E-4); } /** * Fits quadratic discriminant analysis. * @param x training samples. * @param y training labels. * @param params the hyperparameters. * @return the model. */ public static QDA fit(double[][] x, int[] y, Properties params) { double[] priori = Strings.parseDoubleArray(params.getProperty("smile.qda.priori")); double tol = Double.parseDouble(params.getProperty("smile.qda.tolerance", "1E-4")); return fit(x, y, priori, tol); } /** * Fits quadratic discriminant analysis. * @param x training samples. * @param y training labels in [0, k), where k is the number of classes. * @param priori the priori probability of each class. If null, it will be * estimated from the training data. * @param tol a tolerance to decide if a covariance matrix is singular; * it will reject variables whose variance is less than tol2. * @return the model. */ public static QDA fit(double[][] x, int[] y, double[] priori, double tol) { DiscriminantAnalysis da = DiscriminantAnalysis.fit(x, y, priori, tol); Matrix[] cov = DiscriminantAnalysis.cov(x, y, da.mu, da.ni); int k = cov.length; int p = cov[0].nrow(); double[][] eigen = new double[k][]; Matrix[] scaling = new Matrix[k]; tol = tol * tol; for (int i = 0; i < k; i++) { // quick test of singularity for (int j = 0; j < p; j++) { if (cov[i].get(j, j) < tol) { throw new IllegalArgumentException(String.format("Class %d covariance matrix (column %d) is close to singular.", i, j)); } } Matrix.EVD evd = cov[i].eigen(false, true, true).sort(); for (double s : evd.wr) { if (s < tol) { throw new IllegalArgumentException(String.format("Class %d covariance matrix is close to singular.", i)); } } eigen[i] = evd.wr; scaling[i] = evd.Vr; } return new QDA(da.priori, da.mu, eigen, scaling, da.labels); } /** * Returns a priori probabilities. * @return a priori probabilities. */ public double[] priori() { return priori; } @Override public int predict(double[] x) { return predict(x, new double[k]); } @Override public boolean soft() { return true; } @Override public int predict(double[] x, double[] posteriori) { if (x.length != p) { throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, p)); } double[] d = new double[p]; double[] ux = new double[p]; for (int i = 0; i < k; i++) { double[] mui = mu[i]; for (int j = 0; j < p; j++) { d[j] = x[j] - mui[j]; } scaling[i].tv(d, ux); double f = 0.0; double[] ev = eigen[i]; for (int j = 0; j < p; j++) { f += ux[j] * ux[j] / ev[j]; } posteriori[i] = logppriori[i] - 0.5 * f; } return classes.valueOf(MathEx.softmax(posteriori)); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy