All Downloads are FREE. Search and download functionalities are using the official Maven repository.

stream.classifier.NaiveBayes Maven / Gradle / Ivy

/**
 * 
 */
package stream.classifier;

import java.rmi.RemoteException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import stream.Data;
import stream.annotations.Description;
import stream.annotations.Parameter;
import stream.distribution.Distribution;
import stream.distribution.NominalDistribution;
import stream.distribution.NumericalDistribution;

/**
 * 

* This class implements a NaiveBayes classifier. It combines the learning * algorithm and the model implementation in one. The implementation provides * support for numerical (Double) and nominal (String) attributes. *

*

* The implementation is a strictly incremental one and supports memory * limitation. The memory limitation is carried out by truncating the * distribution models built for each of the observed attributes. *

* * @author Christian Bockermann <[email protected]> * */ @Description(group = "Data Stream.Mining.Classifier") public class NaiveBayes extends AbstractClassifier { /** The unique class ID */ private static final long serialVersionUID = 1095437834368310484L; /* The global logger for this class */ static Logger log = LoggerFactory.getLogger(NaiveBayes.class); /* The La-Place correction term */ Double laplaceCorrection = 0.0001; Double confidenceGap = new Double(0.0d); Boolean wop = false; /* This is the distribution of the different classes observed */ protected NominalDistribution classDistribution = null; // createNominalDistribution(); /* A map providing the distributions of the attributes (nominal,numerical) */ protected Map> distributions = new HashMap>(); /** * Create a new NaiveBayes instance. The label attribute is automatically * determined by the learner, if not explicitly set with the * setLabelAttribute method. */ public NaiveBayes() { classDistribution = createNominalDistribution(); } /** * @return the laplaceCorrection */ public Double getLaplaceCorrection() { return laplaceCorrection; } /** * @param laplaceCorrection * the laplaceCorrection to set */ @Parameter(required = false, description = "Value of the la-place correction") public void setLaplaceCorrection(Double laplaceCorrection) { this.laplaceCorrection = laplaceCorrection; } /** * @return the confidenceGap */ public Double getConfidenceGap() { return confidenceGap; } /** * @param confidenceGap * the confidenceGap to set */ public void setConfidenceGap(Double confidenceGap) { this.confidenceGap = confidenceGap; } /** * @return the wop */ public Boolean getWop() { return wop; } /** * @param wop * the wop to set */ public void setWop(Boolean wop) { this.wop = wop; } /** * @see stream.model.PredictionModel#predict(java.lang.Object) */ @SuppressWarnings("unchecked") public Map vote(Data item) { Map classLikeli = new LinkedHashMap(); log.debug("Predicting one of these classes: {}", classDistribution.getElements()); for (String label : getClassDistribution().getElements()) { // 9/14 class likelihoods // if (wop) { classLikeli.put(label, 1.0d); } else { Double cl = getClassDistribution().getCount(label) .doubleValue(); // .getHistogram().get( label ); log.debug("class likelihood for class '" + label + "' is {} / {}", cl, getClassDistribution() .getTotalCount()); Double p_label = getClassDistribution().getHistogram().get( label) / this.getClassDistribution().getTotalCount(); classLikeli.put(label, p_label); } } // // compute the class likelihood for each class: // Double max = 0.0d; String maxClass = null; Double totalLikelihood = 0.0d; for (String clazz : classLikeli.keySet()) { Double likelihood = classLikeli.get(clazz); for (String attribute : item.keySet()) { if (!this.label.equals(attribute)) { Object value = item.get(attribute); if (value.getClass().equals(Double.class)) { // // multiplying probability for double value // Distribution dist = (Distribution) distributions .get(clazz); likelihood *= dist.prob((Double) value); } else { // // determine likelihood for nominal value // String feature = this.getNominalCondition(attribute, item); Double d = ((NominalDistribution) distributions .get(clazz)).getCount(feature).doubleValue(); Double total = ((NominalDistribution) getClassDistribution()) .getCount(clazz).doubleValue(); if (d == null || d == 0.0d) { d = laplaceCorrection; total += laplaceCorrection; } log.debug(" likelihood for {} is {} |" + clazz + " ", feature, d / total); likelihood *= (d / total); } } } classLikeli.put(clazz, likelihood); totalLikelihood += likelihood; } // determine most likely class // Map probs = new LinkedHashMap(); for (String clazz : classLikeli.keySet()) { Double likelihood = classLikeli.get(clazz) / totalLikelihood; probs.put(clazz, likelihood); log.debug("probability for {} is {}", clazz, likelihood); if (maxClass == null || likelihood > max) { maxClass = clazz; max = likelihood; } } return probs; } /** * @see stream.learner.AbstractClassifier#predict(java.lang.Object) */ @Override public String predict(Data item) { Map probs = this.vote(item); Double max = 0.0d; String maxClass = null; Double confidence = 0.0d; for (String clazz : probs.keySet()) { Double likelihood = probs.get(clazz); log.debug("probability for {} is {}", clazz, likelihood); // item.put( LearnerUtils.hide( "pred(" + clazz + ")" ), likelihood // ); if (maxClass == null || likelihood > max) { maxClass = clazz; if (max != null) confidence = 1.0d - Math.abs(likelihood - max); else confidence = 1.0d; max = likelihood; } } log.info("Predicting class {}, label is: {}, confidence-gap: " + confidence + " wop=" + wop, maxClass, item.get(label)); return maxClass; } public String getNominalCondition(String attribute, Data item) { return attribute + "='" + item.get(attribute) + "'"; } /** * @see stream.learner.Learner#learn(java.lang.Object) */ @SuppressWarnings("unchecked") @Override public void train(Data item) { // // determine the label attribute, if not already set // if (label == null) { return; } if (item.get(label) == null) { log.warn("Not processing unlabeled data item {}", item); return; } String clazz = item.get(label).toString(); log.debug("Learning from example with label={}", clazz); if (this.classDistribution == null) this.classDistribution = createNominalDistribution(); if (log.isDebugEnabled()) { log.debug("Classes: {}", classDistribution.getElements()); for (String t : classDistribution.getElements()) log.debug(" {}: {}", t, classDistribution.getCount(t)); } // // For learning we update the distributions of each attribute // for (String attribute : item.keySet()) { if (attribute.equalsIgnoreCase(label)) { // // adjust the class label distribution // classDistribution.update(clazz); } else { Object obj = item.get(attribute); if (obj.getClass().equals(Double.class)) { Double value = (Double) obj; log.debug("Handling numerical case ({}) with value {}", obj, value); // // manage the case of an numerical attribute // Distribution numDist = (Distribution) distributions .get(attribute); if (numDist == null) { numDist = this.createNumericalDistribution(); log.debug( "Creating new numerical distribution model for attribute {}", attribute); distributions.put(attribute, numDist); } numDist.update(value); } else { String value = this.getNominalCondition(attribute, item); log.debug("Handling nominal case for [ {} | {} ]", value, "class=" + clazz); // // adapt the nominal distribution for this attribute // Distribution nomDist = (Distribution) distributions .get(clazz); if (nomDist == null) { nomDist = this.createNominalDistribution(); log.debug( "Creating new nominal distribution model for attribute {}, {}", attribute, "class=" + clazz); distributions.put(clazz, nomDist); } nomDist.update(value); } } } } /** * Returns the class distribution of the current state of the algorithm. * * @return */ public NominalDistribution getClassDistribution() { if (classDistribution == null) classDistribution = createNominalDistribution(); return this.classDistribution; } /** *

* Returns the set of numerical distributions of this model. *

* * @return The set of numerical distributions, currently known to this * classifier. */ @SuppressWarnings("unchecked") public List> getNumericalDistributions() { List> numDists = new ArrayList>(); for (Distribution d : distributions.values()) { if (d instanceof NumericalDistribution) numDists.add((Distribution) d); } return numDists; } /** *

* This method creates a new distribution model for nominal values. It can * be overwritten by subclasses to make use of a more * sophisticated/space-limited assessment of nominal distributions. *

* * @return A new, empty distribution model. */ public NominalDistribution createNominalDistribution() { return new NominalDistribution(); } /** *

* This method creates a new distribution model for numerical data. *

* * @return A new, empty distribution model. */ public Distribution createNumericalDistribution() { return new NumericalDistribution(); } /** * @see stream.learner.PredictionService#getName() */ @Override public String getName() throws RemoteException { return "stream.classifier.NaiveBayes"; } /** * @see stream.service.Service#reset() */ @Override public void reset() throws Exception { classDistribution = this.createNominalDistribution(); distributions.clear(); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy