All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.regression.ElasticNet Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2025 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile. If not, see .
 */
package smile.regression;

import java.util.Properties;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.math.matrix.Matrix;

/**
 * Elastic Net regularization. The elastic net is a regularized regression
 * method that linearly combines the L1 and L2 penalties of the lasso and ridge
 * methods.
 * 

* The elastic net problem can be reduced to a lasso problem on modified data * and response. And note that the penalty function of Elastic Net is strictly * convex so there is a unique global minimum, even if input data matrix is not * full rank. * *

References

*
    *
  1. Kevin P. Murphy: Machine Learning A Probabilistic Perspective, Section * 13.5.3, 2012
  2. *
  3. Zou, Hui, Hastie, Trevor: Regularization and Variable Selection via the * Elastic Net, 2005
  4. *
* * @author rayeaster */ public class ElasticNet { /** Private constructor to prevent object creation. */ private ElasticNet() { } /** * Elastic Net hyperparameters. * @param lambda1 the L1 shrinkage/regularization parameter * @param lambda2 the L2 shrinkage/regularization parameter * @param tol the tolerance of convergence test (relative target duality gap). * @param maxIter the maximum number of IPM (Newton) iterations. * @param alpha the minimum fraction of decrease in the objective function. * @param beta the step size decrease factor * @param eta the tolerance for PCG termination. * @param lsMaxIter the maximum number of backtracking line search iterations. * @param pcgMaxIter the maximum number of PCG iterations. */ public record Options(double lambda1, double lambda2, double tol, int maxIter, double alpha, double beta, double eta, int lsMaxIter, int pcgMaxIter) { /** Constructor. */ public Options { if (lambda1 <= 0) { throw new IllegalArgumentException("Please use Ridge instead, wrong L1 portion setting: " + lambda1); } if (lambda2 <= 0) { throw new IllegalArgumentException("Please use LASSO instead, wrong L2 portion setting: " + lambda2); } if (tol <= 0) { throw new IllegalArgumentException("Invalid tolerance: " + tol); } if (maxIter <= 0) { throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter); } if (alpha <= 0.0) { throw new IllegalArgumentException("Invalid alpha: " + alpha); } if (beta <= 0.0) { throw new IllegalArgumentException("Invalid beta: " + beta); } if (eta <= 0.0) { throw new IllegalArgumentException("Invalid eta: " + eta); } if (lsMaxIter <= 0) { throw new IllegalArgumentException("Invalid maximum number of line search iterations: " + lsMaxIter); } if (pcgMaxIter <= 0) { throw new IllegalArgumentException("Invalid maximum number of PCG iterations: " + pcgMaxIter); } } /** * Constructor. * @param lambda1 the L1 shrinkage/regularization parameter * @param lambda2 the L2 shrinkage/regularization parameter */ public Options(double lambda1, double lambda2) { this(lambda1, lambda2, 1E-4, 1000); } /** * Constructor. * @param lambda1 the L1 shrinkage/regularization parameter * @param lambda2 the L2 shrinkage/regularization parameter * @param tol the tolerance of convergence test (relative target duality gap). * @param maxIter the maximum number of IPM (Newton) iterations. */ public Options(double lambda1, double lambda2, double tol, int maxIter) { this(lambda1, lambda2, tol, maxIter, 0.01, 0.5, 1E-3, 100, 5000); } /** * Returns the persistent set of hyperparameters including *
    *
  • smile.elastic_net.lambda1 is the L1 shrinkage/regularization parameter *
  • smile.elastic_net.lambda2 is the L2 shrinkage/regularization parameter *
  • smile.elastic_net.tolerance is the tolerance for stopping iterations (relative target duality gap). *
  • smile.elastic_net.iterations is the maximum number of IPM (Newton) iterations. *
* @return the persistent set. */ public Properties toProperties() { Properties props = new Properties(); props.setProperty("smile.elastic_net.lambda1", Double.toString(lambda1)); props.setProperty("smile.elastic_net.lambda2", Double.toString(lambda2)); props.setProperty("smile.elastic_net.tolerance", Double.toString(tol)); props.setProperty("smile.elastic_net.iterations", Integer.toString(maxIter)); props.setProperty("smile.elastic_net.alpha", Double.toString(alpha)); props.setProperty("smile.elastic_net.beta", Double.toString(beta)); props.setProperty("smile.elastic_net.eta", Double.toString(eta)); props.setProperty("smile.elastic_net.line_search_iterations", Integer.toString(lsMaxIter)); props.setProperty("smile.elastic_net.pcg_iterations", Integer.toString(pcgMaxIter)); return props; } /** * Returns the options from properties. * * @param props the hyperparameters. * @return the options. */ public static Options of(Properties props) { double lambda1 = Double.parseDouble(props.getProperty("smile.elastic_net.lambda1")); double lambda2 = Double.parseDouble(props.getProperty("smile.elastic_net.lambda2")); double tol = Double.parseDouble(props.getProperty("smile.elastic_net.tolerance", "1E-4")); int maxIter = Integer.parseInt(props.getProperty("smile.elastic_net.iterations", "1000")); double alpha = Double.parseDouble(props.getProperty("smile.elastic_net.alpha", "0.01")); double beta = Double.parseDouble(props.getProperty("smile.elastic_net.beta", "0.5")); double eta = Double.parseDouble(props.getProperty("smile.elastic_net.eta", "1E-3")); int lsMaxIter = Integer.parseInt(props.getProperty("smile.elastic_net.line_search_iterations", "100")); int pcgMaxIter = Integer.parseInt(props.getProperty("smile.elastic_net.pcg_iterations", "5000")); return new Options(lambda1, lambda2, tol, maxIter, alpha, beta, eta, lsMaxIter, pcgMaxIter); } } /** * Fits an Elastic Net model. * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * NO NEED to include a constant column of 1s for bias. * @param lambda1 the L1 shrinkage/regularization parameter * @param lambda2 the L2 shrinkage/regularization parameter * @return the model. */ public static LinearModel fit(Formula formula, DataFrame data, double lambda1, double lambda2) { return fit(formula, data, new Options(lambda1, lambda2)); } /** * Fits an Elastic Net model. * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * NO NEED to include a constant column of 1s for bias. * @param options the hyperparameters. * @return the model. */ public static LinearModel fit(Formula formula, DataFrame data, Options options) { double c = 1 / Math.sqrt(1 + options.lambda2); formula = formula.expand(data.schema()); StructType schema = formula.bind(data.schema()); Matrix X = formula.matrix(data, false); double[] y = formula.y(data).toDoubleArray(); int n = X.nrow(); int p = X.ncol(); double[] center = X.colMeans(); double[] scale = X.colSds(); // Pads 0 at the tail double[] y2 = new double[n + p]; // Center y2 before calling LASSO. // Otherwise, padding zeros become negative when LASSO centers y2 again. double ym = MathEx.mean(y); for (int i = 0; i < n; i++) { y2[i] = y[i] - ym; } // Scales the original data array and pads a weighted identity matrix Matrix X2 = new Matrix(X.nrow()+ p, p); double padding = c * Math.sqrt(options.lambda2); for (int j = 0; j < p; j++) { for (int i = 0; i < n; i++) { X2.set(i, j, c * (X.get(i, j) - center[j]) / scale[j]); } X2.set(j + n, j, padding); } var lasso = new LASSO.Options(options.lambda1 * c, options.tol, options.maxIter, options.alpha, options.beta, options.eta, options.lsMaxIter, options.pcgMaxIter); double[] w = LASSO.train(X2, y2, lasso); for (int i = 0; i < p; i++) { w[i] = c * w[i] / scale[i]; } double b = ym - MathEx.dot(w, center); return new LinearModel(formula, schema, X, y, w, b); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy