All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.functions.SGD Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

The newest version!
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    SGD.java
 *    Copyright (C) 2009-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.functions;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.RandomizableClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Aggregateable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

/**
 
 * Implements stochastic gradient descent for learning various linear models (binary class SVM, binary class logistic regression, squared loss, Huber loss and epsilon-insensitive loss linear regression). Globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data.
* For numeric class attributes, the squared, Huber or epsilon-insensitve loss function must be used. Epsilon-insensitive and Huber loss may require a much higher learning rate. *

* * Valid options are:

* *

 -F
 *  Set the loss function to minimize.
 *  0 = hinge loss (SVM), 1 = log loss (logistic regression),
 *  2 = squared loss (regression), 3 = epsilon insensitive loss (regression),
 *  4 = Huber loss (regression).
 *  (default = 0)
* *
 -L
 *  The learning rate. If normalization is
 *  turned off (as it is automatically for streaming data), then the
 *  default learning rate will need to be reduced (try 0.0001).
 *  (default = 0.01).
* *
 -R <double>
 *  The lambda regularization constant (default = 0.0001)
* *
 -E <integer>
 *  The number of epochs to perform (batch learning only, default = 500)
* *
 -C <double>
 *  The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
* *
 -N
 *  Don't normalize the data
* *
 -M
 *  Don't replace missing values
* *
 -S <num>
 *  Random number seed.
 *  (default 1)
* *
 -output-debug-info
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
* *
 -do-not-check-capabilities
 *  If set, classifier capabilities are not checked before classifier is built
 *  (use with caution).
* * * @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz) * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) * @version $Revision: 15519 $ * */ public class SGD extends RandomizableClassifier implements UpdateableClassifier, OptionHandler, Aggregateable { /** For serialization */ private static final long serialVersionUID = -3732968666673530290L; /** Replace missing values */ protected ReplaceMissingValues m_replaceMissing; /** * Convert nominal attributes to numerically coded binary ones. Uses * supervised NominalToBinary in the batch learning case */ protected Filter m_nominalToBinary; /** Normalize the training data */ protected Normalize m_normalize; /** The regularization parameter */ protected double m_lambda = 0.0001; /** The learning rate */ protected double m_learningRate = 0.01; /** Stores the weights (+ bias in the last element) */ protected double[] m_weights; /** The epsilon parameter for epsilon insensitive and Huber loss */ protected double m_epsilon = 1e-3; /** Holds the current iteration number */ protected double m_t; /** The number of training instances */ protected double m_numInstances; /** * The number of epochs to perform (batch learning). Total iterations is * m_epochs * num instances */ protected int m_epochs = 500; /** * Turn off normalization of the input data. This option gets forced for * incremental training. */ protected boolean m_dontNormalize = false; /** * Turn off global replacement of missing values. Missing values will be * ignored instead. This option gets forced for incremental training. */ protected boolean m_dontReplaceMissing = false; /** Holds the header of the training data */ protected Instances m_data; /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class if (m_loss == SQUAREDLOSS || m_loss == EPSILON_INSENSITIVE || m_loss == HUBER) result.enable(Capability.NUMERIC_CLASS); else result.enable(Capability.BINARY_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); // instances result.setMinimumNumberInstances(0); return result; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String epsilonTipText() { return "The epsilon threshold for epsilon insensitive and Huber " + "loss. An error with absolute value less that this " + "threshold has loss of 0 for epsilon insensitive loss. " + "For Huber loss this is the boundary between the quadratic " + "and linear parts of the loss function."; } /** * Set the epsilon threshold on the error for epsilon insensitive and Huber * loss functions * * @param e the value of epsilon to use */ public void setEpsilon(double e) { m_epsilon = e; } /** * Get the epsilon threshold on the error for epsilon insensitive and Huber * loss functions * * @return the value of epsilon to use */ public double getEpsilon() { return m_epsilon; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String lambdaTipText() { return "The regularization constant. (default = 0.0001)"; } /** * Set the value of lambda to use * * @param lambda the value of lambda to use */ public void setLambda(double lambda) { m_lambda = lambda; } /** * Get the current value of lambda * * @return the current value of lambda */ public double getLambda() { return m_lambda; } /** * Set the learning rate. * * @param lr the learning rate to use. */ public void setLearningRate(double lr) { m_learningRate = lr; } /** * Get the learning rate. * * @return the learning rate */ public double getLearningRate() { return m_learningRate; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String learningRateTipText() { return "The learning rate. If normalization is turned off " + "(as it is automatically for streaming data), then" + "the default learning rate will need to be reduced (" + "try 0.0001)."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String epochsTipText() { return "The number of epochs to perform (batch learning). " + "The total number of iterations is epochs * num" + " instances."; } /** * Set the number of epochs to use * * @param e the number of epochs to use */ public void setEpochs(int e) { m_epochs = e; } /** * Get current number of epochs * * @return the current number of epochs */ public int getEpochs() { return m_epochs; } /** * Turn normalization off/on. * * @param m true if normalization is to be disabled. */ public void setDontNormalize(boolean m) { m_dontNormalize = m; } /** * Get whether normalization has been turned off. * * @return true if normalization has been disabled. */ public boolean getDontNormalize() { return m_dontNormalize; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String dontNormalizeTipText() { return "Turn normalization off"; } /** * Turn global replacement of missing values off/on. If turned off, then * missing values are effectively ignored. * * @param m true if global replacement of missing values is to be turned off. */ public void setDontReplaceMissing(boolean m) { m_dontReplaceMissing = m; } /** * Get whether global replacement of missing values has been disabled. * * @return true if global replacement of missing values has been turned off */ public boolean getDontReplaceMissing() { return m_dontReplaceMissing; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String dontReplaceMissingTipText() { return "Turn off global replacement of missing values"; } /** * Set the loss function to use. * * @param function the loss function to use. */ public void setLossFunction(SelectedTag function) { if (function.getTags() == TAGS_SELECTION) { m_loss = function.getSelectedTag().getID(); } } /** * Get the current loss function. * * @return the current loss function. */ public SelectedTag getLossFunction() { return new SelectedTag(m_loss, TAGS_SELECTION); } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String lossFunctionTipText() { return "The loss function to use. Hinge loss (SVM), " + "log loss (logistic regression) or " + "squared loss (regression)."; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ @Override public Enumeration




© 2015 - 2024 Weber Informatics LLC | Privacy Policy