All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.glm.GLM Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.glm;

import java.io.Serial;
import java.io.Serializable;
import java.util.Properties;
import java.util.stream.IntStream;

import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.glm.model.Model;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.math.special.Erf;
import smile.stat.Hypothesis;
import smile.validation.ModelSelection;

/**
 * Generalized linear models. The generalized linear model (GLM) is a flexible
 * generalization of ordinary linear regression that allows for response
 * variables that have error distribution models other than a normal
 * distribution. The GLM generalizes linear regression by allowing the
 * linear model to be related to the response variable via a link function
 * and by allowing the magnitude of the variance of each measurement to be
 * a function of its predicted value.
 * 

* In GLM, each outcome Y of the dependent variables is assumed * to be generated from a particular distribution in an exponential family. * The mean, μ, of the distribution depends on the * independent variables, X, through: *

* E(Y) = μ = g-1(Xβ) *

* where E(Y) is the expected value of Y; * is the linear combination of linear predictors * and unknown parameters β; g is the link function that is a monotonic, * differentiable function. THe link function that transforms the mean to * the natural parameter is called the canonical link. *

* In this framework, the variance is typically a function, V, * of the mean: *

* Var(Y) = V(μ) = V(g-1(Xβ)) *

* It is convenient if V follows from an exponential family * of distributions, but it may simply be that the variance is a function * of the predicted value, such as V(μi) = μi * for the Poisson, V(μi) = μi(1 - μi) * for the Bernoulli, and V(μi) = σ2 * (i.e., constant) for the normal. *

* The unknown parameters, β, are typically estimated * with maximum likelihood, maximum quasi-likelihood, or Bayesian techniques. * * @author Haifeng Li */ public class GLM implements Serializable { @Serial private static final long serialVersionUID = 2L; private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(GLM.class); /** * The symbolic description of the model to be fitted. */ protected final Formula formula; /** * The predictors of design matrix. */ final String[] predictors; /** * The model specifications (link function, deviance, etc.). */ protected final Model model; /** * The linear weights. */ protected final double[] beta; /** * The coefficients, their standard errors, z-scores, and p-values. */ protected final double[][] ztest; /** * The fitted mean values. */ protected final double[] mu; /** * The null deviance = 2 * (LogLikelihood(Saturated Model) - LogLikelihood(Null Model)). *

* The saturated model, also referred to as the full model or maximal model, * allows a different mean response for each group of replicates. * One can think of the saturated model as having the most general * possible mean structure for the data since the means are unconstrained. *

* The null model assumes that all observations have the same distribution * with common parameter. Like the saturated model, the null model does not * depend on predictor variables. While the saturated most is the most * general model, the null model is the most restricted model. */ protected final double nullDeviance; /** * The deviance = 2 * (LogLikelihood(Saturated Model) - LogLikelihood(Proposed Model)). */ protected final double deviance; /** * The deviance residuals. */ protected final double[] devianceResiduals; /** * The degrees of freedom of the residual deviance. */ protected final int df; /** * Log-likelihood. */ protected final double logLikelihood; /** * Constructor. * @param formula the model formula. * @param predictors the predictors of design matrix. * @param model the generalized linear model specification. * @param beta the linear weights. * @param logLikelihood the log-likelihood. * @param deviance the deviance. * @param nullDeviance the null deviance. * @param mu the fitted mean values. * @param residuals the residuals of fitted values of training data. * @param ztest the z-test of the coefficients. */ public GLM(Formula formula, String[] predictors, Model model, double[] beta, double logLikelihood, double deviance, double nullDeviance, double[] mu, double[] residuals, double[][] ztest) { this.formula = formula; this.model = model; this.predictors = predictors; this.beta = beta; this.logLikelihood = logLikelihood; this.deviance = deviance; this.nullDeviance = nullDeviance; this.mu = mu; this.devianceResiduals = residuals; this.ztest = ztest; df = mu.length - beta.length; } /** * Returns an array of size (p+1) containing the linear weights * of binary logistic regression, where p is the dimension of * feature vectors. The last element is the weight of bias. * * @return the linear weights. */ public double[] coefficients() { return beta; } /** * Returns the z-test of the coefficients (including intercept). * The first column is the coefficients, the second column is the standard * error of coefficients, the third column is the z-score of the hypothesis * test if the coefficient is zero, the fourth column is the p-values of * test. The last row is of intercept. * * @return the z-test of the coefficients. */ public double[][] ztest() { return ztest; } /** * Returns the deviance residuals. * @return the deviance residuals. */ public double[] devianceResiduals() { return devianceResiduals; } /** * Returns the fitted mean values. * @return the fitted mean values. */ public double[] fittedValues() { return mu; } /** * Returns the deviance of model. * @return the deviance of model. */ public double deviance() { return deviance; } /** * Returns the log-likelihood of model. * @return the log-likelihood of model. */ public double logLikelihood() { return logLikelihood; } /** * Returns the AIC score. * @return the AIC score. */ public double AIC() { return ModelSelection.AIC(logLikelihood, beta.length); } /** * Returns the BIC score. * @return the BIC score. */ public double BIC() { return ModelSelection.BIC(logLikelihood, beta.length, mu.length); } /** * Predicts the mean response. * @param x the instance. * @return the mean response. */ public double predict(Tuple x) { double[] a = formula.x(x).toArray(true, CategoricalEncoder.DUMMY); int p = beta.length; double dot = 0.0; for (int i = 0; i < p; i++) { dot += a[i] * beta[i]; } return model.invlink(dot); } /** * Predicts the mean response. * @param data the data frame. * @return the mean response. */ public double[] predict(DataFrame data) { Matrix X = formula.matrix(data, true); double[] y = X.mv(beta); int n = y.length; for (int i = 0; i < n; i++) { y[i] = model.invlink(y[i]); } return y; } @Override public String toString() { StringBuilder builder = new StringBuilder(); builder.append(String.format("Generalized Linear Model - %s:\n", model)); double[] r = devianceResiduals.clone(); builder.append("\nDeviance Residuals:\n"); builder.append(" Min 1Q Median 3Q Max\n"); builder.append(String.format("%10.4f %10.4f %10.4f %10.4f %10.4f%n", MathEx.min(r), MathEx.q1(r), MathEx.median(r), MathEx.q3(r), MathEx.max(r))); int p = beta.length - 1; builder.append("\nCoefficients:\n"); if (ztest != null) { builder.append(" Estimate Std. Error z value Pr(>|z|)\n"); for (int i = 0; i < p; i++) { builder.append(String.format("%-15s %10.3e %10.3e %10.4f %10.5f %s%n", predictors[i], ztest[i][0], ztest[i][1], ztest[i][2], ztest[i][3], Hypothesis.significance(ztest[i][3]))); } builder.append("---------------------------------------------------------------------\n"); builder.append("Significance codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n"); } else { builder.append(String.format("Intercept %10.4f%n", beta[p])); for (int i = 0; i < p; i++) { builder.append(String.format("%-15s %10.4f%n", predictors[i], beta[i])); } } builder.append(String.format("%n Null deviance: %.1f on %d degrees of freedom", nullDeviance, df+p)); builder.append(String.format("%nResidual deviance: %.1f on %d degrees of freedom", deviance, df)); builder.append(String.format("%nAIC: %.4f BIC: %.4f%n", AIC(), BIC())); return builder.toString(); } /** * Fits the generalized linear model with IWLS (iteratively reweighted least squares). * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @param model the generalized linear model specification. * @return the model. */ public static GLM fit(Formula formula, DataFrame data, Model model) { return fit(formula, data, model, new Properties()); } /** * Fits the generalized linear model with IWLS (iteratively reweighted least squares). * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @param model the generalized linear model specification. * @param params the hyperparameters. * @return the model. */ public static GLM fit(Formula formula, DataFrame data, Model model, Properties params) { double tol = Double.parseDouble(params.getProperty("smile.glm.tolerance", "1E-5")); int maxIter = Integer.parseInt(params.getProperty("smile.glm.iterations", "50")); return fit(formula, data, model, tol, maxIter); } /** * Fits the generalized linear model with IWLS (iteratively reweighted least squares). * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @param model the generalized linear model specification. * @param tol the tolerance for stopping iterations. * @param maxIter the maximum number of iterations. * @return the model. */ public static GLM fit(Formula formula, DataFrame data, Model model, double tol, int maxIter) { if (tol <= 0.0) { throw new IllegalArgumentException("Invalid tolerance: " + tol); } if (maxIter <= 0) { throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter); } Matrix X = formula.matrix(data, true); Matrix XW = new Matrix(X.nrow(), X.ncol()); double[] y = formula.y(data).toDoubleArray(); int n = X.nrow(); int p = X.ncol(); if (n <= p) { throw new IllegalArgumentException(String.format("The input matrix is not over determined: %d rows, %d columns", n, p)); } double[] eta = new double[n]; double[] mu = new double[n]; double[] w = new double[n]; // sqrt of diagonal of W double[] z = new double[n]; double[] residuals = new double[n]; // Initialization IntStream.range(0, n).parallel().forEach(i -> { mu[i] = model.mustart(y[i]); eta[i] = model.link(mu[i]); double g = model.dlink(mu[i]); // z[i] = eta[i] + (y[i] - mu[i]) * g; double v = model.variance(mu[i]); w[i] = 1.0 / (g * Math.sqrt(v)); z[i] *= w[i]; }); for (int j = 0; j < p; j++) { for (int i = 0; i < n; i++) { XW.set(i, j, X.get(i, j) * w[i]); } } Matrix.QR qr = XW.qr(true); double[] beta = qr.solve(z); double dev = Double.POSITIVE_INFINITY; for (int iter = 0; iter < maxIter; iter++) { X.mv(beta, eta); IntStream.range(0, n).parallel().forEach(i -> { mu[i] = model.invlink(eta[i]); double g = model.dlink(mu[i]); z[i] = eta[i] + (y[i] - mu[i]) * g; double v = model.variance(mu[i]); w[i] = 1.0 / (g * Math.sqrt(v)); z[i] *= w[i]; }); double newDev = model.deviance(y, mu, residuals); if (iter > 0) { logger.info("Deviance after {} iterations: {}", iter, dev); } if (dev - newDev < tol) { break; } dev = newDev; for (int j = 0; j < p; j++) { for (int i = 0; i < n; i++) { XW.set(i, j, X.get(i, j) * w[i]); } } qr = XW.qr(true); beta = qr.solve(z); } Matrix.Cholesky cholesky = qr.CholeskyOfAtA(); Matrix inv = cholesky.inverse(); double[][] ztest = new double[p][4]; for (int i = 0; i < p; i++) { ztest[i][0] = beta[i]; ztest[i][1] = Math.sqrt(inv.get(i, i)); ztest[i][2] = ztest[i][0] / ztest[i][1]; ztest[i][3] = 2.0 - Erf.erfc(-0.707106781186547524 * Math.abs(ztest[i][2])); } return new GLM(formula, X.colNames(), model, beta, model.logLikelihood(y, mu), dev, model.nullDeviance(y, MathEx.mean(y)), mu, residuals, ztest); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy