All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aliasi.test.unit.classify.TradNaiveBayesClassifierTest Maven / Gradle / Ivy

Go to download

This is the original Lingpipe: http://alias-i.com/lingpipe/web/download.html There were not made any changes to the source code.

There is a newer version: 4.1.2-JL1.0
Show newest version
package com.aliasi.test.unit.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.ConditionalClassification;
import com.aliasi.classify.ConditionalClassifier;
import com.aliasi.classify.JointClassifier;
import com.aliasi.classify.TradNaiveBayesClassifier;

import com.aliasi.tokenizer.IndoEuropeanTokenizerFactory;
import com.aliasi.tokenizer.TokenizerFactory;

import com.aliasi.util.AbstractExternalizable;

import java.io.IOException;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import org.junit.Test;
import static com.aliasi.test.unit.Asserts.succeed;
import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertTrue;
import static junit.framework.Assert.assertFalse;
import static junit.framework.Assert.fail;

public class TradNaiveBayesClassifierTest {

    static final String[] CATS_2 = new String[] { "a", "b" };
    static final Set CAT_SET_2 = listToSet(CATS_2);
    static final TokenizerFactory TOKENIZER_FACTORY
        = IndoEuropeanTokenizerFactory.INSTANCE;

    @Test
    public void testSetLengthNorm() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(listToSet(new String[] { "a", "b" }),
                                           TOKENIZER_FACTORY,
                                           1,1,10.0);
        assertEquals(10.0,classifier.lengthNorm(),0.0001);
        classifier.setLengthNorm(Double.NaN);
        assertTrue(Double.isNaN(classifier.lengthNorm()));

        classifier.setLengthNorm(20.0);
        assertEquals(20.0,classifier.lengthNorm(),0.0001);
    }

    @Test(expected=IllegalArgumentException.class)
    public void testSetLengthNormNegExc() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(listToSet(new String[] { "a", "b" }),
                                           TOKENIZER_FACTORY,
                                           1,1,-1);
    }

    @Test(expected=IllegalArgumentException.class)
    public void testSetLengthNormNegExc2() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(listToSet(new String[] { "a", "b" }),
                                           TOKENIZER_FACTORY,
                                           1,1,10);
        classifier.setLengthNorm(Double.POSITIVE_INFINITY);
    }

    @Test(expected=IllegalArgumentException.class)
    public void testSetLengthNormInfExc() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(listToSet(new String[] { "a", "b" }),
                                           TOKENIZER_FACTORY,
                                           1,1,Double.POSITIVE_INFINITY);
    }

    @Test(expected=IllegalArgumentException.class)
    public void testSetLengthNormInfExc2() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(listToSet(new String[] { "a", "b" }),
                                           TOKENIZER_FACTORY,
                                           1,1,10.0);
        classifier.setLengthNorm(Double.POSITIVE_INFINITY);
    }




    @Test
    public void testTrain() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(listToSet(new String[] { "a", "b" }),
                                           TOKENIZER_FACTORY,
                                           1,1,Double.NaN);

        try {
            classifier.train("", new Classification("a"), -1);
            fail();
        } catch (IllegalArgumentException e) {
            succeed();
        }

        classifier.train("john ran", new Classification("a"), 2);
        classifier.train("john ran", new Classification("a"), -1);
        classifier.train("john ran", new Classification("a"), -1);
        classifier.train("run, mary, run", new Classification("b"),3);
        classifier.train("john jumped", new Classification("a"),3);

        try {
            classifier.train("john ran", new Classification("a"), -1);
            fail();
        } catch (IllegalArgumentException e) {
            succeed();
        }

        try {
            classifier.train("mary hopscotched", new Classification("a"), -1);
            fail();
        } catch (IllegalArgumentException e) {
            succeed();
        }

        classifier.train("john ran",new Classification("a"),3);
        classifier.train("mary ran", new Classification("b"),3);

        classifier.train("john ran",new Classification("a"),-2);
        classifier.train("john jumped", new Classification("a"),-2);
        classifier.train("mary ran", new Classification("b"),-2);
        classifier.train("run, mary, run", new Classification("b"),-2);

        classifier.setLengthNorm(10.0); // shouldn't actually matter here

        assertEquals((2.0 + 1.0)/(4.0 + 6.0 * 1.0), classifier.probToken("john","a"),0.001);
        assertEquals((1.0 + 1.0)/(7.0 + 6.0 * 1.0), classifier.probToken("ran","b"),0.001);
        assertEquals((0.0 + 1.0)/(4.0 + 6.0 * 1.0), classifier.probToken("run","a"),0.001);

        assertEquals(0.5, classifier.probCat("a"), 0.001);
        assertEquals(0.5, classifier.probCat("b"), 0.001);
    }

    @Test
    public void testLengthNormStore() throws IOException {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(listToSet(new String[] { "a", "b" }),
                                           TOKENIZER_FACTORY,
                                           1,1,42.0);
        assertEquals(42.0,classifier.lengthNorm(),0.0001);

        TradNaiveBayesClassifier classifier2
            = (TradNaiveBayesClassifier) AbstractExternalizable.serializeDeserialize(classifier);

        assertEquals(42.0,classifier2.lengthNorm(),0.0001);
    }


    @Test(expected = IllegalArgumentException.class)
    public void testCatsEmpty() {
        new TradNaiveBayesClassifier(listToSet(new String[] { }),
                                     TOKENIZER_FACTORY,
                                     1,1,1.0);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testCats1() {
        String[] cats = new String[] { "foo" };
        new TradNaiveBayesClassifier(listToSet(cats),
                                     TOKENIZER_FACTORY,
                                     1,1,1.0);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testNegCatPrior() {
        new TradNaiveBayesClassifier(CAT_SET_2,
                                     TOKENIZER_FACTORY,
                                     -1,1,1.0);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testNaNCatPrior() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,
                                           TOKENIZER_FACTORY,
                                           Double.NaN,1,1.0);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testInfCatPrior() {
        new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                     Double.POSITIVE_INFINITY,1,1.0);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testNegWordPrior() {
        new TradNaiveBayesClassifier(CAT_SET_2,
                                     TOKENIZER_FACTORY,
                                     1,-1,1.0);
    }

    @Test
    public void testCatSet() {
        String[] cats = new String[] { "a", "c", "d", "b" };
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(listToSet(cats),TOKENIZER_FACTORY,
                                           1,1,1.0);
        assertEquals(new HashSet(Arrays.asList(cats)),
                     classifier.categorySet());
    }

    @Test(expected = IllegalArgumentException.class)
    public void testLengthNormInf() {
        new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                     1,1,Double.POSITIVE_INFINITY);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testLengthNormZero() {
        new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                     1,1,0.0);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testLengthNormNeg() {
        new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                     1,1,-12);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testUnknownCatTrainException() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,1);
        classifier.handle(new Classified("any old string",
                                                       new Classification("unknownCat")));
    }

    @Test
    public void knownTokenTest() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,1);
        classifier.handle(new Classified("john ran",new Classification("a")));
        assertTrue(classifier.isKnownToken("john"));
        assertTrue(classifier.isKnownToken("ran"));
        assertFalse(classifier.isKnownToken("unknownTok"));
    }

    static void handle(TradNaiveBayesClassifier classifier,
                       String input, 
                       Classification c) {
        classifier.handle(new Classified(input,c));
    }

    @Test
    public void log2CaseProbTest() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,Double.NaN);
        classifier.handle(new Classified("john ran",new Classification("a")));
        handle(classifier,"john jumped", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));
        handle(classifier,"run, mary, run", new Classification("b"));

        assertEquals(com.aliasi.util.Math.log2(classifier.probCat("a") * classifier.probToken("mary","a") * classifier.probToken("jumped","a")
                                               + classifier.probCat("b") * classifier.probToken("mary","b") * classifier.probToken("jumped","b")),
                     classifier.log2CaseProb("mary jumped"),
                     0.001);
        assertEquals(com.aliasi.util.Math.log2(classifier.probCat("a") + classifier.probCat("b")), // should be 1.0
                     classifier.log2CaseProb(""), // normalized by length, should also be 1.0
                     0.0001);
    }

    /*
    @Test
    public void log2ModelProbTest() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,Double.NaN);
        handle(classifier,"john ran",new Classification("a"));
        handle(classifier,"john jumped", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));
        handle(classifier,"run, mary, run", new Classification("b"));

        assertEquals(com.aliasi.util.Math.log2(classifier.probCat("a") + classifier.probCat("b")), // should be 1.0
                     classifier.log2CaseProb(""), // normalized by length, should also be 1.0
                     0.0001);
    }
    */


    @Test
    public void probTokenTest() throws IOException {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,Double.NaN);
        handle(classifier,"john ran",new Classification("a"));
        handle(classifier,"john jumped", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));
        handle(classifier,"run, mary, run", new Classification("b"));

        assertEquals((2.0 + 1.0)/(4.0 + 6.0 * 1.0), classifier.probToken("john","a"),0.001);
        assertEquals((1.0 + 1.0)/(7.0 + 6.0 * 1.0), classifier.probToken("ran","b"),0.001);
        assertEquals((0.0 + 1.0)/(4.0 + 6.0 * 1.0), classifier.probToken("run","a"),0.001);

        TradNaiveBayesClassifier classifier2
            = (TradNaiveBayesClassifier) AbstractExternalizable.serializeDeserialize(classifier);

        assertEquals((2.0 + 1.0)/(4.0 + 6.0 * 1.0), classifier2.probToken("john","a"),0.001);
        assertEquals((1.0 + 1.0)/(7.0 + 6.0 * 1.0), classifier2.probToken("ran","b"),0.001);
        assertEquals((0.0 + 1.0)/(4.0 + 6.0 * 1.0), classifier2.probToken("run","a"),0.001);

        assertEquals(classifier.categorySet(), classifier2.categorySet());
        assertEquals(classifier.knownTokenSet(), classifier2.knownTokenSet());

        assertFalse(classifier2.isKnownToken("thisisoneidon'tknow"));
        for (String token : classifier.knownTokenSet()) {
            assertTrue(classifier.isKnownToken(token));
            assertTrue(classifier2.isKnownToken(token));
        }
    }

    @Test(expected = IllegalArgumentException.class)
    public void probTokenTestUnknownToken() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,1);
        handle(classifier,"john ran", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));

        classifier.probToken("unknownTok","a");
    }

    @Test(expected = IllegalArgumentException.class)
    public void probTokenTestUnknownTokenSerDeser() throws IOException {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,1);
        handle(classifier,"john ran", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));

        TradNaiveBayesClassifier classifier2
            = (TradNaiveBayesClassifier) AbstractExternalizable.serializeDeserialize(classifier);

        classifier2.probToken("unknownTok","a");
    }


    @Test(expected = IllegalArgumentException.class)
    public void probTokenTestUnknownCat() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,1);
        handle(classifier,"john ran", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));

        classifier.probToken("john","unknownCat");
    }



    @Test(expected = IllegalArgumentException.class)
    public void probTokenTestUnknownCatSerDeser() throws IOException {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,1);
        handle(classifier,"john ran", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));

        TradNaiveBayesClassifier classifier2
            = (TradNaiveBayesClassifier) AbstractExternalizable.serializeDeserialize(classifier);

        classifier2.probToken("john","unknownCat");
    }


    @Test
    public void testKnownTokenSet() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,1);

        Set knownTokenSet = classifier.knownTokenSet();
        Set expectedTokenSet = new HashSet();
        assertEquals(expectedTokenSet,knownTokenSet);

        handle(classifier,"john ran", new Classification("a"));
        expectedTokenSet.add("john");
        expectedTokenSet.add("ran");
        assertEquals(expectedTokenSet,knownTokenSet);

        handle(classifier,"mary ran", new Classification("b"));
        expectedTokenSet.add("mary");
        assertEquals(expectedTokenSet,knownTokenSet);

        for (String token : knownTokenSet)
            assertTrue(classifier.isKnownToken(token));

    }

    @Test
    public void testProbCat() throws IOException {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,Double.NaN);
        handle(classifier,"john ran",new Classification("a"));
        handle(classifier,"john jumped", new Classification("a"));
        handle(classifier,"john ran and jumped", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));
        handle(classifier,"run, mary, run", new Classification("b"));

        assertEquals((3.0 + 1.0) / (5.0 + 2.0*1.0),classifier.probCat("a"),0.00001);
        assertEquals(1.0 - (3.0 + 1.0) / (5.0 + 2.0*1.0),classifier.probCat("b"),0.00001);

        TradNaiveBayesClassifier classifier2
            = (TradNaiveBayesClassifier) AbstractExternalizable.serializeDeserialize(classifier);

        assertEquals((3.0 + 1.0) / (5.0 + 2.0*1.0),classifier2.probCat("a"),0.00001);
        assertEquals(1.0 - (3.0 + 1.0) / (5.0 + 2.0*1.0),classifier2.probCat("b"),0.00001);

    }

    @Test(expected = IllegalArgumentException.class)
    public void testProbCatUnknownCatExc() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,1);
        handle(classifier,"john ran",new Classification("a"));
        handle(classifier,"john jumped", new Classification("a"));
        handle(classifier,"john ran and jumped", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));
        handle(classifier,"run, mary, run", new Classification("b"));

        classifier.probCat("unknownCat");
    }


    @Test(expected = IllegalArgumentException.class)
    public void testProbCatUnknownCatExcSerDeser() throws IOException {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,1);
        handle(classifier,"john ran",new Classification("a"));
        handle(classifier,"john jumped", new Classification("a"));
        handle(classifier,"john ran and jumped", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));
        handle(classifier,"run, mary, run", new Classification("b"));

        TradNaiveBayesClassifier classifier2
            = (TradNaiveBayesClassifier) AbstractExternalizable.serializeDeserialize(classifier);

        classifier2.probCat("unknownCat");
    }


    @Test
    public void testClassification() {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,Double.NaN);
        handle(classifier,"john ran",new Classification("a"));
        handle(classifier,"john jumped", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));
        handle(classifier,"run, mary, run", new Classification("b"));
        handle(classifier,"mary run", new Classification("b"));

        ConditionalClassification c = classifier.classify("");
        assertEquals(2,c.size());
        assertEquals("a",c.category(1));
        assertEquals("b",c.category(0));
        assertEquals(3.0/7.0,c.conditionalProbability(1),0.0001);
        assertEquals(4.0/7.0,c.conditionalProbability(0),0.0001);

        ConditionalClassification c2 = classifier.classify("backbends");
        assertEquals(2,c2.size());
        assertEquals("a",c2.category(1));
        assertEquals("b",c2.category(0));
        assertEquals(3.0/7.0,c2.conditionalProbability(1),0.0001);
        assertEquals(4.0/7.0,c2.conditionalProbability(0),0.0001);

        ConditionalClassification c3 = classifier.classify("john");
        assertEquals("a",c3.category(0));
        assertEquals("b",c3.category(1));
        double z3 = 9.0/70.0 + 4.0/105.0;
        assertEquals(9.0/70.0 / z3,c3.conditionalProbability(0),0.0001);
        assertEquals(4.0/105.0 / z3,c3.conditionalProbability(1),0.0001);

        ConditionalClassification c4 = classifier.classify("john smith was here");
        assertEquals("a",c4.category(0));
        assertEquals("b",c4.category(1));
        double z4 = 9.0/70.0 + 4.0/105.0;
        assertEquals(9.0/70.0 / z4,c4.conditionalProbability(0),0.0001);
        assertEquals(4.0/105.0 / z4,c4.conditionalProbability(1),0.0001);

        ConditionalClassification c5 = classifier.classify("john saw mary");
        assertEquals("a",c5.category(0));
        assertEquals("b",c5.category(1));
        double z5 = 9.0/700.0 + 16.0/1575.0;
        assertEquals(9.0/700.0/z5,c5.conditionalProbability(0),0.0001);
        assertEquals(16.0/1575.0/z5,c5.conditionalProbability(1),0.0001);
    }


    @Test
    public void testClassificationSerDeser() throws IOException {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,Double.NaN);
        handle(classifier,"john ran",new Classification("a"));
        handle(classifier,"john jumped", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));
        handle(classifier,"run, mary, run", new Classification("b"));
        handle(classifier,"mary run", new Classification("b"));

        TradNaiveBayesClassifier classifier2
            = (TradNaiveBayesClassifier) AbstractExternalizable.serializeDeserialize(classifier);

        ConditionalClassification c = classifier2.classify("");
        assertEquals(2,c.size());
        assertEquals("a",c.category(1));
        assertEquals("b",c.category(0));
        assertEquals(3.0/7.0,c.conditionalProbability(1),0.0001);
        assertEquals(4.0/7.0,c.conditionalProbability(0),0.0001);

        ConditionalClassification c2 = classifier2.classify("backbends");
        assertEquals(2,c2.size());
        assertEquals("a",c2.category(1));
        assertEquals("b",c2.category(0));
        assertEquals(3.0/7.0,c2.conditionalProbability(1),0.0001);
        assertEquals(4.0/7.0,c2.conditionalProbability(0),0.0001);

        ConditionalClassification c3 = classifier2.classify("john");
        assertEquals("a",c3.category(0));
        assertEquals("b",c3.category(1));
        double z3 = 9.0/70.0 + 4.0/105.0;
        assertEquals(9.0/70.0 / z3,c3.conditionalProbability(0),0.0001);
        assertEquals(4.0/105.0 / z3,c3.conditionalProbability(1),0.0001);

        ConditionalClassification c4 = classifier2.classify("john smith was here");
        assertEquals("a",c4.category(0));
        assertEquals("b",c4.category(1));
        double z4 = 9.0/70.0 + 4.0/105.0;
        assertEquals(9.0/70.0 / z4,c4.conditionalProbability(0),0.0001);
        assertEquals(4.0/105.0 / z4,c4.conditionalProbability(1),0.0001);

        ConditionalClassification c5 = classifier2.classify("john saw mary");
        assertEquals("a",c5.category(0));
        assertEquals("b",c5.category(1));
        double z5 = 9.0/700.0 + 16.0/1575.0;
        assertEquals(9.0/700.0/z5,c5.conditionalProbability(0),0.0001);
        assertEquals(16.0/1575.0/z5,c5.conditionalProbability(1),0.0001);
    }


    @Test
    public void testClassificationCompile() throws ClassNotFoundException, IOException {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,Double.NaN);
        handle(classifier,"john ran",new Classification("a"));
        handle(classifier,"john jumped", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));
        handle(classifier,"run, mary, run", new Classification("b"));
        handle(classifier,"mary run", new Classification("b"));

        ConditionalClassifier classifier2
            = (ConditionalClassifier)
            AbstractExternalizable.compile(classifier);

        ConditionalClassification c = classifier2.classify("");
        assertEquals(2,c.size());
        assertEquals("a",c.category(1));
        assertEquals("b",c.category(0));
        assertEquals(3.0/7.0,c.conditionalProbability(1),0.0001);
        assertEquals(4.0/7.0,c.conditionalProbability(0),0.0001);

        ConditionalClassification c2 = classifier2.classify("backbends");
        assertEquals(2,c2.size());
        assertEquals("a",c2.category(1));
        assertEquals("b",c2.category(0));
        assertEquals(3.0/7.0,c2.conditionalProbability(1),0.0001);
        assertEquals(4.0/7.0,c2.conditionalProbability(0),0.0001);

        ConditionalClassification c3 = classifier2.classify("john");
        assertEquals("a",c3.category(0));
        assertEquals("b",c3.category(1));
        double z3 = 9.0/70.0 + 4.0/105.0;
        assertEquals(9.0/70.0 / z3,c3.conditionalProbability(0),0.0001);
        assertEquals(4.0/105.0 / z3,c3.conditionalProbability(1),0.0001);

        ConditionalClassification c4 = classifier2.classify("john smith was here");
        assertEquals("a",c4.category(0));
        assertEquals("b",c4.category(1));
        double z4 = 9.0/70.0 + 4.0/105.0;
        assertEquals(9.0/70.0 / z4,c4.conditionalProbability(0),0.0001);
        assertEquals(4.0/105.0 / z4,c4.conditionalProbability(1),0.0001);

        ConditionalClassification c5 = classifier2.classify("john saw mary");
        assertEquals("a",c5.category(0));
        assertEquals("b",c5.category(1));
        double z5 = 9.0/700.0 + 16.0/1575.0;
        assertEquals(9.0/700.0/z5,c5.conditionalProbability(0),0.0001);
        assertEquals(16.0/1575.0/z5,c5.conditionalProbability(1),0.0001);
    }

    @Test
    public void testLengthNorm() throws IOException, ClassNotFoundException {
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(CAT_SET_2,TOKENIZER_FACTORY,
                                           1,1,Double.NaN);

        // train without norm
        handle(classifier,"john ran",new Classification("a"));
        handle(classifier,"john jumped", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));
        handle(classifier,"run, mary, run", new Classification("b"));
        handle(classifier,"mary run", new Classification("b"));

        // eval with norm
        classifier.setLengthNorm(10.0);

        JointClassifier classifierSer
            = (JointClassifier)
            AbstractExternalizable.serializeDeserialize(classifier);

        
        ConditionalClassifier classifierComp
            = (ConditionalClassifier)
            AbstractExternalizable.compile(classifier);

        ConditionalClassification c1 = classifier.classify("");
        assertEquals("b",c1.category(0));
        assertEquals("a",c1.category(1));
        assertEquals(4.0/7.0,c1.conditionalProbability(0),0.0001);
        assertEquals(3.0/7.0,c1.conditionalProbability(1),0.0001);

        ConditionalClassification c1c = classifierComp.classify("unknownToken");
        assertEquals("b",c1c.category(0));
        assertEquals("a",c1c.category(1));
        assertEquals(4.0/7.0,c1c.conditionalProbability(0),0.0001);
        assertEquals(3.0/7.0,c1c.conditionalProbability(1),0.0001);

        ConditionalClassification c1s = classifierSer.classify("unknownToken unknownToken");
        assertEquals("b",c1s.category(0));
        assertEquals("a",c1s.category(1));
        assertEquals(4.0/7.0,c1s.conditionalProbability(0),0.0001);
        assertEquals(3.0/7.0,c1s.conditionalProbability(1),0.0001);

        double jointA = 3.0/7.0 * java.lang.Math.pow(3.0/10.0,10.0);
        double jointB = 4.0/7.0 * java.lang.Math.pow(1.0/15.0,10.0);

        double expA = jointA / (jointA + jointB);
        double expB = jointB / (jointA + jointB);

        ConditionalClassification c2 = classifier.classify("john");
        assertEquals(2,c2.size());
        assertEquals("a",c2.category(0));
        assertEquals("b",c2.category(1));
        assertEquals(expA,c2.conditionalProbability(0),0.0001);
        assertEquals(expB,c2.conditionalProbability(1),0.0001);

        ConditionalClassification c2s = classifierSer.classify("john");
        assertEquals(2,c2s.size());
        assertEquals("a",c2s.category(0));
        assertEquals("b",c2s.category(1));
        assertEquals(expA,c2s.conditionalProbability(0),0.0001);
        assertEquals(expB,c2s.conditionalProbability(1),0.0001);

        ConditionalClassification c2c = classifierComp.classify("john");
        assertEquals(2,c2c.size());
        assertEquals("a",c2c.category(0));
        assertEquals("b",c2c.category(1));
        assertEquals(expA,c2c.conditionalProbability(0),0.0001);
        assertEquals(expB,c2c.conditionalProbability(1),0.0001);

        jointA = 3.0/7.0 * java.lang.Math.pow(3.0/10.0 * 2.0/10.0,10.0/2.0);
        jointB = 4.0/7.0 * java.lang.Math.pow(1.0/15.0 * 2.0/10.0,10.0/2.0);

        expA = jointA / (jointA + jointB);
        expB = jointB / (jointA + jointB);

        ConditionalClassification c3 = classifier.classify("john");
        assertEquals(2,c3.size());
        assertEquals("a",c3.category(0));
        assertEquals("b",c3.category(1));
        assertEquals(expA,c3.conditionalProbability(0),0.001);
        assertEquals(expB,c3.conditionalProbability(1),0.001);

        ConditionalClassification c3s = classifierSer.classify("john");
        assertEquals(2,c3s.size());
        assertEquals("a",c3s.category(0));
        assertEquals("b",c3s.category(1));
        assertEquals(expA,c3s.conditionalProbability(0),0.001);
        assertEquals(expB,c3s.conditionalProbability(1),0.001);

        ConditionalClassification c3c = classifierComp.classify("john");
        assertEquals(2,c3c.size());
        assertEquals("a",c3c.category(0));
        assertEquals("b",c3c.category(1));
        assertEquals(expA,c3c.conditionalProbability(0),0.001);
        assertEquals(expB,c3c.conditionalProbability(1),0.001);
    }


    @Test
    public void testLengthNormTernary() throws IOException, ClassNotFoundException {
        Set catSet3 = new HashSet(Arrays.asList(new String[] { "a", "b", "c" }));
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(catSet3,TOKENIZER_FACTORY,
                                           1,1,Double.NaN);

        // train without length norm
        handle(classifier,"john ran",new Classification("a"));
        handle(classifier,"john jumped", new Classification("a"));
        handle(classifier,"mary ran", new Classification("b"));
        handle(classifier,"run, mary, run", new Classification("b"));
        handle(classifier,"mary run", new Classification("b"));
        handle(classifier,"bill ran", new Classification("c"));

        // set it for running
        classifier.setLengthNorm(10.0);

        JointClassifier classifierSer
            = (JointClassifier)
            AbstractExternalizable.serializeDeserialize(classifier);

        ConditionalClassifier classifierComp
            = (ConditionalClassifier)
            AbstractExternalizable.compile(classifier);

        double jointA = 3.0/9.0;
        double jointB = 4.0/9.0;
        double jointC = 2.0/9.0;

        double expA = jointA/(jointA+jointB+jointC);
        double expB = jointB/(jointA+jointB+jointC);
        double expC = jointC/(jointA+jointB+jointC);

        ConditionalClassification c = classifier.classify("");
        assertEquals(3,c.size());
        assertEquals("b",c.category(0));
        assertEquals("a",c.category(1));
        assertEquals("c",c.category(2));
        assertEquals(expB,c.conditionalProbability(0),0.001);
        assertEquals(expA,c.conditionalProbability(1),0.001);
        assertEquals(expC,c.conditionalProbability(2),0.001);

        c = classifierSer.classify("");
        assertEquals(3,c.size());
        assertEquals("b",c.category(0));
        assertEquals("a",c.category(1));
        assertEquals("c",c.category(2));
        assertEquals(expB,c.conditionalProbability(0),0.001);
        assertEquals(expA,c.conditionalProbability(1),0.001);
        assertEquals(expC,c.conditionalProbability(2),0.001);

        c = classifierComp.classify("");
        assertEquals(3,c.size());
        assertEquals("b",c.category(0));
        assertEquals("a",c.category(1));
        assertEquals("c",c.category(2));
        assertEquals(expB,c.conditionalProbability(0),0.001);
        assertEquals(expA,c.conditionalProbability(1),0.001);
        assertEquals(expC,c.conditionalProbability(2),0.001);

        c = classifier.classify("blah blah");
        assertEquals(3,c.size());
        assertEquals("b",c.category(0));
        assertEquals("a",c.category(1));
        assertEquals("c",c.category(2));
        assertEquals(expB,c.conditionalProbability(0),0.001);
        assertEquals(expA,c.conditionalProbability(1),0.001);
        assertEquals(expC,c.conditionalProbability(2),0.001);

        c = classifierSer.classify("blah blah blah");
        assertEquals(3,c.size());
        assertEquals("b",c.category(0));
        assertEquals("a",c.category(1));
        assertEquals("c",c.category(2));
        assertEquals(expB,c.conditionalProbability(0),0.001);
        assertEquals(expA,c.conditionalProbability(1),0.001);
        assertEquals(expC,c.conditionalProbability(2),0.001);

        c = classifierComp.classify("blah blah blah blah blah");
        assertEquals(3,c.size());
        assertEquals("b",c.category(0));
        assertEquals("a",c.category(1));
        assertEquals("c",c.category(2));
        assertEquals(expB,c.conditionalProbability(0),0.001);
        assertEquals(expA,c.conditionalProbability(1),0.001);
        assertEquals(expC,c.conditionalProbability(2),0.001);

        jointA = 3.0/9.0 * java.lang.Math.pow(2.0/11.0, 10.0/1.0);
        jointB = 4.0/9.0 * java.lang.Math.pow(2.0/16.0, 10.0/1.0);
        jointC = 2.0/9.0 * java.lang.Math.pow(2.0/9.0, 10.0/1.0);

        expA = jointA/(jointA+jointB+jointC);
        expB = jointB/(jointA+jointB+jointC);
        expC = jointC/(jointA+jointB+jointC);

        c = classifier.classify("ran");
        assertEquals(3,c.size());
        assertEquals("c",c.category(0));
        assertEquals("a",c.category(1));
        assertEquals("b",c.category(2));
        assertEquals(expC,c.conditionalProbability(0),0.001);
        assertEquals(expA,c.conditionalProbability(1),0.001);
        assertEquals(expB,c.conditionalProbability(2),0.001);

        c = classifierSer.classify("ran");
        assertEquals(3,c.size());
        assertEquals("c",c.category(0));
        assertEquals("a",c.category(1));
        assertEquals("b",c.category(2));
        assertEquals(expC,c.conditionalProbability(0),0.001);
        assertEquals(expA,c.conditionalProbability(1),0.001);
        assertEquals(expB,c.conditionalProbability(2),0.001);

        c = classifierComp.classify("ran");
        assertEquals(3,c.size());
        assertEquals("c",c.category(0));
        assertEquals("a",c.category(1));
        assertEquals("b",c.category(2));
        assertEquals(expC,c.conditionalProbability(0),0.001);
        assertEquals(expA,c.conditionalProbability(1),0.001);
        assertEquals(expB,c.conditionalProbability(2),0.001);


        jointA = 3.0/9.0 * java.lang.Math.pow((1.0/11.0) * (1.0/11.0), 10.0/2.0);
        jointB = 4.0/9.0 * java.lang.Math.pow((1.0/16.0) * (4.0/16.0), 10.0/2.0);
        jointC = 2.0/9.0 * java.lang.Math.pow((2.0/9.0) * (1.0/9.0), 10.0/2.0);

        expA = jointA/(jointA+jointB+jointC);
        expB = jointB/(jointA+jointB+jointC);
        expC = jointC/(jointA+jointB+jointC);

        c = classifier.classify("bill run");
        assertEquals(3,c.size());
        assertEquals("c",c.category(0));
        assertEquals("b",c.category(1));
        assertEquals("a",c.category(2));
        assertEquals(expC,c.conditionalProbability(0),0.001);
        assertEquals(expB,c.conditionalProbability(1),0.001);
        assertEquals(expA,c.conditionalProbability(2),0.001);

        c = classifierSer.classify("bill run");
        assertEquals(3,c.size());
        assertEquals("c",c.category(0));
        assertEquals("b",c.category(1));
        assertEquals("a",c.category(2));
        assertEquals(expC,c.conditionalProbability(0),0.001);
        assertEquals(expB,c.conditionalProbability(1),0.001);
        assertEquals(expA,c.conditionalProbability(2),0.001);

        c = classifierComp.classify("bill run");
        assertEquals(3,c.size());
        assertEquals("c",c.category(0));
        assertEquals("b",c.category(1));
        assertEquals("a",c.category(2));
        assertEquals(expC,c.conditionalProbability(0),0.001);
        assertEquals(expB,c.conditionalProbability(1),0.001);
        assertEquals(expA,c.conditionalProbability(2),0.001);
    }

    @Test
    public void testNormTrain() {
        Set catSet2 = new HashSet(Arrays.asList(new String[] { "A", "B" }));
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(catSet2,TOKENIZER_FACTORY,
                                           10.0, 0.5, 1.0);

        classifier.train("aa aa bb bb", new Classification("A"), 1.0);
        classifier.train("aa aa aa", new Classification("A"), 1.0);
        classifier.train("bb bb bb", new Classification("B"), 1.0);
        classifier.train("bb cc cc bb", new Classification("B"), 1.0);
        classifier.train("bb bb bb dd", new Classification("B"), 1.0);

        assertEquals((2.0 + 10.0)/(5.0 + 2.0*10.0), classifier.probCat("A"), 0.0001);
        assertEquals((3.0 + 10.0)/(5.0 + 2.0*10.0), classifier.probCat("B"), 0.0001);

        assertEquals(2.0/4.0, classifier.probToken("aa","A"), 0.0001);
        assertEquals(1.0/4.0, classifier.probToken("bb","A"), 0.0001);
        assertEquals(0.5/4.0, classifier.probToken("cc","A"), 0.0001);
        assertEquals(0.5/4.0, classifier.probToken("dd","A"), 0.0001);

        assertEquals(0.5/5.0, classifier.probToken("aa","B"), 0.0001);
        assertEquals(2.75/5.0, classifier.probToken("bb","B"), 0.0001);
        assertEquals(1.0/5.0, classifier.probToken("cc","B"), 0.0001);
        assertEquals(0.75/5.0, classifier.probToken("dd","B"), 0.0001);

    }

    @Test
    public void testTrainConditional() {
        Set catSet2 = new HashSet(Arrays.asList(new String[] { "A", "B" }));
        TradNaiveBayesClassifier classifier
            = new TradNaiveBayesClassifier(catSet2,TOKENIZER_FACTORY,
                                           0.5, 0.5, Double.NaN);
        ConditionalClassification c1
            = new ConditionalClassification(new String[] { "A", "B" },
                                            new double[] { .75, .25 });
        classifier.trainConditional("aa", c1, 4.0, 0.0);

        ConditionalClassification c2
            = new ConditionalClassification(new String[] { "B", "A" },
                                            new double[] { 0.9, 0.1 });
        classifier.trainConditional("bb",c2, 4.0, 0.0);

        assertEquals((3.4+0.5)/(8.0 + 2.0*0.5),classifier.probCat("A"),0.0001);
        assertEquals((4.6+0.5)/(8.0 + 2.0*0.5),classifier.probCat("B"),0.0001);

        // A:aa:3
        // B:aa:1


        assertEquals(3.5/4.4,classifier.probToken("aa","A"),0.0001);
        assertEquals(0.9/4.4,classifier.probToken("bb","A"),0.0001);
        assertEquals(1.5/5.6,classifier.probToken("aa","B"),0.0001);
        assertEquals(4.1/5.6,classifier.probToken("bb","B"),0.0001);
    }


    static Set listToSet(String[] xs) {
        return new HashSet(Arrays.asList(xs));
    }




}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy