All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.berkeley.nlp.classify.NaiveBayesClassifier Maven / Gradle / Ivy

Go to download

The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).

The newest version!
package edu.berkeley.nlp.classify;

import java.util.ArrayList;
import java.util.List;

import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.CounterMap;

public class NaiveBayesClassifier implements ProbabilisticClassifier {
	
	private CounterMap featureProbs ;
	private Counter backoffProbs ;
	private Counter labelProbs ;
	private FeatureExtractor featureExtractor;
	private double alpha = 0.1;
		
	public static class Factory implements ProbabilisticClassifierFactory {
		
		private FeatureExtractor featureExtractor;
		
		public Factory(FeatureExtractor featureExtractor) {
			this.featureExtractor = featureExtractor;
		}
		
		public ProbabilisticClassifier trainClassifier(List> trainingData) {
			CounterMap featureProbs = new CounterMap();
			Counter backoffProbs = new Counter();
			Counter labelProbs = new Counter();
			for (LabeledInstance instance: trainingData) {
				L label = instance.getLabel();
				labelProbs.incrementCount(label, 1.0);			
				I inst = instance.getInput();
				Counter featCounts = featureExtractor.extractFeatures(inst);
				for (F feat: featCounts.keySet()) {
					double count = featCounts.getCount(feat);
					backoffProbs.incrementCount(feat, count);
					featureProbs.incrementCount(label, feat, count);
				}				
			}
			featureProbs.normalize();
			labelProbs.normalize();
			backoffProbs.normalize();
			return new NaiveBayesClassifier(featureProbs, backoffProbs, labelProbs, featureExtractor);		
		}
		
	}

	public Counter getProbabilities(I instance) {
		Counter posteriors = new Counter();
		List logPosteriorsUnnormed = new ArrayList();
		for (L label: labelProbs.keySet()) {
			double logPrior = Math.log(labelProbs.getCount(label));
			double logPosteriorUnnorm = logPrior;
			Counter featCounts =featureExtractor.extractFeatures(instance);
			for (F feat: featCounts.keySet()) {
				double count = featCounts.getCount(feat);
				logPosteriorUnnorm += count * Math.log( getFeatureProb(feat, label) );				
			}
			logPosteriorsUnnormed.add(logPosteriorUnnorm);
			posteriors.setCount(label, logPosteriorUnnorm);
		}
		double logPosteriorNorm = SloppyMath.logAdd(logPosteriorsUnnormed);
		for (L label: labelProbs.keySet()) {			
			double logPosteriorUnnorm = posteriors.getCount(label);
			double logPosterior = logPosteriorUnnorm - logPosteriorNorm;
			double posterior = Math.exp(logPosterior);
			posteriors.setCount(label, posterior);
		} 		
		// TODO Auto-generated method stub
		return posteriors;
	}
	
	private double getFeatureProb(F feat, L label) {
		double mleProb = featureProbs.getCount(label, feat);
		double backoffProb = backoffProbs.getCount(feat);
		return (1-alpha) * mleProb + alpha * backoffProb;
	}

	public L getLabel(I instance) {
		// TODO Auto-generated method stub
		return getProbabilities(instance).argMax();
	}

	public NaiveBayesClassifier(CounterMap featureProbs,
			Counter backoffProbs, Counter labelProbs,
			FeatureExtractor featureExtractor) {
		super();
		this.featureProbs = featureProbs;
		this.backoffProbs = backoffProbs;
		this.labelProbs = labelProbs;
		this.featureExtractor = featureExtractor;
	}
	
	

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy