edu.berkeley.nlp.classify.LinearRegression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of berkeleyparser Show documentation
Show all versions of berkeleyparser Show documentation
The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).
The newest version!
package edu.berkeley.nlp.classify;
import java.util.Collection;
import java.util.List;
import edu.berkeley.nlp.math.CachingDifferentiableFunction;
import edu.berkeley.nlp.math.GradientMinimizer;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.util.CollectionUtils;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Pair;
public class LinearRegression {
private FeatureExtractor featureExtractor ;
private double[] weights;
private FeatureManager featureManager ;
public static class Factory {
double[] weights ;
FeatureManager featureManager ;
FeatureExtractor featureExtractor;
Collection> trainingData;
public Factory(FeatureExtractor featureExtractor) {
this.featureExtractor = featureExtractor;
this.featureManager = new FeatureManager();
}
private Counter getFeatures(I input) {
Counter strCounts = featureExtractor.extractFeatures(input);
Counter featCounts = new Counter();
for (String f: strCounts.keySet()) {
double count = strCounts.getCount(f);
Feature feat = featureManager.getFeature(f);
featCounts.setCount(feat, count);
}
return featCounts;
}
private double getScore(Counter featureCounts) {
double score = 0.0;
for (Feature feat: featureCounts.keySet()) {
double count = featureCounts.getCount(feat);
score += count * weights[feat.getIndex()];
}
return score;
}
private class ObjectiveFunction extends CachingDifferentiableFunction {
@Override
protected Pair calculate(double[] x) {
weights = x;
double objective = 0.0;
double[] gradient = new double[dimension()];
for (Pair datum: trainingData) {
I input = datum.getFirst();
Counter featCounts = getFeatures(input);
double guessResponse = getScore(featCounts);
double goldResponse = datum.getSecond();
double diff = (guessResponse - goldResponse);
objective += 0.5 * diff * diff;
for (Feature feat: featCounts.keySet()) {
double count = featCounts.getCount(feat);
gradient[feat.getIndex()] += count * diff;
}
}
// TODO Auto-generated method stub
return Pair.newPair(objective, gradient);
}
@Override
public int dimension() {
// TODO Auto-generated method stub
return featureManager.getNumFeatures();
}
public double[] unregularizedDerivativeAt(double[] x) {
// TODO Auto-generated method stub
return null;
}
}
private void extractAllFeatures() {
for (Pair datum: trainingData) {
Counter counts = featureExtractor.extractFeatures(datum.getFirst());
for (String f: counts.keySet()) {
featureManager.getFeature(f);
}
}
featureManager.lock();
}
private String examineWeights() {
Counter counts = new Counter();
for (int i=0; i < weights.length; ++i) {
Feature feat = featureManager.getFeature(i);
counts.setCount(feat, weights[i]);
}
return counts.toString();
}
public LinearRegression train(Collection> trainingData) {
this.trainingData = trainingData;
extractAllFeatures();
ObjectiveFunction objFn = new ObjectiveFunction();
GradientMinimizer gradMinimizer = new LBFGSMinimizer();
double[] initial = new double[objFn.dimension()];
this.weights = gradMinimizer.minimize(objFn, initial, 1.0e-4);
return new LinearRegression(featureExtractor, featureManager, weights);
}
}
private LinearRegression(FeatureExtractor featureExtractor, FeatureManager featureManager, double[] weights) {
this.featureExtractor = featureExtractor;
this.featureManager = featureManager;
this.weights = weights;
}
public double getResponse(I input) {
Counter featCounts = featureExtractor.extractFeatures(input);
double score = 0.0;
for (String f: featCounts.keySet()) {
double count = featCounts.getCount(f);
Feature feat = featureManager.getFeature(f);
score += count * weights[feat.getIndex()];
}
return score;
}
public static void main(String[] args) {
List elem1 = CollectionUtils.makeList("a","b","c");
List elem2 = CollectionUtils.makeList("a","b");
Pair, Double> d1 = Pair.newPair(elem1, 3.0);
Pair, Double> d2 = Pair.newPair(elem2, 2.0);
FeatureExtractor, String> featExtractor = new FeatureExtractor, String>() {
public Counter extractFeatures(List instance) {
Counter counts = new Counter();
for (String elem: instance) { counts.incrementCount(elem, 1.0); }
// TODO Auto-generated method stub
return counts;
}
};
LinearRegression.Factory> factory = new LinearRegression.Factory>(featExtractor);
List,Double>> datums = CollectionUtils.makeList(d1,d2);
LinearRegression> linearRegressionModel = factory.train(datums);
double guess = linearRegressionModel.getResponse(elem1);
System.out.println("guess: " + guess);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy