weka.classifiers.functions.SGDText Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* SGDText.java
* Copyright (C) 2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.functions;
import java.io.File;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Aggregateable;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Stopwords;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.stemmers.NullStemmer;
import weka.core.stemmers.Stemmer;
import weka.core.tokenizers.Tokenizer;
import weka.core.tokenizers.WordTokenizer;
/**
* Implements stochastic gradient descent for learning
* a linear binary class SVM or binary class logistic regression on text data.
* Operates directly (and only) on String attributes. Other types of input
* attributes are accepted but ignored during training and classification.
*
*
*
* Valid options are:
*
*
*
* -F
* Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression)
* (default = 0)
*
*
*
* -outputProbs
* Output probabilities for SVMs (fits a logsitic
* model to the output of the SVM)
*
*
*
* -L
* The learning rate (default = 0.01).
*
*
*
* -R <double>
* The lambda regularization constant (default = 0.0001)
*
*
*
* -E <integer>
* The number of epochs to perform (batch learning only, default = 500)
*
*
*
* -W
* Use word frequencies instead of binary bag of words.
*
*
*
* -P <# instances>
* How often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)
*
*
*
* -M <double>
* Minimum word frequency. Words with less than this frequence are ignored.
* If periodic pruning is turned on then this is also used to determine which
* words to remove from the dictionary (default = 3).
*
*
*
* -normalize
* Normalize document length (use in conjunction with -norm and -lnorm
*
*
*
* -norm <num>
* Specify the norm that each instance must have (default 1.0)
*
*
*
* -lnorm <num>
* Specify L-norm to use (default 2.0)
*
*
*
* -lowercase
* Convert all tokens to lowercase before adding to the dictionary.
*
*
*
* -stoplist
* Ignore words that are in the stoplist.
*
*
*
* -stopwords <file>
* A file containing stopwords to override the default ones.
* Using this option automatically sets the flag ('-stoplist') to use the
* stoplist if the file exists.
* Format: one stopword per line, lines starting with '#'
* are interpreted as comments and ignored.
*
*
*
* -tokenizer <spec>
* The tokenizing algorihtm (classname plus parameters) to use.
* (default: weka.core.tokenizers.WordTokenizer)
*
*
*
* -stemmer <spec>
* The stemmering algorihtm (classname plus parameters) to use.
*
*
*
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz)
*
*/
public class SGDText extends RandomizableClassifier implements
UpdateableClassifier, WeightedInstancesHandler, Aggregateable {
/** For serialization */
private static final long serialVersionUID = 7200171484002029584L;
public static class Count implements Serializable {
/**
* For serialization
*/
private static final long serialVersionUID = 2104201532017340967L;
public double m_count;
public double m_weight;
public Count(double c) {
m_count = c;
}
}
/**
* The number of training instances at which to periodically prune the
* dictionary of min frequency words. Empty or null string indicates don't
* prune
*/
protected int m_periodicP = 0;
/**
* Only consider dictionary words (features) that occur at least this many
* times
*/
protected double m_minWordP = 3;
/** Use word frequencies rather than bag-of-words if true */
protected boolean m_wordFrequencies = false;
/** Whether to normalized document length or not */
protected boolean m_normalize = false;
/** The length that each document vector should have in the end */
protected double m_norm = 1.0;
/** The L-norm to use */
protected double m_lnorm = 2.0;
/** The dictionary (and term weights) */
protected LinkedHashMap m_dictionary;
/** Default (rainbow) stopwords */
protected transient Stopwords m_stopwords;
/**
* a file containing stopwords for using others than the default Rainbow ones.
*/
protected File m_stopwordsFile = new File(System.getProperty("user.dir"));
/** The tokenizer to use */
protected Tokenizer m_tokenizer = new WordTokenizer();
/** Whether or not to convert all tokens to lowercase */
protected boolean m_lowercaseTokens;
/** The stemming algorithm. */
protected Stemmer m_stemmer = new NullStemmer();
/** Whether or not to use a stop list */
protected boolean m_useStopList;
/** The regularization parameter */
protected double m_lambda = 0.0001;
/** The learning rate */
protected double m_learningRate = 0.01;
/** Holds the current iteration number */
protected double m_t;
/** Holds the bias term */
protected double m_bias;
/** The number of training instances */
protected double m_numInstances;
/** The header of the training data */
protected Instances m_data;
/**
* The number of epochs to perform (batch learning). Total iterations is
* m_epochs * num instances
*/
protected int m_epochs = 500;
/**
* Holds the current document vector (LinkedHashMap is more efficient when
* iterating over EntrySet than HashMap)
*/
protected transient LinkedHashMap m_inputVector;
/** the hinge loss function. */
public static final int HINGE = 0;
/** the log loss function. */
public static final int LOGLOSS = 1;
/** The current loss function to minimize */
protected int m_loss = HINGE;
/** Loss functions to choose from */
public static final Tag[] TAGS_SELECTION = {
new Tag(HINGE, "Hinge loss (SVM)"),
new Tag(LOGLOSS, "Log loss (logistic regression)") };
/** Used for producing probabilities for SVM via SGD logistic regression */
protected SGD m_svmProbs;
/**
* True if a logistic regression is to be fit to the output of the SVM for
* producing probability estimates
*/
protected boolean m_fitLogistic = false;
protected Instances m_fitLogisticStructure;
protected double dloss(double z) {
if (m_loss == HINGE) {
return (z < 1) ? 1 : 0;
} else {
// log loss
if (z < 0) {
return 1.0 / (Math.exp(z) + 1.0);
} else {
double t = Math.exp(-z);
return t / (t + 1);
}
}
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
@Override
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.STRING_ATTRIBUTES);
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
result.enable(Capability.BINARY_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
// instances
result.setMinimumNumberInstances(0);
return result;
}
/**
* the stemming algorithm to use, null means no stemming at all (i.e., the
* NullStemmer is used).
*
* @param value the configured stemming algorithm, or null
* @see NullStemmer
*/
public void setStemmer(Stemmer value) {
if (value != null) {
m_stemmer = value;
} else {
m_stemmer = new NullStemmer();
}
}
/**
* Returns the current stemming algorithm, null if none is used.
*
* @return the current stemming algorithm, null if none set
*/
public Stemmer getStemmer() {
return m_stemmer;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String stemmerTipText() {
return "The stemming algorithm to use on the words.";
}
/**
* the tokenizer algorithm to use.
*
* @param value the configured tokenizing algorithm
*/
public void setTokenizer(Tokenizer value) {
m_tokenizer = value;
}
/**
* Returns the current tokenizer algorithm.
*
* @return the current tokenizer algorithm
*/
public Tokenizer getTokenizer() {
return m_tokenizer;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String tokenizerTipText() {
return "The tokenizing algorithm to use on the strings.";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String useWordFrequenciesTipText() {
return "Use word frequencies rather than binary "
+ "bag of words representation";
}
/**
* Set whether to use word frequencies rather than binary bag of words
* representation.
*
* @param u true if word frequencies are to be used.
*/
public void setUseWordFrequencies(boolean u) {
m_wordFrequencies = u;
}
/**
* Get whether to use word frequencies rather than binary bag of words
* representation.
*
* @param u true if word frequencies are to be used.
*/
public boolean getUseWordFrequencies() {
return m_wordFrequencies;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lowercaseTokensTipText() {
return "Whether to convert all tokens to lowercase";
}
/**
* Set whether to convert all tokens to lowercase
*
* @param l true if all tokens are to be converted to lowercase
*/
public void setLowercaseTokens(boolean l) {
m_lowercaseTokens = l;
}
/**
* Get whether to convert all tokens to lowercase
*
* @return true true if all tokens are to be converted to lowercase
*/
public boolean getLowercaseTokens() {
return m_lowercaseTokens;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String useStopListTipText() {
return "If true, ignores all words that are on the stoplist.";
}
/**
* Set whether to ignore all words that are on the stoplist.
*
* @param u true to ignore all words on the stoplist.
*/
public void setUseStopList(boolean u) {
m_useStopList = u;
}
/**
* Get whether to ignore all words that are on the stoplist.
*
* @return true to ignore all words on the stoplist.
*/
public boolean getUseStopList() {
return m_useStopList;
}
/**
* sets the file containing the stopwords, null or a directory unset the
* stopwords. If the file exists, it automatically turns on the flag to use
* the stoplist.
*
* @param value the file containing the stopwords
*/
public void setStopwords(File value) {
if (value == null) {
value = new File(System.getProperty("user.dir"));
}
m_stopwordsFile = value;
if (value.exists() && value.isFile()) {
setUseStopList(true);
}
}
/**
* returns the file used for obtaining the stopwords, if the file represents a
* directory then the default ones are used.
*
* @return the file containing the stopwords
*/
public File getStopwords() {
return m_stopwordsFile;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String stopwordsTipText() {
return "The file containing the stopwords (if this is a directory then the default ones are used).";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String periodicPruningTipText() {
return "How often (number of instances) to prune "
+ "the dictionary of low frequency terms. "
+ "0 means don't prune. Setting a positive "
+ "integer n means prune after every n instances";
}
/**
* Set how often to prune the dictionary
*
* @param p how often to prune
*/
public void setPeriodicPruning(int p) {
m_periodicP = p;
}
/**
* Get how often to prune the dictionary
*
* @return how often to prune the dictionary
*/
public int getPeriodicPruning() {
return m_periodicP;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String minWordFrequencyTipText() {
return "Ignore any words that don't occur at least "
+ "min frequency times in the training data. If periodic "
+ "pruning is turned on, then the dictionary is pruned "
+ "according to this value";
}
/**
* Set the minimum word frequency. Words that don't occur at least min freq
* times are ignored when updating weights. If periodic pruning is turned on,
* then min frequency is used when removing words from the dictionary.
*
* @param minFreq the minimum word frequency to use
*/
public void setMinWordFrequency(double minFreq) {
m_minWordP = minFreq;
}
/**
* Get the minimum word frequency. Words that don't occur at least min freq
* times are ignored when updating weights. If periodic pruning is turned on,
* then min frequency is used when removing words from the dictionary.
*
* @param return the minimum word frequency to use
*/
public double getMinWordFrequency() {
return m_minWordP;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String normalizeDocLengthTipText() {
return "If true then document length is normalized according "
+ "to the settings for norm and lnorm";
}
/**
* Set whether to normalize the length of each document
*
* @param norm true if document lengths is to be normalized
*/
public void setNormalizeDocLength(boolean norm) {
m_normalize = norm;
}
/**
* Get whether to normalize the length of each document
*
* @return true if document lengths is to be normalized
*/
public boolean getNormalizeDocLength() {
return m_normalize;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String normTipText() {
return "The norm of the instances after normalization.";
}
/**
* Get the instance's Norm.
*
* @return the Norm
*/
public double getNorm() {
return m_norm;
}
/**
* Set the norm of the instances
*
* @param newNorm the norm to wich the instances must be set
*/
public void setNorm(double newNorm) {
m_norm = newNorm;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String LNormTipText() {
return "The LNorm to use for document length normalization.";
}
/**
* Get the L Norm used.
*
* @return the L-norm used
*/
public double getLNorm() {
return m_lnorm;
}
/**
* Set the L-norm to used
*
* @param newLNorm the L-norm
*/
public void setLNorm(double newLNorm) {
m_lnorm = newLNorm;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lambdaTipText() {
return "The regularization constant. (default = 0.0001)";
}
/**
* Set the value of lambda to use
*
* @param lambda the value of lambda to use
*/
public void setLambda(double lambda) {
m_lambda = lambda;
}
/**
* Get the current value of lambda
*
* @return the current value of lambda
*/
public double getLambda() {
return m_lambda;
}
/**
* Set the learning rate.
*
* @param lr the learning rate to use.
*/
public void setLearningRate(double lr) {
m_learningRate = lr;
}
/**
* Get the learning rate.
*
* @return the learning rate
*/
public double getLearningRate() {
return m_learningRate;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String learningRateTipText() {
return "The learning rate.";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String epochsTipText() {
return "The number of epochs to perform (batch learning). "
+ "The total number of iterations is epochs * num" + " instances.";
}
/**
* Set the number of epochs to use
*
* @param e the number of epochs to use
*/
public void setEpochs(int e) {
m_epochs = e;
}
/**
* Get current number of epochs
*
* @return the current number of epochs
*/
public int getEpochs() {
return m_epochs;
}
/**
* Set the loss function to use.
*
* @param function the loss function to use.
*/
public void setLossFunction(SelectedTag function) {
if (function.getTags() == TAGS_SELECTION) {
m_loss = function.getSelectedTag().getID();
}
}
/**
* Get the current loss function.
*
* @return the current loss function.
*/
public SelectedTag getLossFunction() {
return new SelectedTag(m_loss, TAGS_SELECTION);
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lossFunctionTipText() {
return "The loss function to use. Hinge loss (SVM), "
+ "log loss (logistic regression) or " + "squared loss (regression).";
}
/**
* Set whether to fit a logistic regression (itself trained using SGD) to the
* outputs of the SVM (if an SVM is being learned).
*
* @param o true if a logistic regression is to be fit to the output of the
* SVM to produce probability estimates.
*/
public void setOutputProbsForSVM(boolean o) {
m_fitLogistic = o;
}
/**
* Get whether to fit a logistic regression (itself trained using SGD) to the
* outputs of the SVM (if an SVM is being learned).
*
* @return true if a logistic regression is to be fit to the output of the SVM
* to produce probability estimates.
*/
public boolean getOutputProbsForSVM() {
return m_fitLogistic;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String outputProbsForSVMTipText() {
return "Fit a logistic regression to the output of SVM for "
+ "producing probability estimates";
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration