cmu.arktweetnlp.Train Maven / Gradle / Ivy
The newest version!
package cmu.arktweetnlp;
import java.io.IOException;
import java.util.ArrayList;
import cmu.arktweetnlp.impl.Model;
import cmu.arktweetnlp.impl.ModelSentence;
import cmu.arktweetnlp.impl.OWLQN;
import cmu.arktweetnlp.impl.Sentence;
import cmu.arktweetnlp.impl.OWLQN.WeightsPrinter;
import cmu.arktweetnlp.impl.features.FeatureExtractor;
import cmu.arktweetnlp.io.CoNLLReader;
import cmu.arktweetnlp.util.Util;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.DiffFunction;
public class Train {
public double l2penalty = 2;
public double l1penalty = 0.25;
public double tol = 1e-7;
public int maxIter = 500;
public String modelLoadFilename = null;
public String examplesFilename = null;
public String modelSaveFilename = null;
public boolean dumpFeatures = false;
// Data structures
private ArrayList lSentences;
private ArrayList mSentences;
private int numTokens = 0;
private Model model;
Train() {
lSentences = new ArrayList();
mSentences = new ArrayList();
model = new Model();
}
public void doFeatureDumping() throws IOException {
readTrainingSentences(examplesFilename);
constructLabelVocab();
extractFeatures();
dumpFeatures();
}
public void doTraining() throws IOException {
readTrainingSentences(examplesFilename);
constructLabelVocab();
extractFeatures();
model.lockdownAfterFeatureExtraction();
if (modelLoadFilename != null) {
readWarmStartModel();
}
optimizationLoop();
model.saveModelAsText(modelSaveFilename);
}
public void readTrainingSentences(String filename) throws IOException {
lSentences = CoNLLReader.readFile(filename);
for (Sentence sent : lSentences)
numTokens += sent.T();
}
public void constructLabelVocab() {
for (Sentence s : lSentences) {
for (String l : s.labels) {
model.labelVocab.num(l);
}
}
model.labelVocab.lock();
model.numLabels = model.labelVocab.size();
}
public void dumpFeatures() throws IOException {
FeatureExtractor fe = new FeatureExtractor(model, true);
fe.dumpMode = true;
for (Sentence lSent : lSentences) {
ModelSentence mSent = new ModelSentence(lSent.T());
fe.computeFeatures(lSent, mSent);
}
}
public void extractFeatures() throws IOException {
System.out.println("Extracting features");
FeatureExtractor fe = new FeatureExtractor(model, true);
for (Sentence lSent : lSentences) {
ModelSentence mSent = new ModelSentence(lSent.T());
fe.computeFeatures(lSent, mSent);
mSentences.add(mSent);
}
}
public void readWarmStartModel() throws IOException {
assert model.featureVocab.isLocked();
Model warmModel = Model.loadModelFromText(modelLoadFilename);
Model.copyCoefsForIntersectingFeatures(warmModel, model);
}
public void optimizationLoop() {
OWLQN minimizer = new OWLQN();
minimizer.setMaxIters(maxIter);
minimizer.setQuiet(false);
minimizer.setWeightsPrinting(new MyWeightsPrinter());
double[] initialWeights = model.convertCoefsToFlat();
double[] finalWeights = minimizer.minimize(
new GradientCalculator(),
initialWeights, l1penalty, tol, 5);
model.setCoefsFromFlat(finalWeights);
}
private class GradientCalculator implements DiffFunction {
@Override
public int domainDimension() {
return model.flatIDsize();
}
@Override
public double valueAt(double[] flatCoefs) {
model.setCoefsFromFlat(flatCoefs);
double loglik = 0;
for (ModelSentence s : mSentences) {
loglik += model.computeLogLik(s);
}
return -loglik + regularizerValue(flatCoefs);
}
@Override
public double[] derivativeAt(double[] flatCoefs) {
double[] g = new double[model.flatIDsize()];
model.setCoefsFromFlat(flatCoefs);
for (ModelSentence s : mSentences) {
model.computeGradient(s, g);
}
ArrayMath.multiplyInPlace(g, -1);
addL2regularizerGradient(g, flatCoefs);
return g;
}
}
private void addL2regularizerGradient(double[] grad, double[] flatCoefs) {
assert grad.length == flatCoefs.length;
for (int f=0; f < flatCoefs.length; f++) {
grad[f] += l2penalty * flatCoefs[f];
}
}
/**
* lambda_2 * (1/2) sum (beta_j)^2 + lambda_1 * sum |beta_j|
* our OWLQN seems to only want the first term
*/
private double regularizerValue(double[] flatCoefs) {
double l2_term = 0;
for (int f=0; f < flatCoefs.length; f++) {
l2_term += Math.pow(flatCoefs[f], 2);
}
return 0.5*l2penalty*l2_term;
}
public class MyWeightsPrinter implements WeightsPrinter {
@Override
public void printWeights() {
double loglik = 0;
for (ModelSentence s : mSentences) {
loglik += model.computeLogLik(s);
}
System.out.printf("\tTokLL %.6f\t", loglik/numTokens);
}
}
//////////////////////////////////////////////////////////////
public static void main(String[] args) throws IOException {
Train trainer = new Train();
if (args.length < 2 || args[0].equals("-h") || args[1].equals("--help")) {
usage();
}
int i=0;
while (i < args.length) {
// Util.p(args[i]);
if (!args[i].startsWith("-")) {
break;
}
else if (args[i].equals("--warm-start")) {
trainer.modelLoadFilename = args[i+1];
i += 2;
}
else if (args[i].equals("--max-iter")) {
trainer.maxIter = Integer.parseInt(args[i+1]);
i += 2;
}
else if (args[i].equals("--dump-feat")) {
trainer.dumpFeatures = true;
i += 1;
} else if (args[i].equals("--l2")) {
trainer.l2penalty = Double.parseDouble(args[i+1]);
i += 2;
} else if (args[i].equals("--l1")) {
trainer.l1penalty = Double.parseDouble(args[i+1]);
i += 2;
}
else {
usage();
}
}
if (trainer.dumpFeatures) {
trainer.examplesFilename = args[i];
trainer.doFeatureDumping();
System.exit(0);
}
if (args.length - i < 2) usage();
trainer.examplesFilename = args[i];
trainer.modelSaveFilename = args[i+1];
trainer.doTraining();
}
public static void usage() {
System.out.println(
"Train [options] \n" +
"Options:" +
"\n --max-iter " +
"\n --warm-start Initializes at weights of this model. discards base features that aren't in training set." +
"\n --dump-feat Show extracted features, instead of training. Useful for debugging/analyzing feature extractors." +
"\n"
);
System.exit(1);
}
}