All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.clearwsd.app.VerbNetExperiment Maven / Gradle / Ivy

There is a newer version: 0.12.1
Show newest version
/*
 * Copyright (C) 2017  James Gung
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */

package io.github.clearwsd.app;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import io.github.clearwsd.WordSenseClassifier;
import io.github.clearwsd.classifier.Classifier;
import io.github.clearwsd.classifier.PaClassifier;
import io.github.clearwsd.corpus.semlink.VerbNetReader;
import io.github.clearwsd.eval.Evaluation;
import io.github.clearwsd.feature.annotator.AggregateAnnotator;
import io.github.clearwsd.feature.context.DepChildrenContextFactory;
import io.github.clearwsd.feature.context.NlpContextFactory;
import io.github.clearwsd.feature.context.OffsetContextFactory;
import io.github.clearwsd.feature.context.RootPathContextFactory;
import io.github.clearwsd.feature.extractor.ConcatenatingFeatureExtractor;
import io.github.clearwsd.feature.extractor.ListConcatenatingFeatureExtractor;
import io.github.clearwsd.feature.extractor.ListLookupFeatureExtractor;
import io.github.clearwsd.feature.extractor.LookupFeatureExtractor;
import io.github.clearwsd.feature.extractor.StringExtractor;
import io.github.clearwsd.feature.extractor.StringFunctionExtractor;
import io.github.clearwsd.feature.extractor.StringListExtractor;
import io.github.clearwsd.feature.extractor.string.LowercaseFunction;
import io.github.clearwsd.feature.function.AggregateFeatureFunction;
import io.github.clearwsd.feature.function.BiasFeatureFunction;
import io.github.clearwsd.feature.function.ConjunctionFunction;
import io.github.clearwsd.feature.function.FeatureFunction;
import io.github.clearwsd.feature.function.MultiStringFeatureFunction;
import io.github.clearwsd.feature.function.StringFeatureFunction;
import io.github.clearwsd.feature.pipeline.DefaultFeaturePipeline;
import io.github.clearwsd.feature.pipeline.FeaturePipeline;
import io.github.clearwsd.feature.pipeline.NlpClassifier;
import io.github.clearwsd.feature.resource.DynamicDependencyNeighborsResource;
import io.github.clearwsd.type.DepNode;
import io.github.clearwsd.type.DepTree;
import io.github.clearwsd.type.FeatureType;
import io.github.clearwsd.type.NlpFocus;
import io.github.clearwsd.utils.LemmaDictionary;
import io.github.clearwsd.verbnet.DefaultVerbNetClassifier;
import io.github.clearwsd.verbnet.VerbNetSenseInventory;
import lombok.Getter;
import lombok.Setter;
import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j;

import static io.github.clearwsd.app.VerbNetClassifierUtils.resourceManager;
import static io.github.clearwsd.type.FeatureType.Dep;

/**
 * VerbNet experiment builder.
 *
 * @author jamesgung
 */
@Slf4j
@Getter
@Setter
@Accessors(fluent = true)
public class VerbNetExperiment {

    private static final int MIN_COUNT = 10;
    private static boolean single = false;

    private int minPredicateCount;
    private boolean filterTestVerbs;

    private static void daisukeSetup() throws IOException {
        List> trainData = new VerbNetReader().readInstances(
                new FileInputStream("data/datasets/semlink/train.ud.dep"));
        trainData = filterMinCount(MIN_COUNT, trainData); // only consider verbs with > 10 occurrences
        trainData = filterPolysemous(trainData);

        Set verbs = getVerbs(trainData); // only test on verbs in training data
        List> validData = new VerbNetReader().readInstances(
                new FileInputStream("data/datasets/semlink/valid.ud.dep"));
        validData = filterByVerb(verbs, validData);
        List> testData = new VerbNetReader().readInstances(
                new FileInputStream("data/datasets/semlink/test.ud.dep"));
        testData = filterByVerb(verbs, testData);
        log.debug("{} test instances and {} dev instances", testData.size(), validData.size());

        Classifier, String> multi;
        if (single) {
            AggregateAnnotator> annotator
                    = new AggregateAnnotator<>(VerbNetClassifierUtils.annotators());
            annotator.initialize(resourceManager());
            trainData.forEach(annotator::annotate);
            validData.forEach(annotator::annotate);
            testData.forEach(annotator::annotate);
            multi = new NlpClassifier<>(new PaClassifier(), initializeFeatures());
        } else {
            multi = new DefaultVerbNetClassifier();
        }

        WordSenseClassifier classifier = new WordSenseClassifier(multi, new VerbNetSenseInventory(), new LemmaDictionary());

        classifier.train(trainData, validData);
        classifier.save(new ObjectOutputStream(new FileOutputStream("data/models/semlink.model")));

        // evaluate classifier
        Evaluation evaluation = new Evaluation();
        for (NlpFocus instance : validData) {
            evaluation.add(classifier.classify(instance), instance.feature(FeatureType.Gold));
        }
        log.debug("Validation data\n{}", evaluation);

        // ensure that loading classifier gives same results
        classifier = new WordSenseClassifier(new ObjectInputStream(new FileInputStream("data/models/semlink.model")));

        evaluation = new Evaluation();
        for (NlpFocus instance : testData) {
            evaluation.add(classifier.classify(instance), instance.feature(FeatureType.Gold));
        }
        log.debug("Test data\n{}", evaluation);
    }

    private static List> filterMinCount(int minCount,
                                                                   List> data) {
        Map counts = new HashMap<>();
        for (NlpFocus instance : data) {
            String lemma = instance.focus().feature(FeatureType.Predicate);
            counts.merge(lemma, 1, (prev, one) -> prev + one);
        }
        return data.stream()
                .filter(instance -> {
                    String predicate = instance.focus().feature(FeatureType.Predicate);
                    return counts.get(predicate) >= minCount;
                })
                .collect(Collectors.toList());
    }

    private static List> filterPolysemous(
            List> data) {
        Multimap labelMap = HashMultimap.create();
        for (NlpFocus instance : data) {
            String predicate = instance.focus().feature(FeatureType.Predicate);
            labelMap.put(predicate, instance.feature(FeatureType.Gold));
        }
        return data.stream()
                .filter(instance -> labelMap.get(instance.focus().feature(FeatureType.Predicate)).size() > 1)
                .collect(Collectors.toList());
    }

    private static List> filterByVerb(Set verbs,
                                                                 List> data) {
        return data.stream()
                .filter(instance -> {
                    String predicate = instance.focus().feature(FeatureType.Predicate);
                    return verbs.contains(predicate);
                })
                .collect(Collectors.toList());
    }

    private static Set getVerbs(List> data) {
        return data.stream()
                .map(i -> (String) i.focus().feature(FeatureType.Predicate))
                .collect(Collectors.toSet());
    }

    private static FeaturePipeline> initializeFeatures() {
        List>> features = new ArrayList<>();

        StringExtractor text = new StringFunctionExtractor<>(
                new LookupFeatureExtractor<>(FeatureType.Text.name()), new LowercaseFunction());
        StringExtractor lemma = new StringFunctionExtractor<>(
                new LookupFeatureExtractor<>(FeatureType.Lemma.name()), new LowercaseFunction());
        StringExtractor dep = new StringFunctionExtractor<>(
                new LookupFeatureExtractor<>(Dep.name()), new LowercaseFunction());
        StringExtractor pos = new StringFunctionExtractor<>(
                new LookupFeatureExtractor<>(FeatureType.Pos.name()), new LowercaseFunction());

        List> windowExtractors =
                Arrays.asList(text, lemma, pos);
        List> depExtractors =
                Stream.of(lemma, pos).map(s -> new ConcatenatingFeatureExtractor<>(
                        Arrays.asList(s, dep))).collect(Collectors.toList());
        List> depPathExtractors =
                Arrays.asList(lemma, dep, pos);

        DepChildrenContextFactory depContexts = new DepChildrenContextFactory(
                Sets.newHashSet("punct"), new HashSet<>());
        NlpContextFactory, DepNode> windowContexts =
                new OffsetContextFactory<>(Arrays.asList(-2, -1, 0, 1, 2));
        NlpContextFactory, DepNode> rootPathContext = new RootPathContextFactory(false, 1);

        DepChildrenContextFactory filteredDepContexts = new DepChildrenContextFactory(
                new HashSet<>(), Sets.newHashSet("dobj", "nmod", "xcomp", "advmod"));
        List> clusterExtractors =
                Stream.of("cluster-100", "cluster-320", "cluster-1000", "cluster-3200", "cluster-10000", "brown").map(
                        (Function>) ListLookupFeatureExtractor::new)
                        .collect(Collectors.toList());
        clusterExtractors = clusterExtractors.stream()
                .map(s -> new ListConcatenatingFeatureExtractor<>(s, dep))
                .collect(Collectors.toList());
        List> filteredDepExtractors = new ArrayList<>(clusterExtractors);
        filteredDepExtractors.add(new ListLookupFeatureExtractor<>(DynamicDependencyNeighborsResource.DDN_KEY));
        filteredDepExtractors.add(new ListLookupFeatureExtractor<>("WN"));

        StringFeatureFunction, DepNode> depFeatures
                = new StringFeatureFunction<>(depContexts, Collections.singletonList(new ConcatenatingFeatureExtractor<>(pos, dep)));
        ConjunctionFunction> function
                = new ConjunctionFunction<>(depFeatures, depFeatures);

        features.add(function);
        features.add(new StringFeatureFunction<>(windowContexts, windowExtractors));
        features.add(new StringFeatureFunction<>(depContexts, depExtractors));
        features.add(new MultiStringFeatureFunction<>(filteredDepContexts, filteredDepExtractors));
        features.add(new MultiStringFeatureFunction<>(new OffsetContextFactory<>(0), clusterExtractors));
        features.add(new StringFeatureFunction<>(rootPathContext, depPathExtractors));
        features.add(new BiasFeatureFunction<>());

        return new DefaultFeaturePipeline<>(new AggregateFeatureFunction<>(features));
    }

    public static void main(String... args) throws IOException {
        daisukeSetup();
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy