
cc.mallet.fst.ThreadedOptimizable Maven / Gradle / Ivy
Show all versions of jcore-mallet-2.0.9 Show documentation
package cc.mallet.fst;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.optimize.Optimizable;
import cc.mallet.util.MalletLogger;
/**
* An adaptor for optimizables based on batch values/gradients.
*
* Computes values, gradients for each batch in multiple threads and combines
* them in the end.
*
* @author Gaurav Chandalia
* @see CRFOptimizableByBatchLabelLikelihood
*/
public class ThreadedOptimizable implements Optimizable.ByGradientValue {
private static Logger logger = MalletLogger.getLogger(ThreadedOptimizable.class.getName());
/** Data */
protected InstanceList trainingSet;
/** optimizable to be parallelized */
protected Optimizable.ByCombiningBatchGradient optimizable;
/** Value obtained from the optimizable for each batch */
protected double[] batchCachedValue;
/** Gradient obtained from the optimizable for each batch */
protected List batchCachedGradient;
// determine when value/gradient become stale
protected CacheStaleIndicator cacheIndicator;
// tasks to be executed in individual threads, each task is instantiated only
// once but executed in every iteration
private transient Collection> valueTasks;
private transient Collection> gradientTasks;
// thread pool to compute value/gradient for one batch of data
private transient ThreadPoolExecutor executor;
// milliseconds
public static final int SLEEP_TIME = 100;
/**
* Initializes the optimizable and starts new threads.
*
* @param optimizable Optimizable to be parallelized
* @param numFactors Number of factors in model's parameters, used to
* initialize the gradient
* @param cacheIndicator Determines when value/gradient become stale
*/
public ThreadedOptimizable(Optimizable.ByCombiningBatchGradient optimizable,
InstanceList trainingSet, int numFactors,
CacheStaleIndicator cacheIndicator) {
// set up
this.trainingSet = trainingSet;
this.optimizable = optimizable;
int numBatches = optimizable.getNumBatches();
assert(numBatches > 0) : "Invalid number of batches: " + numBatches;
batchCachedValue = new double[numBatches];
batchCachedGradient = new ArrayList(numBatches);
for (int i = 0; i < numBatches; ++i) {
batchCachedGradient.add(new double[numFactors]);
}
this.cacheIndicator = cacheIndicator;
logger.info("Creating " + numBatches + " threads for updating gradient...");
executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numBatches);
this.createTasks();
}
public Optimizable.ByCombiningBatchGradient getOptimizable() {
return optimizable;
}
/**
* Shuts down the executor used to start and run threads to compute values
* and gradients.
*
* *Note*: For a clean exit of all the threads, it is recommended to call
* this method after training finishes.
*/
public void shutdown() {
// fix submitted by Mark Dredze ([email protected])
executor.shutdown();
try {
executor.awaitTermination(30, TimeUnit.SECONDS);
} catch (InterruptedException e) {
e.printStackTrace();
}
assert(executor.shutdownNow().size() == 0) : "All tasks didn't finish";
}
public double getValue () {
if (cacheIndicator.isValueStale()) {
// compute values again
try {
// run all threads and wait for them to finish
List> results = executor.invokeAll(valueTasks);
// compute final log probability
int batch = 0;
for (Future f : results) {
try {
batchCachedValue[batch++] = f.get();
} catch (ExecutionException ee) {
ee.printStackTrace();
}
}
} catch (InterruptedException ie) {
ie.printStackTrace();
}
double cachedValue = MatrixOps.sum(batchCachedValue);
logger.info("getValue() (loglikelihood, optimizable by label likelihood) =" + cachedValue);
return cachedValue;
}
return MatrixOps.sum(batchCachedValue);
}
/**
* Returns the gradient, re-computes if gradient is stale.
*
* *Note*: Assumes that buffer is already initialized.
*/
public void getValueGradient (double[] buffer) {
if (cacheIndicator.isGradientStale()) {
// compute values again if required
this.getValue();
// compute gradients again
try {
// run all threads and wait for them to finish
executor.invokeAll(gradientTasks);
} catch (InterruptedException ie) {
ie.printStackTrace();
}
}
optimizable.combineGradients(batchCachedGradient, buffer);
}
/**
* Creates tasks to be executed in parallel, each task looks at a batch of
* data.
*/
protected void createTasks() {
int numBatches = optimizable.getNumBatches();
valueTasks = new ArrayList>(numBatches);
gradientTasks = new ArrayList>(numBatches);
// number of instances per batch
int numBatchInstances = trainingSet.size() / numBatches;
// batch assignments
int start = -1, end = -1;
for (int i = 0; i < numBatches; ++i) {
// get the indices of batch
if (i == 0) {
start = 0;
end = start + numBatchInstances;
} else if (i == numBatches-1) {
start = end;
end = trainingSet.size();
} else {
start = end;
end = start + numBatchInstances;
}
valueTasks.add(new ValueHandler(i, new int[]{start, end}));
gradientTasks.add(new GradientHandler(i, new int[]{start, end}));
}
}
public int getNumParameters () { return optimizable.getNumParameters(); }
public void getParameters (double[] buffer) {
optimizable.getParameters(buffer);
}
public double getParameter (int index) {
return optimizable.getParameter(index);
}
public void setParameters (double [] buff) {
optimizable.setParameters(buff);
}
public void setParameter (int index, double value) {
optimizable.setParameter(index, value);
}
/**
* Computes value in a separate thread for a batch of data.
*/
private class ValueHandler implements Callable {
private int batchIndex;
private int[] batchAssignments;
public ValueHandler(int batchIndex, int[] batchAssignments) {
this.batchIndex = batchIndex;
this.batchAssignments = batchAssignments;
}
/**
* Returns the value for a batch.
*/
public Double call() {
return optimizable.getBatchValue(batchIndex, batchAssignments);
}
}
/**
* Computes gradient in a separate thread for a batch of data.
*/
private class GradientHandler implements Callable {
private int batchIndex;
private int[] batchAssignments;
public GradientHandler(int batchIndex, int[] batchAssignments) {
this.batchIndex = batchIndex;
this.batchAssignments = batchAssignments;
}
/**
* Computes the gradient for a batch, always returns true.
*/
public Boolean call() {
optimizable.getBatchValueGradient(batchCachedGradient.get(batchIndex),
batchIndex, batchAssignments);
return true;
}
}
}