weka.classifiers.bayes.NaiveBayesMultinomialText Maven / Gradle / Ivy
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* NaiveBayesMultinomialText.java
* Copyright (C) 2012 University of Waikato, Hamilton, New Zealand
*/
package weka.classifiers.bayes;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableBatchProcessor;
import weka.classifiers.UpdateableClassifier;
import weka.core.Aggregateable;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.stemmers.NullStemmer;
import weka.core.stemmers.Stemmer;
import weka.core.stopwords.Null;
import weka.core.stopwords.StopwordsHandler;
import weka.core.tokenizers.Tokenizer;
import weka.core.tokenizers.WordTokenizer;
/**
* Multinomial naive bayes for text data. Operates directly (and only) on String attributes. Other types of input attributes are accepted but ignored during training and classification
*
*
* Valid options are:
*
* -W
* Use word frequencies instead of binary bag of words.
*
* -P <# instances>
* How often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)
*
* -M <double>
* Minimum word frequency. Words with less than this frequence are ignored.
* If periodic pruning is turned on then this is also used to determine which
* words to remove from the dictionary (default = 3).
*
* -normalize
* Normalize document length (use in conjunction with -norm and -lnorm)
*
* -norm <num>
* Specify the norm that each instance must have (default 1.0)
*
* -lnorm <num>
* Specify L-norm to use (default 2.0)
*
* -lowercase
* Convert all tokens to lowercase before adding to the dictionary.
*
* -stopwords-handler
* The stopwords handler to use (default Null).
*
* -tokenizer <spec>
* The tokenizing algorihtm (classname plus parameters) to use.
* (default: weka.core.tokenizers.WordTokenizer)
*
* -stemmer <spec>
* The stemmering algorihtm (classname plus parameters) to use.
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @author Andrew Golightly ([email protected])
* @author Bernhard Pfahringer ([email protected])
*
*/
public class NaiveBayesMultinomialText extends AbstractClassifier implements
UpdateableClassifier, UpdateableBatchProcessor, WeightedInstancesHandler,
Aggregateable {
/** For serialization */
private static final long serialVersionUID = 2139025532014821394L;
private static class Count implements Serializable {
/**
* For serialization
*/
private static final long serialVersionUID = 2104201532017340967L;
public double m_count;
public Count(double c) {
m_count = c;
}
}
/** The header of the training data */
protected Instances m_data;
protected double[] m_probOfClass;
protected double[] m_wordsPerClass;
protected Map> m_probOfWordGivenClass;
/**
* Holds the current document vector (LinkedHashMap is more efficient when
* iterating over EntrySet than HashMap)
*/
protected transient LinkedHashMap m_inputVector;
/** Stopword handler to use. */
protected StopwordsHandler m_StopwordsHandler = new Null();
/** The tokenizer to use */
protected Tokenizer m_tokenizer = new WordTokenizer();
/** Whether or not to convert all tokens to lowercase */
protected boolean m_lowercaseTokens;
/** The stemming algorithm. */
protected Stemmer m_stemmer = new NullStemmer();
/**
* The number of training instances at which to periodically prune the
* dictionary of min frequency words. Empty or null string indicates don't
* prune
*/
protected int m_periodicP = 0;
/**
* Only consider dictionary words (features) that occur at least this many
* times
*/
protected double m_minWordP = 3;
/** Use word frequencies rather than bag-of-words if true */
protected boolean m_wordFrequencies = false;
/** normailize document length ? */
protected boolean m_normalize = false;
/** The length that each document vector should have in the end */
protected double m_norm = 1.0;
/** The L-norm to use */
protected double m_lnorm = 2.0;
/** Leplace-like correction factor for zero frequency */
protected double m_leplace = 1.0;
/** Holds the current instance number */
protected double m_t;
/**
* Returns a string describing classifier
*
* @return a description suitable for displaying in the explorer/experimenter
* gui
*/
public String globalInfo() {
return "Multinomial naive bayes for text data. Operates "
+ "directly (and only) on String attributes. "
+ "Other types of input attributes are accepted but "
+ "ignored during training and classification";
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
@Override
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.STRING_ATTRIBUTES);
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
result.enable(Capability.MISSING_CLASS_VALUES);
result.enable(Capability.NOMINAL_CLASS);
// instances
result.setMinimumNumberInstances(0);
return result;
}
/**
* Generates the classifier.
*
* @param data set of instances serving as training data
* @throws Exception if the classifier has not been generated successfully
*/
@Override
public void buildClassifier(Instances data) throws Exception {
reset();
// can classifier handle the data?
getCapabilities().testWithFail(data);
m_data = new Instances(data, 0);
data = new Instances(data);
m_wordsPerClass = new double[data.numClasses()];
m_probOfClass = new double[data.numClasses()];
m_probOfWordGivenClass =
new HashMap>();
double laplace = 1.0;
for (int i = 0; i < data.numClasses(); i++) {
LinkedHashMap dict =
new LinkedHashMap(10000 / data.numClasses());
m_probOfWordGivenClass.put(i, dict);
m_probOfClass[i] = laplace;
// this needs to be updated for laplace correction every time we see a new
// word (attribute)
m_wordsPerClass[i] = 0;
}
for (int i = 0; i < data.numInstances(); i++) {
updateClassifier(data.instance(i));
}
if (data.numInstances() > 0) {
pruneDictionary(true);
}
}
/**
* Updates the classifier with the given instance.
*
* @param instance the new training instance to include in the model
* @throws Exception if the instance could not be incorporated in the model.
*/
@Override
public void updateClassifier(Instance instance) throws Exception {
updateClassifier(instance, true);
}
protected void updateClassifier(Instance instance, boolean updateDictionary)
throws Exception {
if (!instance.classIsMissing()) {
int classIndex = (int) instance.classValue();
m_probOfClass[classIndex] += instance.weight();
tokenizeInstance(instance, updateDictionary);
m_t++;
}
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @throws Exception if there is a problem generating the prediction
*/
@Override
public double[] distributionForInstance(Instance instance) throws Exception {
tokenizeInstance( instance, false );
double[] probOfClassGivenDoc = new double[m_data.numClasses()];
double[] logDocGivenClass = new double[m_data.numClasses()];
for (int i = 0; i < m_data.numClasses(); i++) {
logDocGivenClass[i] += Math.log(m_probOfClass[i]);
LinkedHashMap dictForClass = m_probOfWordGivenClass.get(i);
int allWords = 0;
// for document normalization (if in use)
double iNorm = 0;
double fv = 0;
if (m_normalize) {
for (Map.Entry feature : m_inputVector.entrySet()) {
String word = feature.getKey();
Count c = feature.getValue();
// check the word against all the dictionaries (all classes)
boolean ok = false;
for (int clss = 0; clss < m_data.numClasses(); clss++) {
if (m_probOfWordGivenClass.get(clss).get(word) != null) {
ok = true;
break;
}
}
// only normalize with respect to those words that we've seen during
// training
// (i.e. dictionary over all classes)
if (ok) {
// word counts or bag-of-words?
fv = (m_wordFrequencies) ? c.m_count : 1.0;
iNorm += Math.pow(Math.abs(fv), m_lnorm);
}
}
iNorm = Math.pow(iNorm, 1.0 / m_lnorm);
}
// System.out.println("---- " + m_inputVector.size());
for (Map.Entry feature : m_inputVector.entrySet()) {
String word = feature.getKey();
Count dictCount = dictForClass.get(word);
// System.out.print(word + " ");
/*
* if (dictCount != null) { System.out.println(dictCount.m_count); }
* else { System.out.println("*1"); }
*/
// check the word against all the dictionaries (all classes)
boolean ok = false;
for (int clss = 0; clss < m_data.numClasses(); clss++) {
if (m_probOfWordGivenClass.get(clss).get(word) != null) {
ok = true;
break;
}
}
// ignore words we haven't seen in the training data
if (ok) {
double freq = (m_wordFrequencies) ? feature.getValue().m_count : 1.0;
// double freq = (feature.getValue().m_count / iNorm * m_norm);
if (m_normalize) {
freq *= (m_norm / iNorm);
}
allWords += freq;
if (dictCount != null) {
logDocGivenClass[i] += freq * Math.log(dictCount.m_count);
} else {
// leplace for zero frequency
logDocGivenClass[i] += freq * Math.log(m_leplace);
}
}
}
if (m_wordsPerClass[i] > 0) {
logDocGivenClass[i] -= allWords * Math.log(m_wordsPerClass[i]);
}
}
double max = logDocGivenClass[Utils.maxIndex(logDocGivenClass)];
for (int i = 0; i < m_data.numClasses(); i++) {
probOfClassGivenDoc[i] = Math.exp(logDocGivenClass[i] - max);
}
Utils.normalize(probOfClassGivenDoc);
return probOfClassGivenDoc;
}
protected void tokenizeInstance(Instance instance, boolean updateDictionary) {
if (m_inputVector == null) {
m_inputVector = new LinkedHashMap();
} else {
m_inputVector.clear();
}
for (int i = 0; i < instance.numAttributes(); i++) {
if (instance.attribute(i).isString() && !instance.isMissing(i)) {
m_tokenizer.tokenize(instance.stringValue(i));
while (m_tokenizer.hasMoreElements()) {
String word = m_tokenizer.nextElement();
if (m_lowercaseTokens) {
word = word.toLowerCase();
}
word = m_stemmer.stem(word);
if (m_StopwordsHandler.isStopword(word)) {
continue;
}
Count docCount = m_inputVector.get(word);
if (docCount == null) {
m_inputVector.put(word, new Count(instance.weight()));
} else {
docCount.m_count += instance.weight();
}
}
}
}
if (updateDictionary) {
int classValue = (int) instance.classValue();
LinkedHashMap dictForClass =
m_probOfWordGivenClass.get(classValue);
// document normalization
double iNorm = 0;
double fv = 0;
if (m_normalize) {
for (Count c : m_inputVector.values()) {
// word counts or bag-of-words?
fv = (m_wordFrequencies) ? c.m_count : 1.0;
iNorm += Math.pow(Math.abs(fv), m_lnorm);
}
iNorm = Math.pow(iNorm, 1.0 / m_lnorm);
}
for (Map.Entry feature : m_inputVector.entrySet()) {
String word = feature.getKey();
double freq = (m_wordFrequencies) ? feature.getValue().m_count : 1.0;
// double freq = (feature.getValue().m_count / iNorm * m_norm);
if (m_normalize) {
freq *= (m_norm / iNorm);
}
// check all classes
for (int i = 0; i < m_data.numClasses(); i++) {
LinkedHashMap dict = m_probOfWordGivenClass.get(i);
if (dict.get(word) == null) {
dict.put(word, new Count(m_leplace));
m_wordsPerClass[i] += m_leplace;
}
}
Count dictCount = dictForClass.get(word);
/*
* if (dictCount == null) { dictForClass.put(word, new Count(m_leplace +
* freq)); m_wordsPerClass[classValue] += (m_leplace + freq); } else {
*/
dictCount.m_count += freq;
m_wordsPerClass[classValue] += freq;
// }
}
pruneDictionary(false);
}
}
protected void pruneDictionary(boolean force) {
if ((m_periodicP <= 0 || m_t % m_periodicP > 0) && !force) {
return;
}
Set classesSet = m_probOfWordGivenClass.keySet();
for (Integer classIndex : classesSet) {
LinkedHashMap dictForClass =
m_probOfWordGivenClass.get(classIndex);
Iterator> entries =
dictForClass.entrySet().iterator();
while (entries.hasNext()) {
Map.Entry entry = entries.next();
if (entry.getValue().m_count < m_minWordP) {
m_wordsPerClass[classIndex] -= entry.getValue().m_count;
entries.remove();
}
}
}
}
/**
* Reset the classifier.
*/
public void reset() {
m_t = 1;
m_wordsPerClass = null;
m_probOfWordGivenClass = null;
m_probOfClass = null;
}
/**
* the stemming algorithm to use, null means no stemming at all (i.e., the
* NullStemmer is used).
*
* @param value the configured stemming algorithm, or null
* @see NullStemmer
*/
public void setStemmer(Stemmer value) {
if (value != null) {
m_stemmer = value;
} else {
m_stemmer = new NullStemmer();
}
}
/**
* Returns the current stemming algorithm, null if none is used.
*
* @return the current stemming algorithm, null if none set
*/
public Stemmer getStemmer() {
return m_stemmer;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String stemmerTipText() {
return "The stemming algorithm to use on the words.";
}
/**
* the tokenizer algorithm to use.
*
* @param value the configured tokenizing algorithm
*/
public void setTokenizer(Tokenizer value) {
m_tokenizer = value;
}
/**
* Returns the current tokenizer algorithm.
*
* @return the current tokenizer algorithm
*/
public Tokenizer getTokenizer() {
return m_tokenizer;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String tokenizerTipText() {
return "The tokenizing algorithm to use on the strings.";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String useWordFrequenciesTipText() {
return "Use word frequencies rather than binary "
+ "bag of words representation";
}
/**
* Set whether to use word frequencies rather than binary bag of words
* representation.
*
* @param u true if word frequencies are to be used.
*/
public void setUseWordFrequencies(boolean u) {
m_wordFrequencies = u;
}
/**
* Get whether to use word frequencies rather than binary bag of words
* representation.
*
* @return true if word frequencies are to be used.
*/
public boolean getUseWordFrequencies() {
return m_wordFrequencies;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lowercaseTokensTipText() {
return "Whether to convert all tokens to lowercase";
}
/**
* Set whether to convert all tokens to lowercase
*
* @param l true if all tokens are to be converted to lowercase
*/
public void setLowercaseTokens(boolean l) {
m_lowercaseTokens = l;
}
/**
* Get whether to convert all tokens to lowercase
*
* @return true true if all tokens are to be converted to lowercase
*/
public boolean getLowercaseTokens() {
return m_lowercaseTokens;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String periodicPruningTipText() {
return "How often (number of instances) to prune "
+ "the dictionary of low frequency terms. "
+ "0 means don't prune. Setting a positive "
+ "integer n means prune after every n instances";
}
/**
* Set how often to prune the dictionary
*
* @param p how often to prune
*/
public void setPeriodicPruning(int p) {
m_periodicP = p;
}
/**
* Get how often to prune the dictionary
*
* @return how often to prune the dictionary
*/
public int getPeriodicPruning() {
return m_periodicP;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String minWordFrequencyTipText() {
return "Ignore any words that don't occur at least "
+ "min frequency times in the training data. If periodic "
+ "pruning is turned on, then the dictionary is pruned "
+ "according to this value";
}
/**
* Set the minimum word frequency. Words that don't occur at least min freq
* times are ignored when updating weights. If periodic pruning is turned on,
* then min frequency is used when removing words from the dictionary.
*
* @param minFreq the minimum word frequency to use
*/
public void setMinWordFrequency(double minFreq) {
m_minWordP = minFreq;
}
/**
* Get the minimum word frequency. Words that don't occur at least min freq
* times are ignored when updating weights. If periodic pruning is turned on,
* then min frequency is used when removing words from the dictionary.
*
* @return the minimum word frequency to use
*/
public double getMinWordFrequency() {
return m_minWordP;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String normalizeDocLengthTipText() {
return "If true then document length is normalized according "
+ "to the settings for norm and lnorm";
}
/**
* Set whether to normalize the length of each document
*
* @param norm true if document lengths is to be normalized
*/
public void setNormalizeDocLength(boolean norm) {
m_normalize = norm;
}
/**
* Get whether to normalize the length of each document
*
* @return true if document lengths is to be normalized
*/
public boolean getNormalizeDocLength() {
return m_normalize;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String normTipText() {
return "The norm of the instances after normalization.";
}
/**
* Get the instance's Norm.
*
* @return the Norm
*/
public double getNorm() {
return m_norm;
}
/**
* Set the norm of the instances
*
* @param newNorm the norm to wich the instances must be set
*/
public void setNorm(double newNorm) {
m_norm = newNorm;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String LNormTipText() {
return "The LNorm to use for document length normalization.";
}
/**
* Get the L Norm used.
*
* @return the L-norm used
*/
public double getLNorm() {
return m_lnorm;
}
/**
* Set the L-norm to used
*
* @param newLNorm the L-norm
*/
public void setLNorm(double newLNorm) {
m_lnorm = newLNorm;
}
/**
* Sets the stopwords handler to use.
*
* @param value the stopwords handler, if null, Null is used
*/
public void setStopwordsHandler(StopwordsHandler value) {
if (value != null) {
m_StopwordsHandler = value;
} else {
m_StopwordsHandler = new Null();
}
}
/**
* Gets the stopwords handler.
*
* @return the stopwords handler
*/
public StopwordsHandler getStopwordsHandler() {
return m_StopwordsHandler;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String stopwordsHandlerTipText() {
return "The stopwords handler to use (Null means no stopwords are used).";
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration
© 2015 - 2025 Weber Informatics LLC | Privacy Policy