All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.bayes.NaiveBayesMultinomialText Maven / Gradle / Ivy

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    NaiveBayesMultinomialText.java
 *    Copyright (C) 2012 University of Waikato, Hamilton, New Zealand
 */

package weka.classifiers.bayes;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.Vector;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableBatchProcessor;
import weka.classifiers.UpdateableClassifier;
import weka.core.Aggregateable;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.stemmers.NullStemmer;
import weka.core.stemmers.Stemmer;
import weka.core.stopwords.Null;
import weka.core.stopwords.StopwordsHandler;
import weka.core.tokenizers.Tokenizer;
import weka.core.tokenizers.WordTokenizer;

/**
 
 * Multinomial naive bayes for text data. Operates directly (and only) on String attributes. Other types of input attributes are accepted but ignored during training and classification
 * 

* * Valid options are:

* *

 -W
 *  Use word frequencies instead of binary bag of words.
* *
 -P <# instances>
 *  How often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)
* *
 -M <double>
 *  Minimum word frequency. Words with less than this frequence are ignored.
 *  If periodic pruning is turned on then this is also used to determine which
 *  words to remove from the dictionary (default = 3).
* *
 -normalize
 *  Normalize document length (use in conjunction with -norm and -lnorm)
* *
 -norm <num>
 *  Specify the norm that each instance must have (default 1.0)
* *
 -lnorm <num>
 *  Specify L-norm to use (default 2.0)
* *
 -lowercase
 *  Convert all tokens to lowercase before adding to the dictionary.
* *
 -stopwords-handler
 *  The stopwords handler to use (default Null).
* *
 -tokenizer <spec>
 *  The tokenizing algorihtm (classname plus parameters) to use.
 *  (default: weka.core.tokenizers.WordTokenizer)
* *
 -stemmer <spec>
 *  The stemmering algorihtm (classname plus parameters) to use.
* *
 -output-debug-info
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
* *
 -do-not-check-capabilities
 *  If set, classifier capabilities are not checked before classifier is built
 *  (use with caution).
* * * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) * @author Andrew Golightly ([email protected]) * @author Bernhard Pfahringer ([email protected]) * */ public class NaiveBayesMultinomialText extends AbstractClassifier implements UpdateableClassifier, UpdateableBatchProcessor, WeightedInstancesHandler, Aggregateable { /** For serialization */ private static final long serialVersionUID = 2139025532014821394L; private static class Count implements Serializable { /** * For serialization */ private static final long serialVersionUID = 2104201532017340967L; public double m_count; public Count(double c) { m_count = c; } } /** The header of the training data */ protected Instances m_data; protected double[] m_probOfClass; protected double[] m_wordsPerClass; protected Map> m_probOfWordGivenClass; /** * Holds the current document vector (LinkedHashMap is more efficient when * iterating over EntrySet than HashMap) */ protected transient LinkedHashMap m_inputVector; /** Stopword handler to use. */ protected StopwordsHandler m_StopwordsHandler = new Null(); /** The tokenizer to use */ protected Tokenizer m_tokenizer = new WordTokenizer(); /** Whether or not to convert all tokens to lowercase */ protected boolean m_lowercaseTokens; /** The stemming algorithm. */ protected Stemmer m_stemmer = new NullStemmer(); /** * The number of training instances at which to periodically prune the * dictionary of min frequency words. Empty or null string indicates don't * prune */ protected int m_periodicP = 0; /** * Only consider dictionary words (features) that occur at least this many * times */ protected double m_minWordP = 3; /** Use word frequencies rather than bag-of-words if true */ protected boolean m_wordFrequencies = false; /** normailize document length ? */ protected boolean m_normalize = false; /** The length that each document vector should have in the end */ protected double m_norm = 1.0; /** The L-norm to use */ protected double m_lnorm = 2.0; /** Leplace-like correction factor for zero frequency */ protected double m_leplace = 1.0; /** Holds the current instance number */ protected double m_t; /** * Returns a string describing classifier * * @return a description suitable for displaying in the explorer/experimenter * gui */ public String globalInfo() { return "Multinomial naive bayes for text data. Operates " + "directly (and only) on String attributes. " + "Other types of input attributes are accepted but " + "ignored during training and classification"; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.STRING_ATTRIBUTES); result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); result.enable(Capability.MISSING_CLASS_VALUES); result.enable(Capability.NOMINAL_CLASS); // instances result.setMinimumNumberInstances(0); return result; } /** * Generates the classifier. * * @param data set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ @Override public void buildClassifier(Instances data) throws Exception { reset(); // can classifier handle the data? getCapabilities().testWithFail(data); m_data = new Instances(data, 0); data = new Instances(data); m_wordsPerClass = new double[data.numClasses()]; m_probOfClass = new double[data.numClasses()]; m_probOfWordGivenClass = new HashMap>(); double laplace = 1.0; for (int i = 0; i < data.numClasses(); i++) { LinkedHashMap dict = new LinkedHashMap(10000 / data.numClasses()); m_probOfWordGivenClass.put(i, dict); m_probOfClass[i] = laplace; // this needs to be updated for laplace correction every time we see a new // word (attribute) m_wordsPerClass[i] = 0; } for (int i = 0; i < data.numInstances(); i++) { updateClassifier(data.instance(i)); } if (data.numInstances() > 0) { pruneDictionary(true); } } /** * Updates the classifier with the given instance. * * @param instance the new training instance to include in the model * @throws Exception if the instance could not be incorporated in the model. */ @Override public void updateClassifier(Instance instance) throws Exception { updateClassifier(instance, true); } protected void updateClassifier(Instance instance, boolean updateDictionary) throws Exception { if (!instance.classIsMissing()) { int classIndex = (int) instance.classValue(); m_probOfClass[classIndex] += instance.weight(); tokenizeInstance(instance, updateDictionary); m_t++; } } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if there is a problem generating the prediction */ @Override public double[] distributionForInstance(Instance instance) throws Exception { tokenizeInstance( instance, false ); double[] probOfClassGivenDoc = new double[m_data.numClasses()]; double[] logDocGivenClass = new double[m_data.numClasses()]; for (int i = 0; i < m_data.numClasses(); i++) { logDocGivenClass[i] += Math.log(m_probOfClass[i]); LinkedHashMap dictForClass = m_probOfWordGivenClass.get(i); int allWords = 0; // for document normalization (if in use) double iNorm = 0; double fv = 0; if (m_normalize) { for (Map.Entry feature : m_inputVector.entrySet()) { String word = feature.getKey(); Count c = feature.getValue(); // check the word against all the dictionaries (all classes) boolean ok = false; for (int clss = 0; clss < m_data.numClasses(); clss++) { if (m_probOfWordGivenClass.get(clss).get(word) != null) { ok = true; break; } } // only normalize with respect to those words that we've seen during // training // (i.e. dictionary over all classes) if (ok) { // word counts or bag-of-words? fv = (m_wordFrequencies) ? c.m_count : 1.0; iNorm += Math.pow(Math.abs(fv), m_lnorm); } } iNorm = Math.pow(iNorm, 1.0 / m_lnorm); } // System.out.println("---- " + m_inputVector.size()); for (Map.Entry feature : m_inputVector.entrySet()) { String word = feature.getKey(); Count dictCount = dictForClass.get(word); // System.out.print(word + " "); /* * if (dictCount != null) { System.out.println(dictCount.m_count); } * else { System.out.println("*1"); } */ // check the word against all the dictionaries (all classes) boolean ok = false; for (int clss = 0; clss < m_data.numClasses(); clss++) { if (m_probOfWordGivenClass.get(clss).get(word) != null) { ok = true; break; } } // ignore words we haven't seen in the training data if (ok) { double freq = (m_wordFrequencies) ? feature.getValue().m_count : 1.0; // double freq = (feature.getValue().m_count / iNorm * m_norm); if (m_normalize) { freq *= (m_norm / iNorm); } allWords += freq; if (dictCount != null) { logDocGivenClass[i] += freq * Math.log(dictCount.m_count); } else { // leplace for zero frequency logDocGivenClass[i] += freq * Math.log(m_leplace); } } } if (m_wordsPerClass[i] > 0) { logDocGivenClass[i] -= allWords * Math.log(m_wordsPerClass[i]); } } double max = logDocGivenClass[Utils.maxIndex(logDocGivenClass)]; for (int i = 0; i < m_data.numClasses(); i++) { probOfClassGivenDoc[i] = Math.exp(logDocGivenClass[i] - max); } Utils.normalize(probOfClassGivenDoc); return probOfClassGivenDoc; } protected void tokenizeInstance(Instance instance, boolean updateDictionary) { if (m_inputVector == null) { m_inputVector = new LinkedHashMap(); } else { m_inputVector.clear(); } for (int i = 0; i < instance.numAttributes(); i++) { if (instance.attribute(i).isString() && !instance.isMissing(i)) { m_tokenizer.tokenize(instance.stringValue(i)); while (m_tokenizer.hasMoreElements()) { String word = m_tokenizer.nextElement(); if (m_lowercaseTokens) { word = word.toLowerCase(); } word = m_stemmer.stem(word); if (m_StopwordsHandler.isStopword(word)) { continue; } Count docCount = m_inputVector.get(word); if (docCount == null) { m_inputVector.put(word, new Count(instance.weight())); } else { docCount.m_count += instance.weight(); } } } } if (updateDictionary) { int classValue = (int) instance.classValue(); LinkedHashMap dictForClass = m_probOfWordGivenClass.get(classValue); // document normalization double iNorm = 0; double fv = 0; if (m_normalize) { for (Count c : m_inputVector.values()) { // word counts or bag-of-words? fv = (m_wordFrequencies) ? c.m_count : 1.0; iNorm += Math.pow(Math.abs(fv), m_lnorm); } iNorm = Math.pow(iNorm, 1.0 / m_lnorm); } for (Map.Entry feature : m_inputVector.entrySet()) { String word = feature.getKey(); double freq = (m_wordFrequencies) ? feature.getValue().m_count : 1.0; // double freq = (feature.getValue().m_count / iNorm * m_norm); if (m_normalize) { freq *= (m_norm / iNorm); } // check all classes for (int i = 0; i < m_data.numClasses(); i++) { LinkedHashMap dict = m_probOfWordGivenClass.get(i); if (dict.get(word) == null) { dict.put(word, new Count(m_leplace)); m_wordsPerClass[i] += m_leplace; } } Count dictCount = dictForClass.get(word); /* * if (dictCount == null) { dictForClass.put(word, new Count(m_leplace + * freq)); m_wordsPerClass[classValue] += (m_leplace + freq); } else { */ dictCount.m_count += freq; m_wordsPerClass[classValue] += freq; // } } pruneDictionary(false); } } protected void pruneDictionary(boolean force) { if ((m_periodicP <= 0 || m_t % m_periodicP > 0) && !force) { return; } Set classesSet = m_probOfWordGivenClass.keySet(); for (Integer classIndex : classesSet) { LinkedHashMap dictForClass = m_probOfWordGivenClass.get(classIndex); Iterator> entries = dictForClass.entrySet().iterator(); while (entries.hasNext()) { Map.Entry entry = entries.next(); if (entry.getValue().m_count < m_minWordP) { m_wordsPerClass[classIndex] -= entry.getValue().m_count; entries.remove(); } } } } /** * Reset the classifier. */ public void reset() { m_t = 1; m_wordsPerClass = null; m_probOfWordGivenClass = null; m_probOfClass = null; } /** * the stemming algorithm to use, null means no stemming at all (i.e., the * NullStemmer is used). * * @param value the configured stemming algorithm, or null * @see NullStemmer */ public void setStemmer(Stemmer value) { if (value != null) { m_stemmer = value; } else { m_stemmer = new NullStemmer(); } } /** * Returns the current stemming algorithm, null if none is used. * * @return the current stemming algorithm, null if none set */ public Stemmer getStemmer() { return m_stemmer; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String stemmerTipText() { return "The stemming algorithm to use on the words."; } /** * the tokenizer algorithm to use. * * @param value the configured tokenizing algorithm */ public void setTokenizer(Tokenizer value) { m_tokenizer = value; } /** * Returns the current tokenizer algorithm. * * @return the current tokenizer algorithm */ public Tokenizer getTokenizer() { return m_tokenizer; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String tokenizerTipText() { return "The tokenizing algorithm to use on the strings."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String useWordFrequenciesTipText() { return "Use word frequencies rather than binary " + "bag of words representation"; } /** * Set whether to use word frequencies rather than binary bag of words * representation. * * @param u true if word frequencies are to be used. */ public void setUseWordFrequencies(boolean u) { m_wordFrequencies = u; } /** * Get whether to use word frequencies rather than binary bag of words * representation. * * @return true if word frequencies are to be used. */ public boolean getUseWordFrequencies() { return m_wordFrequencies; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String lowercaseTokensTipText() { return "Whether to convert all tokens to lowercase"; } /** * Set whether to convert all tokens to lowercase * * @param l true if all tokens are to be converted to lowercase */ public void setLowercaseTokens(boolean l) { m_lowercaseTokens = l; } /** * Get whether to convert all tokens to lowercase * * @return true true if all tokens are to be converted to lowercase */ public boolean getLowercaseTokens() { return m_lowercaseTokens; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String periodicPruningTipText() { return "How often (number of instances) to prune " + "the dictionary of low frequency terms. " + "0 means don't prune. Setting a positive " + "integer n means prune after every n instances"; } /** * Set how often to prune the dictionary * * @param p how often to prune */ public void setPeriodicPruning(int p) { m_periodicP = p; } /** * Get how often to prune the dictionary * * @return how often to prune the dictionary */ public int getPeriodicPruning() { return m_periodicP; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String minWordFrequencyTipText() { return "Ignore any words that don't occur at least " + "min frequency times in the training data. If periodic " + "pruning is turned on, then the dictionary is pruned " + "according to this value"; } /** * Set the minimum word frequency. Words that don't occur at least min freq * times are ignored when updating weights. If periodic pruning is turned on, * then min frequency is used when removing words from the dictionary. * * @param minFreq the minimum word frequency to use */ public void setMinWordFrequency(double minFreq) { m_minWordP = minFreq; } /** * Get the minimum word frequency. Words that don't occur at least min freq * times are ignored when updating weights. If periodic pruning is turned on, * then min frequency is used when removing words from the dictionary. * * @return the minimum word frequency to use */ public double getMinWordFrequency() { return m_minWordP; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String normalizeDocLengthTipText() { return "If true then document length is normalized according " + "to the settings for norm and lnorm"; } /** * Set whether to normalize the length of each document * * @param norm true if document lengths is to be normalized */ public void setNormalizeDocLength(boolean norm) { m_normalize = norm; } /** * Get whether to normalize the length of each document * * @return true if document lengths is to be normalized */ public boolean getNormalizeDocLength() { return m_normalize; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String normTipText() { return "The norm of the instances after normalization."; } /** * Get the instance's Norm. * * @return the Norm */ public double getNorm() { return m_norm; } /** * Set the norm of the instances * * @param newNorm the norm to wich the instances must be set */ public void setNorm(double newNorm) { m_norm = newNorm; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String LNormTipText() { return "The LNorm to use for document length normalization."; } /** * Get the L Norm used. * * @return the L-norm used */ public double getLNorm() { return m_lnorm; } /** * Set the L-norm to used * * @param newLNorm the L-norm */ public void setLNorm(double newLNorm) { m_lnorm = newLNorm; } /** * Sets the stopwords handler to use. * * @param value the stopwords handler, if null, Null is used */ public void setStopwordsHandler(StopwordsHandler value) { if (value != null) { m_StopwordsHandler = value; } else { m_StopwordsHandler = new Null(); } } /** * Gets the stopwords handler. * * @return the stopwords handler */ public StopwordsHandler getStopwordsHandler() { return m_StopwordsHandler; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String stopwordsHandlerTipText() { return "The stopwords handler to use (Null means no stopwords are used)."; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ @Override public Enumeration




© 2015 - 2025 Weber Informatics LLC | Privacy Policy