cc.mallet.classify.tests.TestNaiveBayes Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Andrew McCallum [email protected]
*/
package cc.mallet.classify.tests;
import junit.framework.*;
import java.net.URI;
import java.io.File;
import cc.mallet.classify.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.ArrayIterator;
import cc.mallet.pipe.iterator.FileIterator;
import cc.mallet.types.*;
import cc.mallet.util.*;
public class TestNaiveBayes extends TestCase
{
public TestNaiveBayes (String name)
{
super (name);
}
public void testNonTrained ()
{
Alphabet fdict = new Alphabet ();
System.out.println ("fdict.size="+fdict.size());
LabelAlphabet ldict = new LabelAlphabet ();
Multinomial.Estimator me1 = new Multinomial.LaplaceEstimator (fdict);
Multinomial.Estimator me2 = new Multinomial.LaplaceEstimator (fdict);
// Prior
ldict.lookupIndex ("sports");
ldict.lookupIndex ("politics");
ldict.stopGrowth ();
System.out.println ("ldict.size="+ldict.size());
Multinomial prior = new Multinomial (new double[] {.5, .5}, ldict);
// Sports
me1.increment ("win", 5);
me1.increment ("puck", 5);
me1.increment ("team", 5);
System.out.println ("fdict.size="+fdict.size());
// Politics
me2.increment ("win", 5);
me2.increment ("speech", 5);
me2.increment ("vote", 5);
Multinomial sports = me1.estimate();
Multinomial politics = me2.estimate();
// We must estimate from me1 and me2 after all data is incremented,
// so that the "sports" multinomial knows the full dictionary size!
Classifier c = new NaiveBayes (new Noop (fdict, ldict),
prior,
new Multinomial[] {sports, politics});
Instance inst = c.getInstancePipe().instanceFrom(
new Instance (new FeatureVector (fdict,
new Object[] {"speech", "win"},
new double[] {1, 1}),
ldict.lookupLabel ("politics"),
null, null));
System.out.println ("inst.data = "+inst.getData ());
Classification cf = c.classify (inst);
LabelVector l = (LabelVector) cf.getLabeling();
//System.out.println ("l.size="+l.size());
System.out.println ("l.getBestIndex="+l.getBestIndex());
assertTrue (cf.getLabeling().getBestLabel()
== ldict.lookupLabel("politics"));
assertTrue (cf.getLabeling().getBestValue() > 0.6);
}
public void testStringTrained ()
{
String[] africaTraining = new String[] {
"on the plains of africa the lions roar",
"in swahili ngoma means to dance",
"nelson mandela became president of south africa",
"the saraha dessert is expanding"};
String[] asiaTraining = new String[] {
"panda bears eat bamboo",
"china's one child policy has resulted in a surplus of boys",
"tigers live in the jungle"};
InstanceList instances =
new InstanceList (
new SerialPipes (new Pipe[] {
new Target2Label (),
new CharSequence2TokenSequence (),
new TokenSequence2FeatureSequence (),
new FeatureSequence2FeatureVector ()}));
instances.addThruPipe (new ArrayIterator (africaTraining, "africa"));
instances.addThruPipe (new ArrayIterator (asiaTraining, "asia"));
Classifier c = new NaiveBayesTrainer ().train (instances);
Classification cf = c.classify ("nelson mandela never eats lions");
assertTrue (cf.getLabeling().getBestLabel()
== ((LabelAlphabet)instances.getTargetAlphabet()).lookupLabel("africa"));
}
public void testRandomTrained ()
{
InstanceList ilist = new InstanceList (new Randoms(1), 10, 2);
Classifier c = new NaiveBayesTrainer ().train (ilist);
// test on the training data
int numCorrect = 0;
for (int i = 0; i < ilist.size(); i++) {
Instance inst = ilist.get(i);
Classification cf = c.classify (inst);
cf.print ();
if (cf.getLabeling().getBestLabel() == inst.getLabeling().getBestLabel())
numCorrect++;
}
System.out.println ("Accuracy on training set = " + ((double)numCorrect)/ilist.size());
}
public void testIncrementallyTrainedGrowingAlphabets()
{
System.out.println("testIncrementallyTrainedGrowingAlphabets");
String[] args = new String[] {
"src/cc/mallet/classify/tests/NaiveBayesData/learn/a",
"src/cc/mallet/classify/tests/NaiveBayesData/learn/b"
};
File[] directories = new File[args.length];
for (int i = 0; i < args.length; i++)
directories[i] = new File (args[i]);
SerialPipes instPipe =
// MALLET pipeline for converting instances to feature vectors
new SerialPipes(new Pipe[] {
new Target2Label(),
new Input2CharSequence(),
//SKIP_HEADER only works for Unix
//new CharSubsequence(CharSubsequence.SKIP_HEADER),
new CharSequence2TokenSequence(),
new TokenSequenceLowercase(),
new TokenSequenceRemoveStopwords(),
new TokenSequence2FeatureSequence(),
new FeatureSequence2FeatureVector() });
InstanceList instList = new InstanceList(instPipe);
instList.addThruPipe(new
FileIterator(directories, FileIterator.STARTING_DIRECTORIES));
System.out.println("Training 1");
NaiveBayesTrainer trainer = new NaiveBayesTrainer();
NaiveBayes classifier = trainer.trainIncremental(instList);
//instList.getDataAlphabet().stopGrowth();
// incrementally train...
String[] t2directories = {
"src/cc/mallet/classify/tests/NaiveBayesData/learn/b"
};
System.out.println("data alphabet size " + instList.getDataAlphabet().size());
System.out.println("target alphabet size " + instList.getTargetAlphabet().size());
InstanceList instList2 = new InstanceList(instPipe);
instList2.addThruPipe(new
FileIterator(t2directories, FileIterator.STARTING_DIRECTORIES));
System.out.println("Training 2");
System.out.println("data alphabet size " + instList2.getDataAlphabet().size());
System.out.println("target alphabet size " + instList2.getTargetAlphabet().size());
NaiveBayes classifier2 = (NaiveBayes) trainer.trainIncremental(instList2);
}
public void testIncrementallyTrained()
{
System.out.println("testIncrementallyTrained");
String[] args = new String[] {
"src/cc/mallet/classify/tests/NaiveBayesData/learn/a",
"src/cc/mallet/classify/tests/NaiveBayesData/learn/b"
};
File[] directories = new File[args.length];
for (int i = 0; i < args.length; i++)
directories[i] = new File (args[i]);
SerialPipes instPipe =
// MALLET pipeline for converting instances to feature vectors
new SerialPipes(new Pipe[] {
new Target2Label(),
new Input2CharSequence(),
//SKIP_HEADER only works for Unix
//new CharSubsequence(CharSubsequence.SKIP_HEADER),
new CharSequence2TokenSequence(),
new TokenSequenceLowercase(),
new TokenSequenceRemoveStopwords(),
new TokenSequence2FeatureSequence(),
new FeatureSequence2FeatureVector() });
InstanceList instList = new InstanceList(instPipe);
instList.addThruPipe(new
FileIterator(directories, FileIterator.STARTING_DIRECTORIES));
System.out.println("Training 1");
NaiveBayesTrainer trainer = new NaiveBayesTrainer();
NaiveBayes classifier = (NaiveBayes) trainer.trainIncremental(instList);
Classification initialClassification = classifier.classify("Hello Everybody");
Classification initial2Classification = classifier.classify("Goodbye now");
System.out.println("Initial Classification = ");
initialClassification.print();
initial2Classification.print();
System.out.println("data alphabet " + classifier.getAlphabet());
System.out.println("label alphabet " + classifier.getLabelAlphabet());
// incrementally train...
String[] t2directories = {
"src/cc/mallet/classify/tests/NaiveBayesData/learn/b"
};
System.out.println("data alphabet size " + instList.getDataAlphabet().size());
System.out.println("target alphabet size " + instList.getTargetAlphabet().size());
InstanceList instList2 = new InstanceList(instPipe);
instList2.addThruPipe(new
FileIterator(t2directories, FileIterator.STARTING_DIRECTORIES));
System.out.println("Training 2");
System.out.println("data alphabet size " + instList2.getDataAlphabet().size());
System.out.println("target alphabet size " + instList2.getTargetAlphabet().size());
NaiveBayes classifier2 = (NaiveBayes) trainer.trainIncremental(instList2);
}
public void testEmptyStringBug()
{
System.out.println("testEmptyStringBug");
String[] args = new String[] {
"src/cc/mallet/classify/tests/NaiveBayesData/learn/a",
"src/cc/mallet/classify/tests/NaiveBayesData/learn/b"
};
File[] directories = new File[args.length];
for (int i = 0; i < args.length; i++)
directories[i] = new File (args[i]);
SerialPipes instPipe =
// MALLET pipeline for converting instances to feature vectors
new SerialPipes(new Pipe[] {
new Target2Label(),
new Input2CharSequence(),
//SKIP_HEADER only works for Unix
//new CharSubsequence(CharSubsequence.SKIP_HEADER),
new CharSequence2TokenSequence(),
new TokenSequenceLowercase(),
new TokenSequenceRemoveStopwords(),
new TokenSequence2FeatureSequence(),
new FeatureSequence2FeatureVector() });
InstanceList instList = new InstanceList(instPipe);
instList.addThruPipe(new
FileIterator(directories, FileIterator.STARTING_DIRECTORIES));
System.out.println("Training 1");
NaiveBayesTrainer trainer = new NaiveBayesTrainer();
NaiveBayes classifier = (NaiveBayes) trainer.trainIncremental(instList);
Classification initialClassification = classifier.classify("Hello Everybody");
Classification initial2Classification = classifier.classify("Goodbye now");
System.out.println("Initial Classification = ");
initialClassification.print();
initial2Classification.print();
System.out.println("data alphabet " + classifier.getAlphabet());
System.out.println("label alphabet " + classifier.getLabelAlphabet());
// test
String[] t2directories = {
"src/cc/mallet/classify/tests/NaiveBayesData/learn/b"
};
System.out.println("data alphabet size " + instList.getDataAlphabet().size());
System.out.println("target alphabet size " + instList.getTargetAlphabet().size());
InstanceList instList2 = new InstanceList(instPipe);
instList2.addThruPipe(new
FileIterator(t2directories, FileIterator.STARTING_DIRECTORIES, true));
System.out.println("Training 2");
System.out.println("data alphabet size " + instList2.getDataAlphabet().size());
System.out.println("target alphabet size " + instList2.getTargetAlphabet().size());
NaiveBayes classifier2 = (NaiveBayes) trainer.trainIncremental(instList2);
Classification secondClassification = classifier.classify("Goodbye now");
secondClassification.print();
}
static Test suite ()
{
return new TestSuite (TestNaiveBayes.class);
//TestSuite suite= new TestSuite();
// //suite.addTest(new TestNaiveBayes("testIncrementallyTrained"));
// suite.addTest(new TestNaiveBayes("testEmptyStringBug"));
// return suite;
}
protected void setUp ()
{
}
public static void main (String[] args)
{
junit.textui.TestRunner.run (suite());
}
}