All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.classification.DiscriminantAnalysis Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.classification;

import smile.math.MathEx;
import smile.math.blas.UPLO;
import smile.math.matrix.Matrix;
import smile.util.IntSet;

/**
 * Common functions for various discriminant analysis.
 *
 * @author Haifeng Li
 */
class DiscriminantAnalysis {
    /** The number of classes. */
    final int k;
    /** The class labels in [0, k). */
    final int[] y;
    /** The original class labels. */
    final IntSet labels;
    /** The number of instances in each class. */
    final int[] ni;
    /** The priori probabilities. */
    final double[] priori;
    /** The mean vector. */
    final double[] mean;
    /** THe mean vector per class. */
    final double[][] mu;

    /**
     * Constructor.
     * @param priori a priori probabilities of each class.
     * @param mean the mean vector of all samples.
     * @param mu the mean vectors of each class.
     */
    public DiscriminantAnalysis(ClassLabels codec, double[] priori, double[] mean, double[][] mu) {
        this.k = codec.k;
        this.ni = codec.ni;
        this.y = codec.y;
        this.labels = codec.classes;
        this.priori = priori;
        this.mean = mean;
        this.mu = mu;
    }

    /**
     * Computes the common information in various discriminant analysis algorithms.
     * @param x training samples.
     * @param y training labels.
     * @param priori the priori probability of each class. If null, it will be
     * estimated from the training data.
     * @param tol a tolerance to decide if a covariance matrix is singular; it
     * will reject variables whose variance is less than tol2.
     */
    public static DiscriminantAnalysis fit(double[][] x, int[] y, double[] priori, double tol) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }

        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tol: " + tol);
        }

        int n = x.length;

        // class label set.
        ClassLabels codec = ClassLabels.fit(y);
        int k = codec.k;
        y = codec.y;
        int[] ni = codec.ni;

        if (n <= k) {
            throw new IllegalArgumentException(String.format("Sample size is too small: %d <= %d", n, k));
        }

        if (priori == null) {
            priori = codec.priori;
        } else {
            if (priori.length != k) {
                throw new IllegalArgumentException("Invalid number of priori probabilities: " + priori.length);
            }

            double sum = 0.0;
            for (double pr : priori) {
                if (pr <= 0.0 || pr >= 1.0) {
                    throw new IllegalArgumentException("Invalid priori probability: " + pr);
                }
                sum += pr;
            }

            if (Math.abs(sum - 1.0) > 1E-10) {
                throw new IllegalArgumentException("The sum of priori probabilities is not one: " + sum);
            }
        }

        int p = x[0].length;
        // Common mean vector.
        double[] mean = MathEx.colMeans(x);
        // Class mean vectors.
        double[][] mu = new double[k][p];

        for (int i = 0; i < n; i++) {
            double[] xi = x[i];
            double[] mui = mu[y[i]];
            for (int j = 0; j < p; j++) {
                mui[j] += xi[j];
            }
        }

        for (int i = 0; i < k; i++) {
            int m = ni[i];
            double[] mui = mu[i];
            for (int j = 0; j < p; j++) {
                mui[j] /= m;
            }
        }

        return new DiscriminantAnalysis(codec, priori, mean, mu);
    }

    /** Computes the covariance matrix of all samples. */
    public static Matrix St(double[][] x, double[] mean, int k, double tol) {
        int n = x.length;
        int p = x[0].length;

        Matrix St = new Matrix(p, p);
        St.uplo(UPLO.LOWER);

        for (double[] xi : x) {
            for (int j = 0; j < p; j++) {
                for (int l = 0; l <= j; l++) {
                    St.add(j, l, (xi[j] - mean[j]) * (xi[l] - mean[l]));
                }
            }
        }

        tol = tol * tol;
        for (int j = 0; j < p; j++) {
            for (int l = 0; l <= j; l++) {
                St.div(j, l, (n - k));
                St.set(l, j, St.get(j, l));
            }

            if (St.get(j, j) < tol) {
                throw new IllegalArgumentException(String.format("Covariance matrix (column %d) is close to singular.", j));
            }
        }

        return St;
    }

    /** Computes the covariance matrix of each class. */
    public static Matrix[] cov(double[][] x, int[] y, double[][] mu, int[] ni) {
        int n = x.length;
        int p = x[0].length;
        int k = mu.length;

        Matrix[] cov = new Matrix[k];

        for (int i = 0; i < k; i++) {
            if (ni[i] <= p) {
                throw new IllegalArgumentException(String.format("The sample size of class %d is too small.", i));
            }

            cov[i] = new Matrix(p, p);
            cov[i].uplo(UPLO.LOWER);
        }

        for (int i = 0; i < n; i++) {
            Matrix v = cov[y[i]];
            double[] mui = mu[y[i]];
            double[] xi = x[i];
            for (int j = 0; j < p; j++) {
                for (int l = 0; l <= j; l++) {
                    v.add(j, l, (xi[j] - mui[j]) * (xi[l] - mui[l]));
                }
            }
        }

        for (int i = 0; i < k; i++) {
            Matrix v = cov[i];
            int m = ni[i] - 1;
            for (int j = 0; j < p; j++) {
                for (int l = 0; l <= j; l++) {
                    v.div(j, l, m);
                    v.set(l, j, v.get(j, l));
                }
            }
        }

        return cov;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy