All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aliasi.test.unit.crf.ChainCrfTest Maven / Gradle / Ivy

Go to download

This is the original Lingpipe: http://alias-i.com/lingpipe/web/download.html There were not made any changes to the source code.

There is a newer version: 4.1.2-JL1.0
Show newest version
package com.aliasi.test.unit.crf;

import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;

import com.aliasi.crf.ChainCrf;
import com.aliasi.crf.ChainCrfFeatureExtractor;
import com.aliasi.crf.ChainCrfFeatures;

import com.aliasi.io.LogLevel;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;

import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Vector;

import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;

import com.aliasi.symbol.SymbolTable;
import com.aliasi.symbol.SymbolTableCompiler;

import com.aliasi.tag.ScoredTagging;
import com.aliasi.tag.Tagging;
import com.aliasi.tag.TagLattice;

import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.ObjectToDoubleMap;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

import java.io.IOException;
import java.io.Serializable;

import org.junit.Test;

import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertNotNull;
import static junit.framework.Assert.fail;
import static org.junit.Assert.assertArrayEquals;

import static com.aliasi.test.unit.Asserts.succeed;

public class ChainCrfTest {

    static String CAT1 = "X";
    static String CAT2 = "Y";
    static String CAT3 = "Z";
    static String[] TAGS = new String[] { CAT1, CAT2, CAT3 };

    static String X1 = "a";
    static String X2 = "b";
    static String X3 = "c";
    static String X4 = "d";
    static String[] TOKENS = new String[] { X1, X2, X3, X4 };

    static String[] FEATURES =
        new String[] { CAT1, CAT2, CAT3, X1, X2, X3, X4 };

    static double XX = 1.0;
    static double XY = 1.0;
    static double XZ = 2.0;
    static double YX = 2.0;
    static double YY = -1.0;
    static double YZ = 4.0;
    static double ZX = 3.0;
    static double ZY = 1.0;
    static double ZZ = 6.0;

    static double[][] TRANSITION_WEIGHTS = new double[][] {
        { XX, YX, ZX },
        { XY, YY, ZY },
        { XZ, YZ, ZZ }
    };

    static double Xa = 4.0;
    static double Xb = 5.0;
    static double Xc = 6.0;
    static double Xd = 7.0;

    static double Ya = -1.0;
    static double Yb = 10.0;
    static double Yc = -1.0;
    static double Yd = 1.0;

    static double Za = -2.0;
    static double Zb = -4.0;
    static double Zc = -6.0;
    static double Zd = 15.0;

    static double[][] TOKEN_WEIGHTS = new double[][] {
        { Xa, Xb, Xc, Xd },
        { Ya, Yb, Yc, Yd },
        { Za, Zb, Zc, Zd }
    };



    static int NUM_TAGS = TAGS.length;


    static Vector[] COEFFICIENTS = new DenseVector[] {
        new DenseVector(new double[] { XX, YX, ZX,
                                       Xa, Xb, Xc, Xd }),

        new DenseVector(new double[] { XY, YY, ZY,
                                       Ya, Yb, Yc, Yd }),

        new DenseVector(new double[] { XZ, YZ, ZZ,
                                       Za, Zb, Zc, Zd })
    };



    static final SymbolTable FEATURE_SYMBOL_TABLE
        = SymbolTableCompiler.asSymbolTable(FEATURES);

    static final ChainCrfFeatureExtractor FEATURE_EXTRACTOR
        = new TestFeatureExtractor();

    static class TestFeatureExtractor
        implements ChainCrfFeatureExtractor,
                   Serializable {

        public ChainCrfFeatures extract(List tokens, List tags) {
            return new TestCrfFeatures(tokens,tags);
        }
    }

    static class TestCrfFeatures
        extends ChainCrfFeatures {
        // could cache maps -- only need one per token and per tag
        public TestCrfFeatures(List tokens, List tags) {
            super(tokens,tags);
        }
        public Map nodeFeatures(int n) {
            return Collections.singletonMap(token(n),
                                            Integer.valueOf(1));
        }
        public Map edgeFeatures(int n, int prevTagIndex) {
            return Collections.singletonMap(tag(prevTagIndex),
                                            Integer.valueOf(1));
        }
    }

    static boolean ADD_INTERCEPT_FEATURE = false;

    static ChainCrf CRF
        = new ChainCrf(TAGS,
                               COEFFICIENTS,
                               FEATURE_SYMBOL_TABLE,
                               FEATURE_EXTRACTOR,
                               ADD_INTERCEPT_FEATURE);

    @Test
    public void testDecoder() throws IOException {
        @SuppressWarnings("unchecked")
        ChainCrf crf2
            = (ChainCrf) AbstractExternalizable.serializeDeserialize(CRF);

        assertEquals(CRF.addInterceptFeature(), crf2.addInterceptFeature());
        assertEquals(CRF.featureSymbolTable().numSymbols(),
                     crf2.featureSymbolTable().numSymbols());
        for (int i = 0; i < CRF.featureSymbolTable().numSymbols(); ++i)
            assertEquals(CRF.featureSymbolTable().idToSymbol(i),
                         crf2.featureSymbolTable().idToSymbol(i));
        assertEquals(CRF.tags(), crf2.tags());

        Vector[] coeffsCRF = CRF.coefficients();
        Vector[] coeffsCrf2 = crf2.coefficients();
        assertEquals(coeffsCRF.length, coeffsCrf2.length);
        for (int i = 0; i < coeffsCRF.length; ++i) {
            assertEquals(coeffsCRF[i].numDimensions(),
                         coeffsCrf2[i].numDimensions());
            assertArrayEquals(coeffsCRF[i].nonZeroDimensions(),
                              coeffsCrf2[i].nonZeroDimensions());
            for (int d : coeffsCRF[i].nonZeroDimensions())
                assertEquals(coeffsCRF[i].value(d),
                             coeffsCrf2[i].value(d), 0.0001);
        }


        // tests all 4**0, 4**1, 4**2, and 4**3 length inputs
        for (int length = 0; length < 5; ++length) {
            for (int[] tokenIds : allArrays(length,TOKENS.length)) {
                List tokenList = new ArrayList(length);
                for (int i = 0; i < tokenIds.length; ++i)
                    tokenList.add(TOKENS[tokenIds[i]]);

                // brute force
                ObjectToDoubleMap otdMap = bruteForce(tokenIds,TAGS.length,
                                                             TRANSITION_WEIGHTS,
                                                             TOKEN_WEIGHTS);

                // first best eval
                assertCorrectAnswer(CRF,tokenList,otdMap,TAGS);
                assertCorrectAnswer(crf2,tokenList,otdMap,TAGS);

                // n-best eval
                Iterator> nBest = CRF.tagNBest(tokenList,Integer.MAX_VALUE);
                assertCorrectNBest(otdMap,nBest,TAGS,false);

                Iterator> nBestCond
                    = CRF.tagNBestConditional(tokenList,Integer.MAX_VALUE);
                assertCorrectNBest(otdMap,nBestCond,TAGS,true);

                // marginal eval
                TagLattice tagLattice = CRF.tagMarginal(tokenList);
                assertCorrectMarginal(otdMap,tagLattice,TAGS,tokenList);
            }
        }
    }

    void assertCorrectMarginal(ObjectToDoubleMap otdMap,
                               TagLattice tagLattice,
                               String[] tags,
                               List tokenList) {
        assertEquals(tokenList,tagLattice.tokenList());

        double logZ = logZ(otdMap);
        assertEquals(logZ,tagLattice.logZ(),0.001);

        List tagList = tagLattice.tagList();
        for (int pos = 0; pos < tokenList.size(); ++pos) {
            double sum = 0.0;
            for (int tagId = 0; tagId < tagList.size(); ++tagId) {
                sum += Math.exp(tagLattice.logProbability(pos,tagId));
                assertEquals(logMarginal(otdMap,pos,tagId,tags.length,logZ),
                             tagLattice.logProbability(pos,tagId),
                             0.0001);
            }
            assertEquals("marginals norm " + pos + " " + tokenList,1.0,sum,0.01);
        }
    }

    static double logMarginal(ObjectToDoubleMap otdMap,
                              int pos,
                              int tagId,
                              int numTags,
                              double logZ) {
        int count = 0;
        for (int[] key : otdMap.keySet()) {
            if (key[pos] == tagId)
                ++count;
        }
        double[] xs = new double[count];
        count = 0;
        for (Map.Entry entry : otdMap.entrySet())
            if (tagId == entry.getKey()[pos])
                xs[count++] = entry.getValue();
        return com.aliasi.util.Math.logSumOfExponentials(xs) - logZ;
    }

    static double logZ(ObjectToDoubleMap otdMap) {
        double[] xs = new double[otdMap.size()];
        int idx = 0;
        for (double x : otdMap.values())
            xs[idx++] = x;
        return com.aliasi.util.Math.logSumOfExponentials(xs);
    }


    void assertCorrectNBest(ObjectToDoubleMap otdMap,
                            Iterator> nBest,
                            String[] tags,
                            boolean conditional) {
        double logZ = conditional ? logZ(otdMap) : 0.0;

        ObjectToDoubleMap otdMap2 = new ObjectToDoubleMap();
        int count = 0;
        Set expectedTaggingSet = new TreeSet();
        for (Map.Entry entry : otdMap.entrySet()) {
            Double val = entry.getValue();
            int[] tagIds = entry.getKey();
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < tagIds.length; ++i) {
                sb.append(tags[tagIds[i]]);
            }
            String tagRep = sb.toString();
            otdMap2.put(tagRep,val);
            expectedTaggingSet.add(tagRep);
            ++count;
        }
        Set foundTaggingSet = new TreeSet();
        for (count = 0; nBest.hasNext(); ++count) {
            ScoredTagging scoredTagging = nBest.next();
            double val = scoredTagging.score();
            List tagList = scoredTagging.tags();
            StringBuilder sb = new StringBuilder();
            for (String tag : tagList)
                sb.append(tag);
            String tagRep = sb.toString();
            foundTaggingSet.add(tagRep);
            double expectedVal = otdMap2.get(tagRep) - logZ;
            assertEquals(expectedVal,val,0.0001);
        }
        assertEquals(expectedTaggingSet,foundTaggingSet);
    }

    @Test
    public void testAllOutputsSizes() {
        assertEquals(1,allArrays(0,5).size());
        assertEquals(5,allArrays(1,5).size());
        assertEquals(25,allArrays(2,5).size());
        assertEquals(125,allArrays(3,5).size());
    }

    static void assertCorrectAnswer(ChainCrf crf,
                                    List tokenList,
                                    ObjectToDoubleMap otdMap,
                                    String[] tags) {
        // complexity is dealing with n-best
        Tagging tagging = crf.tag(tokenList);
        List foundTags = tagging.tags();
        List keysList = otdMap.keysOrderedByValueList();
        double score = otdMap.getValue(keysList.get(0));
        for (int[] keys : keysList) {
            double score2 = otdMap.getValue(keys);
            if (score2 < score) {
                fail();
            }
            if (areEqualTags(foundTags,keys,tags)) {
                succeed();
                return;
            }
        }
    }

    static boolean areEqualTags(List foundTags,
                                int[] expectedTags,
                                String[] tags) {
        for (int i = 0; i < expectedTags.length; ++i)
            if (!foundTags.get(i).equals(tags[expectedTags[i]]))
                return false;
        return true;
    }

    static ObjectToDoubleMap  bruteForce(int[] tokens, int numTags,
                                                double[][] transitionWeights, double[][] tokenWeights) {
        ObjectToDoubleMap outputMap = new ObjectToDoubleMap();
        List allArrays = allArrays(tokens.length,numTags);
        for (int[] output : allArrays) {
            double score = score(tokens,output,transitionWeights,tokenWeights);
            outputMap.put(output,score);
        }
        return outputMap;
    }

    static double score(int[] tokens, int[] output, double[][] transitionWeights, double[][] tokenWeights) {
        double score = 0.0;
        for (int i = 0; i < tokens.length; ++i)
            score += tokenWeights[output[i]][tokens[i]];
        for (int i = 1; i < tokens.length; ++i)
            score += transitionWeights[output[i]][output[i-1]];
        return score;
    }

    static List allArrays(int size, int maxVal) {
        List result = new ArrayList();
        allArrays(size,maxVal,new int[size],result);
        return result;
    }

    static void allArrays(int size, int maxVal, int[] buf, List result) {
        if (size == 0) {
            result.add(buf.clone());
            return;
        }
        for (int i = 0; i < maxVal; ++i) {
            buf[size-1] = i;
            allArrays(size-1,maxVal,buf,result);
        }
    }

    static class TestCorpus extends Corpus>> {
        static final String[][][] WORDS_TAGSS = new String[][][] {
            { { }, { } },
            { { "." }, { "EOS" } },
            { { "John", "ran", "." }, { "PN", "IV", "EOS" } },
            { { "Mary", "ran", "." }, { "PN", "IV", "EOS" } },
            { { "John", "jumped", "!" }, { "PN", "IV", "EOS" } },
            { { "The", "dog", "jumped", "!" }, { "DET", "N", "IV", "EOS" } },
            { { "The", "dog", "sat", "." }, { "DET", "N", "IV", "EOS" } },
            { { "Mary", "sat", "!" }, { "PN", "IV", "EOS" } },
            { { "Mary", "likes", "John", "." }, { "PN", "TV", "PN", "EOS" } },
            { { "The", "dog", "likes", "Mary", "." }, { "DET", "N", "TV", "PN", "EOS" } },
            { { "John", "likes", "the", "dog", "." }, { "PN", "TV", "DET", "N", "EOS" } },
            { { "The", "dog", "ran", "." }, { "DET", "N", "IV", "EOS", } },
            { { "The", "dog", "ran", "." }, { "DET", "N", "IV", "EOS", } }
        };
        public void visitTrain(ObjectHandler> handler) {
            for (String[][] wordsTags : WORDS_TAGSS) {
                String[] words = wordsTags[0];
                String[] tags = wordsTags[1];
                Tagging tagging
                    = new Tagging(Arrays.asList(words),
                                          Arrays.asList(tags));
                handler.handle(tagging);
            }
        }
        public void visitTest(ObjectHandler> handler) {
        }
    }


    @Test
    public void testEstimate() throws Exception {
        Corpus>> corpus = new TestCorpus();
        int minCount = 1;
        boolean addIntercept = true;
        boolean cacheFeatureVectors = true;
        boolean allowUnseenTransitions = true;
        RegressionPrior prior = RegressionPrior.gaussian(10.0,true);
        int priorBlockSize = 3;
        AnnealingSchedule annealingSchedule
            = AnnealingSchedule.exponential(0.02,0.995);
        double minImprovement = 0.00001;
        int minEpochs = 2;
        int maxEpochs = 2000;
        Reporter reporter = null; // Reporters.stdOut().setLevel(LogLevel.DEBUG);
        ChainCrf crf
            = ChainCrf.estimate(corpus,
                                FEATURE_EXTRACTOR,
                                addIntercept,
                                minCount,
                                cacheFeatureVectors,
                                allowUnseenTransitions,
                                prior,
                                priorBlockSize,
                                annealingSchedule,
                                minImprovement,
                                minEpochs,
                                maxEpochs,
                                reporter);

        assertTagging(Arrays.asList("John","ran","."),
                      Arrays.asList("PN","IV","EOS"),
                      crf);

        assertTagging(Arrays.asList("Mary","ran","."),
                      Arrays.asList("PN","IV","EOS"),
                      crf);
        assertTagging(Arrays.asList("The","dog","ran","."),
                      Arrays.asList("DET","N","IV","EOS"),
                      crf);
        assertTagging(Arrays.asList("The","dog","ran","!"),
                      Arrays.asList("DET","N","IV","EOS"),
                      crf);
        assertTagging(Arrays.asList("The","dog","sat","!"),
                      Arrays.asList("DET","N","IV","EOS"),
                      crf);
        assertTagging(Arrays.asList("The","dog","sat","."),
                      Arrays.asList("DET","N","IV","EOS"),
                      crf);
        assertTagging(Arrays.asList("John","likes","Mary","."),
                      Arrays.asList("PN","TV","PN","EOS"),
                      crf);
        assertTagging(Arrays.asList("Mary","likes","John","."),
                      Arrays.asList("PN","TV","PN","EOS"),
                      crf);

        // don't barf on unknown words
        assertNotNull(crf.tag(Arrays.asList("Fred","likes","John",".")));
        assertNotNull(crf.tag(Arrays.asList(";",".","likes","likes")));
    }

    static  void assertTagging(List tokens,
                                  List tagsExpected,
                                  ChainCrf crf) {
        Tagging tagging = crf.tag(tokens);
        // System.out.println(tagging);
        List tagsFound = tagging.tags();
        assertEquals(tagsExpected,tagsFound);
    }


}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy