io.github.clearwsd.app.VerbNetExperiment Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of clearwsd-cli Show documentation
Show all versions of clearwsd-cli Show documentation
Command line interfaces for non-programmatic training and experimentation.
/*
* Copyright (C) 2017 James Gung
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package io.github.clearwsd.app;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import io.github.clearwsd.WordSenseClassifier;
import io.github.clearwsd.classifier.Classifier;
import io.github.clearwsd.classifier.PaClassifier;
import io.github.clearwsd.corpus.semlink.VerbNetReader;
import io.github.clearwsd.eval.Evaluation;
import io.github.clearwsd.feature.annotator.AggregateAnnotator;
import io.github.clearwsd.feature.context.DepChildrenContextFactory;
import io.github.clearwsd.feature.context.NlpContextFactory;
import io.github.clearwsd.feature.context.OffsetContextFactory;
import io.github.clearwsd.feature.context.RootPathContextFactory;
import io.github.clearwsd.feature.extractor.ConcatenatingFeatureExtractor;
import io.github.clearwsd.feature.extractor.ListConcatenatingFeatureExtractor;
import io.github.clearwsd.feature.extractor.ListLookupFeatureExtractor;
import io.github.clearwsd.feature.extractor.LookupFeatureExtractor;
import io.github.clearwsd.feature.extractor.StringExtractor;
import io.github.clearwsd.feature.extractor.StringFunctionExtractor;
import io.github.clearwsd.feature.extractor.StringListExtractor;
import io.github.clearwsd.feature.extractor.string.LowercaseFunction;
import io.github.clearwsd.feature.function.AggregateFeatureFunction;
import io.github.clearwsd.feature.function.BiasFeatureFunction;
import io.github.clearwsd.feature.function.ConjunctionFunction;
import io.github.clearwsd.feature.function.FeatureFunction;
import io.github.clearwsd.feature.function.MultiStringFeatureFunction;
import io.github.clearwsd.feature.function.StringFeatureFunction;
import io.github.clearwsd.feature.pipeline.DefaultFeaturePipeline;
import io.github.clearwsd.feature.pipeline.FeaturePipeline;
import io.github.clearwsd.feature.pipeline.NlpClassifier;
import io.github.clearwsd.feature.resource.DynamicDependencyNeighborsResource;
import io.github.clearwsd.type.DepNode;
import io.github.clearwsd.type.DepTree;
import io.github.clearwsd.type.FeatureType;
import io.github.clearwsd.type.NlpFocus;
import io.github.clearwsd.utils.LemmaDictionary;
import io.github.clearwsd.verbnet.DefaultVerbNetClassifier;
import io.github.clearwsd.verbnet.VerbNetSenseInventory;
import lombok.Getter;
import lombok.Setter;
import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j;
import static io.github.clearwsd.app.VerbNetClassifierUtils.resourceManager;
import static io.github.clearwsd.type.FeatureType.Dep;
/**
* VerbNet experiment builder.
*
* @author jamesgung
*/
@Slf4j
@Getter
@Setter
@Accessors(fluent = true)
public class VerbNetExperiment {
private static final int MIN_COUNT = 10;
private static boolean single = false;
private int minPredicateCount;
private boolean filterTestVerbs;
private static void daisukeSetup() throws IOException {
List> trainData = new VerbNetReader().readInstances(
new FileInputStream("data/datasets/semlink/train.ud.dep"));
trainData = filterMinCount(MIN_COUNT, trainData); // only consider verbs with > 10 occurrences
trainData = filterPolysemous(trainData);
Set verbs = getVerbs(trainData); // only test on verbs in training data
List> validData = new VerbNetReader().readInstances(
new FileInputStream("data/datasets/semlink/valid.ud.dep"));
validData = filterByVerb(verbs, validData);
List> testData = new VerbNetReader().readInstances(
new FileInputStream("data/datasets/semlink/test.ud.dep"));
testData = filterByVerb(verbs, testData);
log.debug("{} test instances and {} dev instances", testData.size(), validData.size());
Classifier, String> multi;
if (single) {
AggregateAnnotator> annotator
= new AggregateAnnotator<>(VerbNetClassifierUtils.annotators());
annotator.initialize(resourceManager());
trainData.forEach(annotator::annotate);
validData.forEach(annotator::annotate);
testData.forEach(annotator::annotate);
multi = new NlpClassifier<>(new PaClassifier(), initializeFeatures());
} else {
multi = new DefaultVerbNetClassifier();
}
WordSenseClassifier classifier = new WordSenseClassifier(multi, new VerbNetSenseInventory(), new LemmaDictionary());
classifier.train(trainData, validData);
classifier.save(new ObjectOutputStream(new FileOutputStream("data/models/semlink.model")));
// evaluate classifier
Evaluation evaluation = new Evaluation();
for (NlpFocus instance : validData) {
evaluation.add(classifier.classify(instance), instance.feature(FeatureType.Gold));
}
log.debug("Validation data\n{}", evaluation);
// ensure that loading classifier gives same results
classifier = new WordSenseClassifier(new ObjectInputStream(new FileInputStream("data/models/semlink.model")));
evaluation = new Evaluation();
for (NlpFocus instance : testData) {
evaluation.add(classifier.classify(instance), instance.feature(FeatureType.Gold));
}
log.debug("Test data\n{}", evaluation);
}
private static List> filterMinCount(int minCount,
List> data) {
Map counts = new HashMap<>();
for (NlpFocus instance : data) {
String lemma = instance.focus().feature(FeatureType.Predicate);
counts.merge(lemma, 1, (prev, one) -> prev + one);
}
return data.stream()
.filter(instance -> {
String predicate = instance.focus().feature(FeatureType.Predicate);
return counts.get(predicate) >= minCount;
})
.collect(Collectors.toList());
}
private static List> filterPolysemous(
List> data) {
Multimap labelMap = HashMultimap.create();
for (NlpFocus instance : data) {
String predicate = instance.focus().feature(FeatureType.Predicate);
labelMap.put(predicate, instance.feature(FeatureType.Gold));
}
return data.stream()
.filter(instance -> labelMap.get(instance.focus().feature(FeatureType.Predicate)).size() > 1)
.collect(Collectors.toList());
}
private static List> filterByVerb(Set verbs,
List> data) {
return data.stream()
.filter(instance -> {
String predicate = instance.focus().feature(FeatureType.Predicate);
return verbs.contains(predicate);
})
.collect(Collectors.toList());
}
private static Set getVerbs(List> data) {
return data.stream()
.map(i -> (String) i.focus().feature(FeatureType.Predicate))
.collect(Collectors.toSet());
}
private static FeaturePipeline> initializeFeatures() {
List>> features = new ArrayList<>();
StringExtractor text = new StringFunctionExtractor<>(
new LookupFeatureExtractor<>(FeatureType.Text.name()), new LowercaseFunction());
StringExtractor lemma = new StringFunctionExtractor<>(
new LookupFeatureExtractor<>(FeatureType.Lemma.name()), new LowercaseFunction());
StringExtractor dep = new StringFunctionExtractor<>(
new LookupFeatureExtractor<>(Dep.name()), new LowercaseFunction());
StringExtractor pos = new StringFunctionExtractor<>(
new LookupFeatureExtractor<>(FeatureType.Pos.name()), new LowercaseFunction());
List> windowExtractors =
Arrays.asList(text, lemma, pos);
List> depExtractors =
Stream.of(lemma, pos).map(s -> new ConcatenatingFeatureExtractor<>(
Arrays.asList(s, dep))).collect(Collectors.toList());
List> depPathExtractors =
Arrays.asList(lemma, dep, pos);
DepChildrenContextFactory depContexts = new DepChildrenContextFactory(
Sets.newHashSet("punct"), new HashSet<>());
NlpContextFactory, DepNode> windowContexts =
new OffsetContextFactory<>(Arrays.asList(-2, -1, 0, 1, 2));
NlpContextFactory, DepNode> rootPathContext = new RootPathContextFactory(false, 1);
DepChildrenContextFactory filteredDepContexts = new DepChildrenContextFactory(
new HashSet<>(), Sets.newHashSet("dobj", "nmod", "xcomp", "advmod"));
List> clusterExtractors =
Stream.of("cluster-100", "cluster-320", "cluster-1000", "cluster-3200", "cluster-10000", "brown").map(
(Function>) ListLookupFeatureExtractor::new)
.collect(Collectors.toList());
clusterExtractors = clusterExtractors.stream()
.map(s -> new ListConcatenatingFeatureExtractor<>(s, dep))
.collect(Collectors.toList());
List> filteredDepExtractors = new ArrayList<>(clusterExtractors);
filteredDepExtractors.add(new ListLookupFeatureExtractor<>(DynamicDependencyNeighborsResource.DDN_KEY));
filteredDepExtractors.add(new ListLookupFeatureExtractor<>("WN"));
StringFeatureFunction, DepNode> depFeatures
= new StringFeatureFunction<>(depContexts, Collections.singletonList(new ConcatenatingFeatureExtractor<>(pos, dep)));
ConjunctionFunction> function
= new ConjunctionFunction<>(depFeatures, depFeatures);
features.add(function);
features.add(new StringFeatureFunction<>(windowContexts, windowExtractors));
features.add(new StringFeatureFunction<>(depContexts, depExtractors));
features.add(new MultiStringFeatureFunction<>(filteredDepContexts, filteredDepExtractors));
features.add(new MultiStringFeatureFunction<>(new OffsetContextFactory<>(0), clusterExtractors));
features.add(new StringFeatureFunction<>(rootPathContext, depPathExtractors));
features.add(new BiasFeatureFunction<>());
return new DefaultFeaturePipeline<>(new AggregateFeatureFunction<>(features));
}
public static void main(String... args) throws IOException {
daisukeSetup();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy