All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.grmm.util.CachingOptimizable Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.util;


import cc.mallet.optimize.Optimizable;
import cc.mallet.util.MalletLogger;

import java.util.logging.Logger;

/**
 * Created: Aug 27, 2004
 *
 * @author getParameters(double[]) and setParameters(double[]).
     * Subclasses may override this method for more efficient implemetnations.
     *
     * @param index
     * @param value
     */
    public void setParameter (int index, double value)
    {
      cachedValueStale = cachedGradientStale = true;
      double[] params = new double[getNumParameters ()];
      getParameters (params);
      params[index] = value;
      setParametersInternal (params);
    }

    /**
     * Returns one parameter of the maximizable object.  This default implementation
     * inefficiently calls through to getParameters(double[]).
     * Subclasses may override this method for more efficient implemetnations.
     *
     * @param index
     * @return The value of parameter index
     */
    public double getParameter (int index)
    {
      double[] params = new double[getNumParameters ()];
      getParameters (params);
      return params[index];
    }

    public void forceStale ()
    {
      cachedValueStale = cachedGradientStale = true;
    }

  }

   /**/
   public static abstract class ByGradient extends Base implements Optimizable.ByGradientValue {

    protected abstract double computeValue ();

    protected abstract void computeValueGradient (double[] buffer);

    public void getValueGradient (double[] buffer)
    {
      if (buffer.length != getNumParameters ())
        throw new IllegalArgumentException ("Argument is not of the " +
                " correct dimensions");

      if (cachedValueStale) {
        cachedValue = computeValue ();
        cachedValueStale = false;
      }
      if (cachedGradientStale) {
        if (cachedGradient == null) {
          cachedGradient = new double[getNumParameters ()];
        }
        computeValueGradient (cachedGradient);
        cachedGradientStale = false;
      }
      System.arraycopy (cachedGradient, 0, buffer, 0, cachedGradient.length);
    }

    public double getValue ()
    {
      if (cachedValueStale) {
        long startTime = System.currentTimeMillis();
        cachedValue = computeValue ();
        long endTime = System.currentTimeMillis();
        logger.info ("Optimizable computeValue time (ms) ="+(endTime-startTime));
        logger.info ("computeValue() = " + cachedValue);
        cachedValueStale = false;
      }
      return cachedValue;
    }

    /**
     * Sets the cached gradient.  This is useful for subclasses that
     * need to compute the value and the gradient at the same time.
     * If they call this method in computeValue(), then
     * their computeValueGradient() will never be called.
     *
     * @param gradient
     */
    protected void setCachedGradient (double[] gradient)
    {
      if (cachedGradient == null) {
        cachedGradient = new double[getNumParameters ()];
      }
      System.arraycopy (gradient, 0, cachedGradient, 0, gradient.length);
      cachedGradientStale = false;
    }

  }

  public static abstract class ByBatchGradient extends Base implements Optimizable.ByBatchGradient {

    private int lastIndex;
    private int[] lastAssns;

    public void getBatchValueGradient (double[] buffer, int batchIndex, int[] batchAssignments)
    {
      if (buffer.length != getNumParameters ())
        throw new IllegalArgumentException ("Argument is not of the " +
                " correct dimensions");

      if ((batchIndex != lastIndex) || (batchAssignments != lastAssns)) {
        forceStale ();
        lastIndex = batchIndex;
        lastAssns = batchAssignments;
      }

      if (cachedValueStale) {
        cachedValue = computeBatchValue (batchIndex, batchAssignments);
        cachedValueStale = false;
      }
      if (cachedGradientStale) {
        if (cachedGradient == null) {
          cachedGradient = new double[getNumParameters ()];
        }
        computeBatchGradient (cachedGradient, batchIndex, batchAssignments);
        cachedGradientStale = false;
      }
      System.arraycopy (cachedGradient, 0, buffer, 0, cachedGradient.length);
    }

    public double getBatchValue (int batchIndex, int[] batchAssignments)
    {
      if ((batchIndex != lastIndex) || (batchAssignments != lastAssns)) {
        forceStale ();
        lastIndex = batchIndex;
        lastAssns = batchAssignments;
      }

      if (cachedValueStale) {
        cachedValue = computeBatchValue (batchIndex, batchAssignments);
        logger.info ("computeValue() = " + cachedValue);
        cachedValueStale = false;
      }
      return cachedValue;
    }

    protected abstract double computeBatchValue (int batchIndex, int[] batchAssignments);

    protected abstract void computeBatchGradient (double[] buffer, int batchIndex, int[] batchAssignments);

  }
}