com.aliasi.test.unit.classify.TfIdfClassifierTrainerTest Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of aliasi-lingpipe Show documentation
Show all versions of aliasi-lingpipe Show documentation
This is the original Lingpipe:
http://alias-i.com/lingpipe/web/download.html
There were not made any changes to the source code.
package com.aliasi.test.unit.classify;
import com.aliasi.classify.TfIdfClassifierTrainer;
import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.ScoredClassification;
import com.aliasi.classify.ScoredClassifier;
import com.aliasi.tokenizer.IndoEuropeanTokenizerFactory;
import com.aliasi.tokenizer.TokenFeatureExtractor;
import com.aliasi.util.AbstractExternalizable;
import org.junit.Test;
import static junit.framework.Assert.assertEquals;
public class TfIdfClassifierTrainerTest {
@Test
public void testTf() {
TokenFeatureExtractor featureExtractor
= new TokenFeatureExtractor(IndoEuropeanTokenizerFactory.INSTANCE);
TfIdfClassifierTrainer trainer
= new TfIdfClassifierTrainer(featureExtractor);
handle(trainer,"a b b", "cat1");
handle(trainer,"b c c c d", "cat2");
handle(trainer,"c c c c c", "cat3");
handle(trainer,"a d", "cat1");
assertEquals(Math.sqrt(2.0), trainer.tf("a","cat1"));
assertEquals(0.0, trainer.tf("c","cat1"));
assertEquals(Math.sqrt(5.0), trainer.tf("c","cat3"));
assertEquals(0.0, trainer.tf("foo","cat1"));
}
@Test
public void testIdf() {
TokenFeatureExtractor featureExtractor
= new TokenFeatureExtractor(IndoEuropeanTokenizerFactory.INSTANCE);
TfIdfClassifierTrainer trainer
= new TfIdfClassifierTrainer(featureExtractor);
handle(trainer,"a b b", "cat1");
handle(trainer,"b c c c d", "cat2");
handle(trainer,"b c c c c c", "cat3");
assertEquals(Math.log(3.0/1.0), trainer.idf("a"));
assertEquals(Math.log(3.0/2.0), trainer.idf("c"));
assertEquals(Math.log(3.0/3.0), trainer.idf("b"));
assertEquals(0.0, trainer.idf("foo"));
}
@Test
public void testTfIdf() {
TokenFeatureExtractor featureExtractor
= new TokenFeatureExtractor(IndoEuropeanTokenizerFactory.INSTANCE);
TfIdfClassifierTrainer trainer
= new TfIdfClassifierTrainer(featureExtractor);
handle(trainer,"a b b", "cat1");
handle(trainer,"b c c c d", "cat2");
handle(trainer,"b c c c c c", "cat3");
assertEquals(0.0, trainer.tfIdf("foo","bar"));
assertEquals(0.0, trainer.tfIdf("a","bar"));
assertEquals(0.0, trainer.tfIdf("foo","cat1"));
assertEquals(Math.sqrt(1.0), trainer.tf("a","cat1"));
assertEquals(Math.log(3.0/1.0), trainer.idf("a"));
assertEquals(Math.sqrt(1.0) * Math.log(3.0/1.0), trainer.tfIdf("a","cat1"));
assertEquals(Math.sqrt(2.0) * Math.log(3.0/3.0), trainer.tfIdf("b","cat1"));
assertEquals(0.0, trainer.tfIdf("c","cat1"));
assertEquals(0.0, trainer.tfIdf("d","cat1"));
assertEquals(Math.sqrt(0.0) * Math.log(3.0/1.0), trainer.tfIdf("a","cat2"));
assertEquals(Math.sqrt(2.0) * Math.log(3.0/3.0), trainer.tfIdf("b","cat2"));
assertEquals(Math.sqrt(3.0) * Math.log(3.0/2.0), trainer.tfIdf("c","cat2"));
assertEquals(Math.sqrt(1.0) * Math.log(3.0/1.0), trainer.tfIdf("d","cat2"));
assertEquals(Math.sqrt(0.0) * Math.log(3.0/1.0), trainer.tfIdf("a","cat3"));
assertEquals(Math.sqrt(1.0) * Math.log(3.0/3.0), trainer.tfIdf("b","cat3"));
assertEquals(Math.sqrt(5.0) * Math.log(3.0/2.0), trainer.tfIdf("c","cat3"));
assertEquals(0.0, trainer.tfIdf("d","cat3"));
}
static void handle(TfIdfClassifierTrainer classifier,
String input,
Classification c) {
classifier.handle(new Classified(input,c));
}
static void handle(TfIdfClassifierTrainer classifier,
String input,
String cat) {
handle(classifier,input, new Classification(cat));
}
@Test
public void testOne() throws Exception {
TokenFeatureExtractor featureExtractor
= new TokenFeatureExtractor(IndoEuropeanTokenizerFactory.INSTANCE);
TfIdfClassifierTrainer trainer
= new TfIdfClassifierTrainer(featureExtractor);
handle(trainer,"a b b", new Classification("cat1"));
handle(trainer,"b c c c d", new Classification("cat2"));
handle(trainer,"c c c c c", new Classification("cat3"));
double cat1_a = Math.sqrt(1.0) * java.lang.Math.log(3.0/1.0);
double cat1_b = Math.sqrt(2.0) * java.lang.Math.log(3.0/2.0);
double len1 = Math.sqrt(cat1_a * cat1_a + cat1_b * cat1_b);
double cat1_a_n = cat1_a/len1;
double cat1_b_n = cat1_b/len1;
double cat2_b = Math.sqrt(1.0) * java.lang.Math.log(3.0/2.0);
double cat2_c = Math.sqrt(3.0) * java.lang.Math.log(3.0/2.0);
double cat2_d = Math.sqrt(1.0) * java.lang.Math.log(3.0);
double len2 = Math.sqrt(cat2_b*cat2_b + cat2_c*cat2_c + cat2_d*cat2_d);
double cat2_b_n = cat2_b/len2;
double cat2_c_n = cat2_c/len2;
double cat2_d_n = cat2_d/len2;
double cat3_c = Math.sqrt(5.0) * java.lang.Math.log(3.0/2.0);
double len3 = Math.sqrt(cat3_c * cat3_c);
double cat3_c_n = cat3_c/len3; // = 1.0 :-)
ScoredClassifier classifier
= (ScoredClassifier)
AbstractExternalizable.compile(trainer);
ScoredClassification classification
= classifier.classify("a b b");
assertEquals("cat1",classification.bestCategory());
assertEquals("cat1",classification.category(0));
assertEquals(1.0,classification.score(0),0.001);
assertEquals("cat2",classification.category(1));
assertEquals(cat1_b_n * cat2_b_n,classification.score(1),0.05); // off by 0.01
assertEquals("cat3",classification.category(2));
assertEquals(0.0,classification.score(2),0.001);
TfIdfClassifierTrainer trainer2
= (TfIdfClassifierTrainer)
AbstractExternalizable.serializeDeserialize(trainer);
ScoredClassifier classifier2
= (ScoredClassifier)
AbstractExternalizable.compile(trainer2);
assertEquals("cat1",classification.bestCategory());
assertEquals("cat1",classification.category(0));
assertEquals(1.0,classification.score(0),0.001);
assertEquals("cat2",classification.category(1));
assertEquals(cat1_b_n * cat2_b_n,classification.score(1),0.05); // off by 0.01
assertEquals("cat3",classification.category(2));
assertEquals(0.0,classification.score(2),0.001);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy