weka.classifiers.functions.SGDText Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* SGDText.java
* Copyright (C) 2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.functions;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.UpdateableBatchProcessor;
import weka.classifiers.UpdateableClassifier;
import weka.core.Aggregateable;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.stemmers.NullStemmer;
import weka.core.stemmers.Stemmer;
import weka.core.stopwords.Null;
import weka.core.stopwords.StopwordsHandler;
import weka.core.tokenizers.Tokenizer;
import weka.core.tokenizers.WordTokenizer;
/**
* Implements stochastic gradient descent for learning a linear binary class SVM or binary class logistic regression on text data. Operates directly (and only) on String attributes. Other types of input attributes are accepted but ignored during training and classification.
*
*
* Valid options are:
*
* -F
* Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression)
* (default = 0)
*
* -outputProbs
* Output probabilities for SVMs (fits a logsitic
* model to the output of the SVM)
*
* -L
* The learning rate (default = 0.01).
*
* -R <double>
* The lambda regularization constant (default = 0.0001)
*
* -E <integer>
* The number of epochs to perform (batch learning only, default = 500)
*
* -W
* Use word frequencies instead of binary bag of words.
*
* -P <# instances>
* How often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)
*
* -M <double>
* Minimum word frequency. Words with less than this frequence are ignored.
* If periodic pruning is turned on then this is also used to determine which
* words to remove from the dictionary (default = 3).
*
* -min-coeff <double>
* Minimum absolute value of coefficients in the model.
* If periodic pruning is turned on then this
* is also used to prune words from the dictionary
* (default = 0.001
*
* -normalize
* Normalize document length (use in conjunction with -norm and -lnorm)
*
* -norm <num>
* Specify the norm that each instance must have (default 1.0)
*
* -lnorm <num>
* Specify L-norm to use (default 2.0)
*
* -lowercase
* Convert all tokens to lowercase before adding to the dictionary.
*
* -stopwords-handler
* The stopwords handler to use (default Null).
*
* -tokenizer <spec>
* The tokenizing algorihtm (classname plus parameters) to use.
* (default: weka.core.tokenizers.WordTokenizer)
*
* -stemmer <spec>
* The stemmering algorihtm (classname plus parameters) to use.
*
* -S <num>
* Random number seed.
* (default 1)
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz)
*
*/
public class SGDText extends RandomizableClassifier implements
UpdateableClassifier, UpdateableBatchProcessor,
WeightedInstancesHandler, Aggregateable {
/** For serialization */
private static final long serialVersionUID = 7200171484002029584L;
public static class Count implements Serializable {
/**
* For serialization
*/
private static final long serialVersionUID = 2104201532017340967L;
public double m_count;
public double m_weight;
public Count(double c) {
m_count = c;
}
}
/**
* The number of training instances at which to periodically prune the
* dictionary of min frequency words. Empty or null string indicates don't
* prune
*/
protected int m_periodicP = 0;
/**
* Only consider dictionary words (features) that occur at least this many
* times.
*/
protected double m_minWordP = 3;
/**
* Prune terms from the model that have a coefficient smaller than this.
*/
protected double m_minAbsCoefficient = 0.001;
/** Use word frequencies rather than bag-of-words if true */
protected boolean m_wordFrequencies = false;
/** Whether to normalized document length or not */
protected boolean m_normalize = false;
/** The length that each document vector should have in the end */
protected double m_norm = 1.0;
/** The L-norm to use */
protected double m_lnorm = 2.0;
/** The dictionary (and term weights) */
protected LinkedHashMap m_dictionary;
/** Stopword handler to use. */
protected StopwordsHandler m_StopwordsHandler = new Null();
/** The tokenizer to use */
protected Tokenizer m_tokenizer = new WordTokenizer();
/** Whether or not to convert all tokens to lowercase */
protected boolean m_lowercaseTokens;
/** The stemming algorithm. */
protected Stemmer m_stemmer = new NullStemmer();
/** The regularization parameter */
protected double m_lambda = 0.0001;
/** The learning rate */
protected double m_learningRate = 0.01;
/** Holds the current iteration number */
protected double m_t;
/** Holds the bias term */
protected double m_bias;
/** The number of training instances */
protected double m_numInstances;
/** The header of the training data */
protected Instances m_data;
/**
* The number of epochs to perform (batch learning). Total iterations is
* m_epochs * num instances
*/
protected int m_epochs = 500;
/**
* Holds the current document vector (LinkedHashMap is more efficient when
* iterating over EntrySet than HashMap)
*/
protected transient LinkedHashMap m_inputVector;
/** the hinge loss function. */
public static final int HINGE = 0;
/** the log loss function. */
public static final int LOGLOSS = 1;
/** The current loss function to minimize */
protected int m_loss = HINGE;
/** Loss functions to choose from */
public static final Tag[] TAGS_SELECTION = {
new Tag(HINGE, "Hinge loss (SVM)"),
new Tag(LOGLOSS, "Log loss (logistic regression)") };
/** Used for producing probabilities for SVM via SGD logistic regression */
protected SGD m_svmProbs;
/**
* True if a logistic regression is to be fit to the output of the SVM for
* producing probability estimates
*/
protected boolean m_fitLogistic = false;
protected Instances m_fitLogisticStructure;
protected double dloss(double z) {
if (m_loss == HINGE) {
return (z < 1) ? 1 : 0;
} else {
// log loss
if (z < 0) {
return 1.0 / (Math.exp(z) + 1.0);
} else {
double t = Math.exp(-z);
return t / (t + 1);
}
}
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
@Override
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.STRING_ATTRIBUTES);
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
result.enable(Capability.BINARY_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
// instances
result.setMinimumNumberInstances(0);
return result;
}
/**
* the stemming algorithm to use, null means no stemming at all (i.e., the
* NullStemmer is used).
*
* @param value the configured stemming algorithm, or null
* @see NullStemmer
*/
public void setStemmer(Stemmer value) {
if (value != null) {
m_stemmer = value;
} else {
m_stemmer = new NullStemmer();
}
}
/**
* Returns the current stemming algorithm, null if none is used.
*
* @return the current stemming algorithm, null if none set
*/
public Stemmer getStemmer() {
return m_stemmer;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String stemmerTipText() {
return "The stemming algorithm to use on the words.";
}
/**
* the tokenizer algorithm to use.
*
* @param value the configured tokenizing algorithm
*/
public void setTokenizer(Tokenizer value) {
m_tokenizer = value;
}
/**
* Returns the current tokenizer algorithm.
*
* @return the current tokenizer algorithm
*/
public Tokenizer getTokenizer() {
return m_tokenizer;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String tokenizerTipText() {
return "The tokenizing algorithm to use on the strings.";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String useWordFrequenciesTipText() {
return "Use word frequencies rather than binary "
+ "bag of words representation";
}
/**
* Set whether to use word frequencies rather than binary bag of words
* representation.
*
* @param u true if word frequencies are to be used.
*/
public void setUseWordFrequencies(boolean u) {
m_wordFrequencies = u;
}
/**
* Get whether to use word frequencies rather than binary bag of words
* representation.
*
* @return true if word frequencies are to be used.
*/
public boolean getUseWordFrequencies() {
return m_wordFrequencies;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lowercaseTokensTipText() {
return "Whether to convert all tokens to lowercase";
}
/**
* Set whether to convert all tokens to lowercase
*
* @param l true if all tokens are to be converted to lowercase
*/
public void setLowercaseTokens(boolean l) {
m_lowercaseTokens = l;
}
/**
* Get whether to convert all tokens to lowercase
*
* @return true true if all tokens are to be converted to lowercase
*/
public boolean getLowercaseTokens() {
return m_lowercaseTokens;
}
/**
* Sets the stopwords handler to use.
*
* @param value the stopwords handler, if null, Null is used
*/
public void setStopwordsHandler(StopwordsHandler value) {
if (value != null) {
m_StopwordsHandler = value;
} else {
m_StopwordsHandler = new Null();
}
}
/**
* Gets the stopwords handler.
*
* @return the stopwords handler
*/
public StopwordsHandler getStopwordsHandler() {
return m_StopwordsHandler;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String stopwordsHandlerTipText() {
return "The stopwords handler to use (Null means no stopwords are used).";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String periodicPruningTipText() {
return "How often (number of instances) to prune "
+ "the dictionary of low frequency terms. "
+ "0 means don't prune. Setting a positive "
+ "integer n means prune after every n instances";
}
/**
* Set how often to prune the dictionary
*
* @param p how often to prune
*/
public void setPeriodicPruning(int p) {
m_periodicP = p;
}
/**
* Get how often to prune the dictionary
*
* @return how often to prune the dictionary
*/
public int getPeriodicPruning() {
return m_periodicP;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String minWordFrequencyTipText() {
return "Ignore any words that don't occur at least "
+ "min frequency times in the training data. If periodic "
+ "pruning is turned on, then the dictionary is pruned "
+ "according to this value";
}
/**
* Set the minimum word frequency. Words that don't occur at least min freq
* times are ignored when updating weights. If periodic pruning is turned on,
* then min frequency is used when removing words from the dictionary.
*
* @param minFreq the minimum word frequency to use
*/
public void setMinWordFrequency(double minFreq) {
m_minWordP = minFreq;
}
/**
* Get the minimum word frequency. Words that don't occur at least min freq
* times are ignored when updating weights. If periodic pruning is turned on,
* then min frequency is used when removing words from the dictionary.
*
* @return the minimum word frequency to use
*/
public double getMinWordFrequency() {
return m_minWordP;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String minAbsoluteCoefficientValueTipText() {
return "The minimum absolute magnitude for model coefficients. Terms "
+ "with weights smaller than this value are ignored. If periodic "
+ "pruning is turned on then this is also used to determine if a "
+ "word should be removed from the dictionary.";
}
/**
* Set the minimum absolute magnitude for model coefficients. Terms with
* weights smaller than this value are ignored. If periodic pruning is turned
* on then this is also used to determine if a word should be removed from the
* dictionary
*
* @param minCoeff the minimum absolute value of a model coefficient
*/
public void setMinAbsoluteCoefficientValue(double minCoeff) {
m_minAbsCoefficient = minCoeff;
}
/**
* Get the minimum absolute magnitude for model coefficients. Terms with
* weights smaller than this value are ignored. If periodic pruning is turned
* on this then is also used to determine if a word should be removed from the
* dictionary
*
* @return the minimum absolute value of a model coefficient
*/
public double getMinAbsoluteCoefficientValue() {
return m_minAbsCoefficient;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String normalizeDocLengthTipText() {
return "If true then document length is normalized according "
+ "to the settings for norm and lnorm";
}
/**
* Set whether to normalize the length of each document
*
* @param norm true if document lengths is to be normalized
*/
public void setNormalizeDocLength(boolean norm) {
m_normalize = norm;
}
/**
* Get whether to normalize the length of each document
*
* @return true if document lengths is to be normalized
*/
public boolean getNormalizeDocLength() {
return m_normalize;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String normTipText() {
return "The norm of the instances after normalization.";
}
/**
* Get the instance's Norm.
*
* @return the Norm
*/
public double getNorm() {
return m_norm;
}
/**
* Set the norm of the instances
*
* @param newNorm the norm to wich the instances must be set
*/
public void setNorm(double newNorm) {
m_norm = newNorm;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String LNormTipText() {
return "The LNorm to use for document length normalization.";
}
/**
* Get the L Norm used.
*
* @return the L-norm used
*/
public double getLNorm() {
return m_lnorm;
}
/**
* Set the L-norm to used
*
* @param newLNorm the L-norm
*/
public void setLNorm(double newLNorm) {
m_lnorm = newLNorm;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lambdaTipText() {
return "The regularization constant. (default = 0.0001)";
}
/**
* Set the value of lambda to use
*
* @param lambda the value of lambda to use
*/
public void setLambda(double lambda) {
m_lambda = lambda;
}
/**
* Get the current value of lambda
*
* @return the current value of lambda
*/
public double getLambda() {
return m_lambda;
}
/**
* Set the learning rate.
*
* @param lr the learning rate to use.
*/
public void setLearningRate(double lr) {
m_learningRate = lr;
}
/**
* Get the learning rate.
*
* @return the learning rate
*/
public double getLearningRate() {
return m_learningRate;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String learningRateTipText() {
return "The learning rate.";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String epochsTipText() {
return "The number of epochs to perform (batch learning). "
+ "The total number of iterations is epochs * num" + " instances.";
}
/**
* Set the number of epochs to use
*
* @param e the number of epochs to use
*/
public void setEpochs(int e) {
m_epochs = e;
}
/**
* Get current number of epochs
*
* @return the current number of epochs
*/
public int getEpochs() {
return m_epochs;
}
/**
* Set the loss function to use.
*
* @param function the loss function to use.
*/
public void setLossFunction(SelectedTag function) {
if (function.getTags() == TAGS_SELECTION) {
m_loss = function.getSelectedTag().getID();
}
}
/**
* Get the current loss function.
*
* @return the current loss function.
*/
public SelectedTag getLossFunction() {
return new SelectedTag(m_loss, TAGS_SELECTION);
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lossFunctionTipText() {
return "The loss function to use. Hinge loss (SVM), "
+ "log loss (logistic regression) or " + "squared loss (regression).";
}
/**
* Set whether to fit a logistic regression (itself trained using SGD) to the
* outputs of the SVM (if an SVM is being learned).
*
* @param o true if a logistic regression is to be fit to the output of the
* SVM to produce probability estimates.
*/
public void setOutputProbsForSVM(boolean o) {
m_fitLogistic = o;
}
/**
* Get whether to fit a logistic regression (itself trained using SGD) to the
* outputs of the SVM (if an SVM is being learned).
*
* @return true if a logistic regression is to be fit to the output of the SVM
* to produce probability estimates.
*/
public boolean getOutputProbsForSVM() {
return m_fitLogistic;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String outputProbsForSVMTipText() {
return "Fit a logistic regression to the output of SVM for "
+ "producing probability estimates";
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration
© 2015 - 2024 Weber Informatics LLC | Privacy Policy