cc.mallet.grmm.util.CachingOptimizable Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
The newest version!
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.util;
import cc.mallet.optimize.Optimizable;
import cc.mallet.util.MalletLogger;
import java.util.logging.Logger;
/**
* Created: Aug 27, 2004
*
* @author getParameters(double[]) and setParameters(double[]).
* Subclasses may override this method for more efficient implemetnations.
*
* @param index
* @param value
*/
public void setParameter (int index, double value)
{
cachedValueStale = cachedGradientStale = true;
double[] params = new double[getNumParameters ()];
getParameters (params);
params[index] = value;
setParametersInternal (params);
}
/**
* Returns one parameter of the maximizable object. This default implementation
* inefficiently calls through to getParameters(double[]).
* Subclasses may override this method for more efficient implemetnations.
*
* @param index
* @return The value of parameter index
*/
public double getParameter (int index)
{
double[] params = new double[getNumParameters ()];
getParameters (params);
return params[index];
}
public void forceStale ()
{
cachedValueStale = cachedGradientStale = true;
}
}
/**/
public static abstract class ByGradient extends Base implements Optimizable.ByGradientValue {
protected abstract double computeValue ();
protected abstract void computeValueGradient (double[] buffer);
public void getValueGradient (double[] buffer)
{
if (buffer.length != getNumParameters ())
throw new IllegalArgumentException ("Argument is not of the " +
" correct dimensions");
if (cachedValueStale) {
cachedValue = computeValue ();
cachedValueStale = false;
}
if (cachedGradientStale) {
if (cachedGradient == null) {
cachedGradient = new double[getNumParameters ()];
}
computeValueGradient (cachedGradient);
cachedGradientStale = false;
}
System.arraycopy (cachedGradient, 0, buffer, 0, cachedGradient.length);
}
public double getValue ()
{
if (cachedValueStale) {
long startTime = System.currentTimeMillis();
cachedValue = computeValue ();
long endTime = System.currentTimeMillis();
logger.info ("Optimizable computeValue time (ms) ="+(endTime-startTime));
logger.info ("computeValue() = " + cachedValue);
cachedValueStale = false;
}
return cachedValue;
}
/**
* Sets the cached gradient. This is useful for subclasses that
* need to compute the value and the gradient at the same time.
* If they call this method in computeValue(), then
* their computeValueGradient() will never be called.
*
* @param gradient
*/
protected void setCachedGradient (double[] gradient)
{
if (cachedGradient == null) {
cachedGradient = new double[getNumParameters ()];
}
System.arraycopy (gradient, 0, cachedGradient, 0, gradient.length);
cachedGradientStale = false;
}
}
public static abstract class ByBatchGradient extends Base implements Optimizable.ByBatchGradient {
private int lastIndex;
private int[] lastAssns;
public void getBatchValueGradient (double[] buffer, int batchIndex, int[] batchAssignments)
{
if (buffer.length != getNumParameters ())
throw new IllegalArgumentException ("Argument is not of the " +
" correct dimensions");
if ((batchIndex != lastIndex) || (batchAssignments != lastAssns)) {
forceStale ();
lastIndex = batchIndex;
lastAssns = batchAssignments;
}
if (cachedValueStale) {
cachedValue = computeBatchValue (batchIndex, batchAssignments);
cachedValueStale = false;
}
if (cachedGradientStale) {
if (cachedGradient == null) {
cachedGradient = new double[getNumParameters ()];
}
computeBatchGradient (cachedGradient, batchIndex, batchAssignments);
cachedGradientStale = false;
}
System.arraycopy (cachedGradient, 0, buffer, 0, cachedGradient.length);
}
public double getBatchValue (int batchIndex, int[] batchAssignments)
{
if ((batchIndex != lastIndex) || (batchAssignments != lastAssns)) {
forceStale ();
lastIndex = batchIndex;
lastAssns = batchAssignments;
}
if (cachedValueStale) {
cachedValue = computeBatchValue (batchIndex, batchAssignments);
logger.info ("computeValue() = " + cachedValue);
cachedValueStale = false;
}
return cachedValue;
}
protected abstract double computeBatchValue (int batchIndex, int[] batchAssignments);
protected abstract void computeBatchGradient (double[] buffer, int batchIndex, int[] batchAssignments);
}
}