cc.mallet.fst.semi_supervised.pr.ConstraintsOptimizableByPR Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
The newest version!
/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.fst.semi_supervised.pr;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Logger;
import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.Transducer.TransitionIterator;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.optimize.Optimizable.ByGradientValue;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
/**
* Optimizable for E-step/I-projection in Posterior Regularization (PR).
*
* @author Kedar Bellare
* @author Gregory Druck
*/
public class ConstraintsOptimizableByPR implements Serializable, ByGradientValue {
private static Logger logger = MalletLogger.getLogger(ConstraintsOptimizableByPR.class.getName());
private static final long serialVersionUID = 1;
protected boolean cacheStale;
protected int numParameters;
protected int numThreads;
protected InstanceList trainingSet;
protected double cachedValue = -123456789;
protected double[] cachedGradient;
protected CRF crf;
protected ThreadPoolExecutor executor;
protected double[][][][] cachedDots;
PRAuxiliaryModel model;
public ConstraintsOptimizableByPR(CRF crf, InstanceList ilist, PRAuxiliaryModel model) {
this(crf,ilist,model,1);
}
public ConstraintsOptimizableByPR(CRF crf, InstanceList ilist, PRAuxiliaryModel model, int numThreads) {
this.crf = crf;
this.trainingSet = ilist;
this.model = model;
this.numParameters = model.numParameters();
cachedGradient = new double[numParameters];
this.cacheStale = true;
this.numThreads = numThreads;
this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numThreads);
cacheDotProducts();
}
public void cacheDotProducts() {
cachedDots = new double[trainingSet.size()][][][];
for (int i = 0; i < trainingSet.size(); i++) {
FeatureVectorSequence input = (FeatureVectorSequence)trainingSet.get(i).getData();
cachedDots[i] = new double[input.size()][crf.numStates()][crf.numStates()];
for (int j = 0; j < input.size(); j++) {
for (int k = 0; k < crf.numStates(); k++) {
for (int l = 0; l < crf.numStates(); l++) {
cachedDots[i][j][k][l] = Transducer.IMPOSSIBLE_WEIGHT;
}
}
}
for (int j = 0; j < input.size(); j++) {
for (int k = 0; k < crf.numStates(); k++) {
TransitionIterator iter = crf.getState(k).transitionIterator(input, j);
while (iter.hasNext()) {
int l = iter.next().getIndex();
cachedDots[i][j][k][l] = iter.getWeight();
}
}
}
}
}
public int getNumParameters() {
return numParameters;
}
public void getParameters(double[] params) {
model.getParameters(params);
}
public double getParameter(int index) {
return model.getParameter(index);
}
public void setParameters(double[] params) {
cacheStale = true;
model.setParameters(params);
}
public void setParameter(int index, double value) {
cacheStale = true;
model.setParameter(index, value);
}
protected double getExpectationValue() {
model.zeroExpectations();
// updating tasks
ArrayList> tasks = new ArrayList>();
int increment = trainingSet.size() / numThreads;
int start = 0;
int end = increment;
for (int taskIndex = 0; taskIndex < numThreads; taskIndex++) {
tasks.add(new ExpectationTask(start,end,model.copy()));
start = end;
if (taskIndex == numThreads - 2) {
end = trainingSet.size();
}
else {
end = start + increment;
}
}
double value = 0;
try {
List> results = executor.invokeAll(tasks);
// compute value
for (Future f : results) {
try {
value += f.get();
} catch (ExecutionException ee) {
ee.printStackTrace();
}
}
} catch (InterruptedException ie) {
ie.printStackTrace();
}
// combine results
combine(model,tasks);
// mu*b - w*||mu||^2
value += model.getValue();
return value;
}
/**
* Returns the log probability of the training sequence labels and the prior
* over parameters.
*/
public double getValue() {
if (cacheStale) {
cachedValue = getExpectationValue();
model.getValueGradient(cachedGradient);
cacheStale = false;
logger.info("getValue (auxiliary distribution) = " + cachedValue);
}
return cachedValue;
}
public double getCompleteValueContribution() {
if (cacheStale) {
getValue();
}
double value = model.getCompleteValueContribution();
return value;
}
public void getValueGradient(double[] buffer) {
if (cacheStale) {
getValue();
}
System.arraycopy(cachedGradient, 0, buffer, 0, cachedGradient.length);
}
private void combine(PRAuxiliaryModel orig, ArrayList> tasks) {
for (int i = 0; i < tasks.size(); i++) {
ExpectationTask task = (ExpectationTask)tasks.get(i);
PRAuxiliaryModel model = task.getModelCopy();
for (int ci = 0; ci < model.numConstraints(); ci++) {
PRConstraint origConstraint = orig.getConstraint(ci);
PRConstraint copyConstraint = model.getConstraint(ci);
double[] expectation = new double[origConstraint.numDimensions()];
copyConstraint.getExpectations(expectation);
origConstraint.addExpectations(expectation);
}
}
}
public void shutdown() {
executor.shutdown();
}
public double[][][][] getCachedDots() {
return cachedDots;
}
public PRAuxiliaryModel getAuxModel() {
return model;
}
private class ExpectationTask implements Callable {
private int start;
private int end;
private PRAuxiliaryModel modelCopy;
public ExpectationTask(int start, int end, PRAuxiliaryModel modelCopy) {
this.start = start;
this.end = end;
this.modelCopy = modelCopy;
}
public Double call() throws Exception {
double value = 0;
for (int ii = start; ii < end; ii++) {
Instance inst = trainingSet.get(ii);
Sequence input = (Sequence) inst.getData();
// logZ
value -= new SumLatticePR(crf, ii, input, null, modelCopy, cachedDots[ii], true, null, null, false).getTotalWeight();
}
return value;
}
public PRAuxiliaryModel getModelCopy() {
return modelCopy;
}
}
}