All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.functions.SGDText Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

There is a newer version: 3.9.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    SGDText.java
 *    Copyright (C) 2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.functions;

import java.io.File;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.RandomizableClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Aggregateable;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Stopwords;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.stemmers.NullStemmer;
import weka.core.stemmers.Stemmer;
import weka.core.tokenizers.Tokenizer;
import weka.core.tokenizers.WordTokenizer;

/**
 *  Implements stochastic gradient descent for learning
 * a linear binary class SVM or binary class logistic regression on text data.
 * Operates directly (and only) on String attributes. Other types of input
 * attributes are accepted but ignored during training and classification.
 * 

* * * Valid options are: *

* *

 * -F
 *  Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression)
 *  (default = 0)
 * 
* *
 * -outputProbs
 *  Output probabilities for SVMs (fits a logsitic
 *  model to the output of the SVM)
 * 
* *
 * -L
 *  The learning rate (default = 0.01).
 * 
* *
 * -R <double>
 *  The lambda regularization constant (default = 0.0001)
 * 
* *
 * -E <integer>
 *  The number of epochs to perform (batch learning only, default = 500)
 * 
* *
 * -W
 *  Use word frequencies instead of binary bag of words.
 * 
* *
 * -P <# instances>
 *  How often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)
 * 
* *
 * -M <double>
 *  Minimum word frequency. Words with less than this frequence are ignored.
 *  If periodic pruning is turned on then this is also used to determine which
 *  words to remove from the dictionary (default = 3).
 * 
* *
 * -normalize
 *  Normalize document length (use in conjunction with -norm and -lnorm
 * 
* *
 * -norm <num>
 *  Specify the norm that each instance must have (default 1.0)
 * 
* *
 * -lnorm <num>
 *  Specify L-norm to use (default 2.0)
 * 
* *
 * -lowercase
 *  Convert all tokens to lowercase before adding to the dictionary.
 * 
* *
 * -stoplist
 *  Ignore words that are in the stoplist.
 * 
* *
 * -stopwords <file>
 *  A file containing stopwords to override the default ones.
 *  Using this option automatically sets the flag ('-stoplist') to use the
 *  stoplist if the file exists.
 *  Format: one stopword per line, lines starting with '#'
 *  are interpreted as comments and ignored.
 * 
* *
 * -tokenizer <spec>
 *  The tokenizing algorihtm (classname plus parameters) to use.
 *  (default: weka.core.tokenizers.WordTokenizer)
 * 
* *
 * -stemmer <spec>
 *  The stemmering algorihtm (classname plus parameters) to use.
 * 
* * * * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) * @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz) * */ public class SGDText extends RandomizableClassifier implements UpdateableClassifier, WeightedInstancesHandler, Aggregateable { /** For serialization */ private static final long serialVersionUID = 7200171484002029584L; public static class Count implements Serializable { /** * For serialization */ private static final long serialVersionUID = 2104201532017340967L; public double m_count; public double m_weight; public Count(double c) { m_count = c; } } /** * The number of training instances at which to periodically prune the * dictionary of min frequency words. Empty or null string indicates don't * prune */ protected int m_periodicP = 0; /** * Only consider dictionary words (features) that occur at least this many * times */ protected double m_minWordP = 3; /** Use word frequencies rather than bag-of-words if true */ protected boolean m_wordFrequencies = false; /** Whether to normalized document length or not */ protected boolean m_normalize = false; /** The length that each document vector should have in the end */ protected double m_norm = 1.0; /** The L-norm to use */ protected double m_lnorm = 2.0; /** The dictionary (and term weights) */ protected LinkedHashMap m_dictionary; /** Default (rainbow) stopwords */ protected transient Stopwords m_stopwords; /** * a file containing stopwords for using others than the default Rainbow ones. */ protected File m_stopwordsFile = new File(System.getProperty("user.dir")); /** The tokenizer to use */ protected Tokenizer m_tokenizer = new WordTokenizer(); /** Whether or not to convert all tokens to lowercase */ protected boolean m_lowercaseTokens; /** The stemming algorithm. */ protected Stemmer m_stemmer = new NullStemmer(); /** Whether or not to use a stop list */ protected boolean m_useStopList; /** The regularization parameter */ protected double m_lambda = 0.0001; /** The learning rate */ protected double m_learningRate = 0.01; /** Holds the current iteration number */ protected double m_t; /** Holds the bias term */ protected double m_bias; /** The number of training instances */ protected double m_numInstances; /** The header of the training data */ protected Instances m_data; /** * The number of epochs to perform (batch learning). Total iterations is * m_epochs * num instances */ protected int m_epochs = 500; /** * Holds the current document vector (LinkedHashMap is more efficient when * iterating over EntrySet than HashMap) */ protected transient LinkedHashMap m_inputVector; /** the hinge loss function. */ public static final int HINGE = 0; /** the log loss function. */ public static final int LOGLOSS = 1; /** The current loss function to minimize */ protected int m_loss = HINGE; /** Loss functions to choose from */ public static final Tag[] TAGS_SELECTION = { new Tag(HINGE, "Hinge loss (SVM)"), new Tag(LOGLOSS, "Log loss (logistic regression)") }; /** Used for producing probabilities for SVM via SGD logistic regression */ protected SGD m_svmProbs; /** * True if a logistic regression is to be fit to the output of the SVM for * producing probability estimates */ protected boolean m_fitLogistic = false; protected Instances m_fitLogisticStructure; protected double dloss(double z) { if (m_loss == HINGE) { return (z < 1) ? 1 : 0; } else { // log loss if (z < 0) { return 1.0 / (Math.exp(z) + 1.0); } else { double t = Math.exp(-z); return t / (t + 1); } } } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.STRING_ATTRIBUTES); result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); result.enable(Capability.BINARY_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); // instances result.setMinimumNumberInstances(0); return result; } /** * the stemming algorithm to use, null means no stemming at all (i.e., the * NullStemmer is used). * * @param value the configured stemming algorithm, or null * @see NullStemmer */ public void setStemmer(Stemmer value) { if (value != null) { m_stemmer = value; } else { m_stemmer = new NullStemmer(); } } /** * Returns the current stemming algorithm, null if none is used. * * @return the current stemming algorithm, null if none set */ public Stemmer getStemmer() { return m_stemmer; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String stemmerTipText() { return "The stemming algorithm to use on the words."; } /** * the tokenizer algorithm to use. * * @param value the configured tokenizing algorithm */ public void setTokenizer(Tokenizer value) { m_tokenizer = value; } /** * Returns the current tokenizer algorithm. * * @return the current tokenizer algorithm */ public Tokenizer getTokenizer() { return m_tokenizer; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String tokenizerTipText() { return "The tokenizing algorithm to use on the strings."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String useWordFrequenciesTipText() { return "Use word frequencies rather than binary " + "bag of words representation"; } /** * Set whether to use word frequencies rather than binary bag of words * representation. * * @param u true if word frequencies are to be used. */ public void setUseWordFrequencies(boolean u) { m_wordFrequencies = u; } /** * Get whether to use word frequencies rather than binary bag of words * representation. * * @param u true if word frequencies are to be used. */ public boolean getUseWordFrequencies() { return m_wordFrequencies; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String lowercaseTokensTipText() { return "Whether to convert all tokens to lowercase"; } /** * Set whether to convert all tokens to lowercase * * @param l true if all tokens are to be converted to lowercase */ public void setLowercaseTokens(boolean l) { m_lowercaseTokens = l; } /** * Get whether to convert all tokens to lowercase * * @return true true if all tokens are to be converted to lowercase */ public boolean getLowercaseTokens() { return m_lowercaseTokens; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String useStopListTipText() { return "If true, ignores all words that are on the stoplist."; } /** * Set whether to ignore all words that are on the stoplist. * * @param u true to ignore all words on the stoplist. */ public void setUseStopList(boolean u) { m_useStopList = u; } /** * Get whether to ignore all words that are on the stoplist. * * @return true to ignore all words on the stoplist. */ public boolean getUseStopList() { return m_useStopList; } /** * sets the file containing the stopwords, null or a directory unset the * stopwords. If the file exists, it automatically turns on the flag to use * the stoplist. * * @param value the file containing the stopwords */ public void setStopwords(File value) { if (value == null) { value = new File(System.getProperty("user.dir")); } m_stopwordsFile = value; if (value.exists() && value.isFile()) { setUseStopList(true); } } /** * returns the file used for obtaining the stopwords, if the file represents a * directory then the default ones are used. * * @return the file containing the stopwords */ public File getStopwords() { return m_stopwordsFile; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String stopwordsTipText() { return "The file containing the stopwords (if this is a directory then the default ones are used)."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String periodicPruningTipText() { return "How often (number of instances) to prune " + "the dictionary of low frequency terms. " + "0 means don't prune. Setting a positive " + "integer n means prune after every n instances"; } /** * Set how often to prune the dictionary * * @param p how often to prune */ public void setPeriodicPruning(int p) { m_periodicP = p; } /** * Get how often to prune the dictionary * * @return how often to prune the dictionary */ public int getPeriodicPruning() { return m_periodicP; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String minWordFrequencyTipText() { return "Ignore any words that don't occur at least " + "min frequency times in the training data. If periodic " + "pruning is turned on, then the dictionary is pruned " + "according to this value"; } /** * Set the minimum word frequency. Words that don't occur at least min freq * times are ignored when updating weights. If periodic pruning is turned on, * then min frequency is used when removing words from the dictionary. * * @param minFreq the minimum word frequency to use */ public void setMinWordFrequency(double minFreq) { m_minWordP = minFreq; } /** * Get the minimum word frequency. Words that don't occur at least min freq * times are ignored when updating weights. If periodic pruning is turned on, * then min frequency is used when removing words from the dictionary. * * @param return the minimum word frequency to use */ public double getMinWordFrequency() { return m_minWordP; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String normalizeDocLengthTipText() { return "If true then document length is normalized according " + "to the settings for norm and lnorm"; } /** * Set whether to normalize the length of each document * * @param norm true if document lengths is to be normalized */ public void setNormalizeDocLength(boolean norm) { m_normalize = norm; } /** * Get whether to normalize the length of each document * * @return true if document lengths is to be normalized */ public boolean getNormalizeDocLength() { return m_normalize; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String normTipText() { return "The norm of the instances after normalization."; } /** * Get the instance's Norm. * * @return the Norm */ public double getNorm() { return m_norm; } /** * Set the norm of the instances * * @param newNorm the norm to wich the instances must be set */ public void setNorm(double newNorm) { m_norm = newNorm; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String LNormTipText() { return "The LNorm to use for document length normalization."; } /** * Get the L Norm used. * * @return the L-norm used */ public double getLNorm() { return m_lnorm; } /** * Set the L-norm to used * * @param newLNorm the L-norm */ public void setLNorm(double newLNorm) { m_lnorm = newLNorm; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String lambdaTipText() { return "The regularization constant. (default = 0.0001)"; } /** * Set the value of lambda to use * * @param lambda the value of lambda to use */ public void setLambda(double lambda) { m_lambda = lambda; } /** * Get the current value of lambda * * @return the current value of lambda */ public double getLambda() { return m_lambda; } /** * Set the learning rate. * * @param lr the learning rate to use. */ public void setLearningRate(double lr) { m_learningRate = lr; } /** * Get the learning rate. * * @return the learning rate */ public double getLearningRate() { return m_learningRate; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String learningRateTipText() { return "The learning rate."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String epochsTipText() { return "The number of epochs to perform (batch learning). " + "The total number of iterations is epochs * num" + " instances."; } /** * Set the number of epochs to use * * @param e the number of epochs to use */ public void setEpochs(int e) { m_epochs = e; } /** * Get current number of epochs * * @return the current number of epochs */ public int getEpochs() { return m_epochs; } /** * Set the loss function to use. * * @param function the loss function to use. */ public void setLossFunction(SelectedTag function) { if (function.getTags() == TAGS_SELECTION) { m_loss = function.getSelectedTag().getID(); } } /** * Get the current loss function. * * @return the current loss function. */ public SelectedTag getLossFunction() { return new SelectedTag(m_loss, TAGS_SELECTION); } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String lossFunctionTipText() { return "The loss function to use. Hinge loss (SVM), " + "log loss (logistic regression) or " + "squared loss (regression)."; } /** * Set whether to fit a logistic regression (itself trained using SGD) to the * outputs of the SVM (if an SVM is being learned). * * @param o true if a logistic regression is to be fit to the output of the * SVM to produce probability estimates. */ public void setOutputProbsForSVM(boolean o) { m_fitLogistic = o; } /** * Get whether to fit a logistic regression (itself trained using SGD) to the * outputs of the SVM (if an SVM is being learned). * * @return true if a logistic regression is to be fit to the output of the SVM * to produce probability estimates. */ public boolean getOutputProbsForSVM() { return m_fitLogistic; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String outputProbsForSVMTipText() { return "Fit a logistic regression to the output of SVM for " + "producing probability estimates"; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ @Override public Enumeration




© 2015 - 2024 Weber Informatics LLC | Privacy Policy