All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.classification.LogisticRegression Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 *   
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *  
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/

package smile.classification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;

import smile.math.Math;
import smile.math.DifferentiableMultivariateFunction;
import smile.util.MulticoreExecutor;

/**
 * Logistic regression. Logistic regression (logit model) is a generalized
 * linear model used for binomial regression. Logistic regression applies
 * maximum likelihood estimation after transforming the dependent into
 * a logit variable. A logit is the natural log of the odds of the dependent
 * equaling a certain value or not (usually 1 in binary logistic models,
 * the highest value in multinomial models). In this way, logistic regression
 * estimates the odds of a certain event (value) occurring. 
 * 

* Goodness-of-fit tests such as the likelihood ratio test are available * as indicators of model appropriateness, as is the Wald statistic to test * the significance of individual independent variables. *

* Logistic regression has many analogies to ordinary least squares (OLS) * regression. Unlike OLS regression, however, logistic regression does not * assume linearity of relationship between the raw values of the independent * variables and the dependent, does not require normally distributed variables, * does not assume homoscedasticity, and in general has less stringent * requirements. *

* Compared with linear discriminant analysis, logistic regression has several * advantages: *

    *
  • It is more robust: the independent variables don't have to be normally * distributed, or have equal variance in each group *
  • It does not assume a linear relationship between the independent * variables and dependent variable. *
  • It may handle nonlinear effects since one can add explicit interaction * and power terms. *
* However, it requires much more data to achieve stable, meaningful results. *

* Logistic regression also has strong connections with neural network and * maximum entropy modeling. For example, binary logistic regression is * equivalent to a one-layer, single-output neural network with a logistic * activation function trained under log loss. Similarly, multinomial logistic * regression is equivalent to a one-layer, softmax-output neural network. *

* Logistic regression estimation also obeys the maximum entropy principle, and * thus logistic regression is sometimes called "maximum entropy modeling", * and the resulting classifier the "maximum entropy classifier". * * @see NeuralNetwork * @see Maxent * @see LDA * * @author Haifeng Li */ public class LogisticRegression implements Classifier { /** * The dimension of input space. */ private int p; /** * The number of classes. */ private int k; /** * The log-likelihood of learned model. */ private double L; /** * The linear weights for binary logistic regression. */ private double[] w; /** * The linear weights for multi-class logistic regression. */ private double[][] W; /** * Trainer for logistic regression. */ public static class Trainer extends ClassifierTrainer { /** * Regularization factor. λ > 0 gives a "regularized" estimate * of linear weights which often has superior generalization * performance, especially when the dimensionality is high. */ private double lambda = 0.0; /** * The tolerance for BFGS stopping iterations. */ private double tol = 1E-5; /** * The maximum number of BFGS iterations. */ private int maxIter = 500; /** * Constructor. */ public Trainer() { } /** * Sets the regularization factor. λ > 0 gives a "regularized" * estimate of linear weights which often has superior generalization * performance, especially when the dimensionality is high. * * @param lambda regularization factor. */ public Trainer setRegularizationFactor(double lambda) { this.lambda = lambda; return this; } /** * Sets the tolerance for BFGS stopping iterations. * * @param tol the tolerance for stopping iterations. */ public Trainer setTolerance(double tol) { if (tol <= 0.0) { throw new IllegalArgumentException("Invalid tolerance: " + tol); } this.tol = tol; return this; } /** * Sets the maximum number of iterations. * * @param maxIter the maximum number of iterations. */ public Trainer setMaxNumIteration(int maxIter) { if (maxIter <= 0) { throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter); } this.maxIter = maxIter; return this; } @Override public LogisticRegression train(double[][] x, int[] y) { return new LogisticRegression(x, y, lambda, tol, maxIter); } } /** * Constructor. No regularization. * * @param x training samples. * @param y training labels in [0, k), where k is the number of classes. */ public LogisticRegression(double[][] x, int[] y) { this(x, y, 0.0); } /** * Constructor. * * @param x training samples. * @param y training labels in [0, k), where k is the number of classes. * @param lambda λ > 0 gives a "regularized" estimate of linear * weights which often has superior generalization performance, especially * when the dimensionality is high. */ public LogisticRegression(double[][] x, int[] y, double lambda) { this(x, y, lambda, 1E-5, 500); } /** * Constructor. * * @param x training samples. * @param y training labels in [0, k), where k is the number of classes. * @param lambda λ > 0 gives a "regularized" estimate of linear * weights which often has superior generalization performance, especially * when the dimensionality is high. * @param tol the tolerance for stopping iterations. * @param maxIter the maximum number of iterations. */ public LogisticRegression(double[][] x, int[] y, double lambda, double tol, int maxIter) { if (x.length != y.length) { throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length)); } if (lambda < 0.0) { throw new IllegalArgumentException("Invalid regularization factor: " + lambda); } if (tol <= 0.0) { throw new IllegalArgumentException("Invalid tolerance: " + tol); } if (maxIter <= 0) { throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter); } // class label set. int[] labels = Math.unique(y); Arrays.sort(labels); for (int i = 0; i < labels.length; i++) { if (labels[i] < 0) { throw new IllegalArgumentException("Negative class label: " + labels[i]); } if (i > 0 && labels[i] - labels[i-1] > 1) { throw new IllegalArgumentException("Missing class: " + labels[i]+1); } } k = labels.length; if (k < 2) { throw new IllegalArgumentException("Only one class."); } p = x[0].length; if (k == 2) { BinaryObjectiveFunction func = new BinaryObjectiveFunction(x, y, lambda); w = new double[p + 1]; L = 0.0; try { L = -Math.min(func, 5, w, tol, maxIter); } catch (Exception ex) { // If L-BFGS doesn't work, let's try BFGS. L = -Math.min(func, w, tol, maxIter); } } else { MultiClassObjectiveFunction func = new MultiClassObjectiveFunction(x, y, k, lambda); w = new double[k * (p + 1)]; L = 0.0; try { L = -Math.min(func, 5, w, tol, maxIter); } catch (Exception ex) { // If L-BFGS doesn't work, let's try BFGS. L = -Math.min(func, w, tol, maxIter); } W = new double[k][p+1]; for (int i = 0, m = 0; i < k; i++) { for (int j = 0; j <= p; j++, m++) { W[i][j] = w[m]; } } w = null; } } /** * Returns natural log(1+exp(x)) without overflow. */ private static double log1pe(double x) { double y = 0.0; if (x > 15) { y = x; } else { y += Math.log1p(Math.exp(x)); } return y; } /** * Binary-class logistic regression objective function. */ static class BinaryObjectiveFunction implements DifferentiableMultivariateFunction { /** * Training instances. */ double[][] x; /** * Training labels. */ int[] y; /** * Regularization factor. */ double lambda; /** * Parallel computing of objective function. */ List ftasks = null; /** * Parallel computing of objective function and gradient. */ List gtasks = null; /** * Constructor. */ BinaryObjectiveFunction(double[][] x, int[] y, double lambda) { this.x = x; this.y = y; this.lambda = lambda; int n = x.length; int m = MulticoreExecutor.getThreadPoolSize(); if (n >= 1000 && m >= 2) { ftasks = new ArrayList(m + 1); gtasks = new ArrayList(m + 1); int step = n / m; if (step < 100) { step = 100; } int start = 0; int end = step; for (int i = 0; i < m - 1; i++) { ftasks.add(new FTask(start, end)); gtasks.add(new GTask(start, end)); start += step; end += step; } ftasks.add(new FTask(start, n)); gtasks.add(new GTask(start, n)); } } /** * Task to calculate the objective function. */ class FTask implements Callable { /** * The parameter vector. */ double[] w; /** * The start index of data portion for this task. */ int start; /** * The end index of data portion for this task. */ int end; FTask(int start, int end) { this.start = start; this.end = end; } @Override public Double call() { double f = 0.0; for (int i = start; i < end; i++) { double wx = dot(x[i], w); f += log1pe(wx) - y[i] * wx; } return f; } } @Override public double f(double[] w) { double f = Double.NaN; int p = w.length - 1; if (ftasks != null) { for (FTask task : ftasks) { task.w = w; } try { f = 0.0; for (double fi : MulticoreExecutor.run(ftasks)) { f += fi; } } catch (Exception ex) { System.err.println(ex); f = Double.NaN; } } if (Double.isNaN(f)) { f = 0.0; int n = x.length; for (int i = 0; i < n; i++) { double wx = dot(x[i], w); f += log1pe(wx) - y[i] * wx; } } if (lambda != 0.0) { double w2 = 0.0; for (int i = 0; i < p; i++) { w2 += w[i] * w[i]; } f += 0.5 * lambda * w2; } return f; } /** * Task to calculate the objective function and gradient. */ class GTask implements Callable { /** * The parameter vector. */ double[] w; /** * The start index of data portion for this task. */ int start; /** * The end index of data portion for this task. */ int end; GTask(int start, int end) { this.start = start; this.end = end; } @Override public double[] call() { double f = 0.0; int p = w.length - 1; double[] g = new double[w.length + 1]; for (int i = start; i < end; i++) { double wx = dot(x[i], w); f += log1pe(wx) - y[i] * wx; double yi = y[i] - Math.logistic(wx); for (int j = 0; j < p; j++) { g[j] -= yi * x[i][j]; } g[p] -= yi; } g[w.length] = f; return g; } } @Override public double f(double[] w, double[] g) { double f = Double.NaN; int p = w.length - 1; Arrays.fill(g, 0.0); if (gtasks != null) { for (GTask task : gtasks) { task.w = w; } try { f = 0.0; for (double[] gi : MulticoreExecutor.run(gtasks)) { f += gi[w.length]; for (int i = 0; i < w.length; i++) { g[i] += gi[i]; } } } catch (Exception ex) { System.err.println(ex); f = Double.NaN; } } if (Double.isNaN(f)) { f = 0.0; int n = x.length; for (int i = 0; i < n; i++) { double wx = dot(x[i], w); f += log1pe(wx) - y[i] * wx; double yi = y[i] - Math.logistic(wx); for (int j = 0; j < p; j++) { g[j] -= yi * x[i][j]; } g[p] -= yi; } } if (lambda != 0.0) { double w2 = 0.0; for (int i = 0; i < p; i++) { w2 += w[i] * w[i]; } f += 0.5 * lambda * w2; for (int j = 0; j < p; j++) { g[j] += lambda * w[j]; } } return f; } } /** * Returns natural log without underflow. */ private static double log(double x) { double y = 0.0; if (x < 1E-300) { y = -690.7755; } else { y = Math.log(x); } return y; } /** * Multi-class logistic regression objective function. */ static class MultiClassObjectiveFunction implements DifferentiableMultivariateFunction { /** * Training instances. */ double[][] x; /** * Training labels. */ int[] y; /** * The number of classes. */ int k; /** * Regularization factor. */ double lambda; /** * Parallel computing of objective function. */ List ftasks = null; /** * Parallel computing of objective function and gradient. */ List gtasks = null; /** * Constructor. */ MultiClassObjectiveFunction(double[][] x, int[] y, int k, double lambda) { this.x = x; this.y = y; this.k = k; this.lambda = lambda; int n = x.length; int m = MulticoreExecutor.getThreadPoolSize(); if (n >= 1000 && m >= 2) { ftasks = new ArrayList(m + 1); gtasks = new ArrayList(m + 1); int step = n / m; if (step < 100) { step = 100; } int start = 0; int end = step; for (int i = 0; i < m - 1; i++) { ftasks.add(new FTask(start, end)); gtasks.add(new GTask(start, end)); start += step; end += step; } ftasks.add(new FTask(start, n)); gtasks.add(new GTask(start, n)); } } /** * Task to calculate the objective function. */ class FTask implements Callable { /** * The parameter vector. */ double[] w; /** * The start index of data portion for this task. */ int start; /** * The end index of data portion for this task. */ int end; FTask(int start, int end) { this.start = start; this.end = end; } @Override public Double call() { double f = 0.0; int p = x[0].length; double[] prob = new double[k]; for (int i = start; i < end; i++) { for (int j = 0; j < k; j++) { prob[j] = dot(x[i], w, j * (p + 1)); } softmax(prob); f -= log(prob[y[i]]); } return f; } } @Override public double f(double[] w) { double f = Double.NaN; int p = x[0].length; double[] prob = new double[k]; if (ftasks != null) { for (FTask task : ftasks) { task.w = w; } try { f = 0.0; for (double fi : MulticoreExecutor.run(ftasks)) { f += fi; } } catch (Exception ex) { System.err.println(ex); f = Double.NaN; } } if (Double.isNaN(f)) { f = 0.0; int n = x.length; for (int i = 0; i < n; i++) { for (int j = 0; j < k; j++) { prob[j] = dot(x[i], w, j * (p + 1)); } softmax(prob); f -= log(prob[y[i]]); } } if (lambda != 0.0) { double w2 = 0.0; for (int i = 0; i < k; i++) { for (int j = 0; j < p; j++) { w2 += Math.sqr(w[i*(p+1) + j]); } } f += 0.5 * lambda * w2; } return f; } /** * Task to calculate the objective function and gradient. */ class GTask implements Callable { /** * The parameter vector. */ double[] w; /** * The start index of data portion for this task. */ int start; /** * The end index of data portion for this task. */ int end; GTask(int start, int end) { this.start = start; this.end = end; } @Override public double[] call() { double f = 0.0; double[] g = new double[w.length+1]; int p = x[0].length; double[] prob = new double[k]; for (int i = start; i < end; i++) { for (int j = 0; j < k; j++) { prob[j] = dot(x[i], w, j * (p + 1)); } softmax(prob); f -= log(prob[y[i]]); double yi = 0.0; for (int j = 0; j < k; j++) { yi = (y[i] == j ? 1.0 : 0.0) - prob[j]; for (int l = 0, pos = j * (p + 1); l < p; l++) { g[pos + l] -= yi * x[i][l]; } g[j * (p + 1) + p] -= yi; } } g[w.length] = f; return g; } } @Override public double f(double[] w, double[] g) { double f = Double.NaN; int p = x[0].length; double[] prob = new double[k]; Arrays.fill(g, 0.0); if (gtasks != null) { for (GTask task : gtasks) { task.w = w; } try { f = 0.0; for (double[] gi : MulticoreExecutor.run(gtasks)) { f += gi[w.length]; for (int i = 0; i < w.length; i++) { g[i] += gi[i]; } } } catch (Exception ex) { System.err.println(ex); f = Double.NaN; } } if (Double.isNaN(f)) { f = 0.0; int n = x.length; for (int i = 0; i < n; i++) { for (int j = 0; j < k; j++) { prob[j] = dot(x[i], w, j * (p + 1)); } softmax(prob); f -= log(prob[y[i]]); double yi = 0.0; for (int j = 0; j < k; j++) { yi = (y[i] == j ? 1.0 : 0.0) - prob[j]; for (int l = 0, pos = j * (p + 1); l < p; l++) { g[pos + l] -= yi * x[i][l]; } g[j * (p + 1) + p] -= yi; } } } if (lambda != 0.0) { double w2 = 0.0; for (int i = 0; i < k; i++) { for (int j = 0; j < p; j++) { int pos = i * (p+1) + j; w2 += w[pos] * w[pos]; g[pos] += lambda * w[pos]; } } f += 0.5 * lambda * w2; } return f; } } /** * Calculate softmax function without overflow. */ private static void softmax(double[] prob) { double max = Double.NEGATIVE_INFINITY; for (int i = 0; i < prob.length; i++) { if (prob[i] > max) { max = prob[i]; } } double Z = 0.0; for (int i = 0; i < prob.length; i++) { double p = Math.exp(prob[i] - max); prob[i] = p; Z += p; } for (int i = 0; i < prob.length; i++) { prob[i] /= Z; } } /** * Returns the dot product between weight vector and x (augmented with 1). */ private static double dot(double[] x, double[] w) { int i = 0; double dot = 0.0; for (; i < x.length; i++) { dot += x[i] * w[i]; } return dot + w[i]; } /** * Returns the dot product between weight vector and x (augmented with 1). */ private static double dot(double[] x, double[] w, int pos) { int i = 0; double dot = 0.0; for (; i < x.length; i++) { dot += x[i] * w[pos+i]; } return dot + w[pos+i]; } /** * Returns the log-likelihood of model. */ public double loglikelihood() { return L; } @Override public int predict(double[] x) { return predict(x, null); } @Override public int predict(double[] x, double[] posteriori) { if (x.length != p) { throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, p)); } if (posteriori != null && posteriori.length != k) { throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, k)); } if (k == 2) { double f = 1.0 / (1.0 + Math.exp(-dot(x, w))); if (posteriori != null) { posteriori[0] = f; posteriori[1] = 1.0 - f; } if (f < 0.5) { return 0; } else { return 1; } } else { int label = -1; double max = Double.NEGATIVE_INFINITY; for (int i = 0; i < k; i++) { double prob = dot(x, W[i]); if (prob > max) { max = prob; label = i; } if (posteriori != null) { posteriori[i] = prob; } } if (posteriori != null) { double Z = 0.0; for (int i = 0; i < k; i++) { posteriori[i] = Math.exp(posteriori[i] - max); Z += posteriori[i]; } for (int i = 0; i < k; i++) { posteriori[i] /= Z; } } return label; } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy