All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.optimize.AGIS Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */




/**
 * Implementation of Salakhutdinav and Roweis Adaptive Overrelaxed GIS (2003)
   @author Ryan McDonald [email protected]
 */

package cc.mallet.optimize;

import java.util.logging.*;

import cc.mallet.optimize.Optimizable;
import cc.mallet.types.MatrixOps;

public class AGIS implements Optimizer
{
	private static Logger logger =
	Logger.getLogger("edu.umass.cs.mallet.base.minimize.AGIS");

	double initialStepSize = 1;
	double alpha;
	double eta = 1.0;
	double tolerance = 0.0001;
	int maxIterations = 200;
	Optimizable.ByGISUpdate maxable;
	boolean converged = false;

	boolean backTrack;
	
	// "eps" is a small number to recitify the special case of converging
	// to exactly zero function value
	final double eps = 1.0e-10;
	
	public AGIS (Optimizable.ByGISUpdate maxable, double alph)
	{
		this(maxable,alph,true);
	}

	public AGIS (Optimizable.ByGISUpdate maxable, double alph, boolean backTrack)
	{
		this.maxable = maxable;
		this.alpha = alph;
		this.backTrack = backTrack;
	}
	
	public Optimizable getOptimizable () { return maxable; }
	public boolean isConverged () { return converged; }

	public boolean optimize () {
		return optimize (maxIterations);
	}
	
	public boolean optimize (int numIterations)
	{
		int iterations;
		double[] params = new double[maxable.getNumParameters()];
		double[] gis = new double[maxable.getNumParameters()];
		double[] old_params = new double[maxable.getNumParameters()];
		double[] updates = new double[maxable.getNumParameters()];
		
		maxable.getParameters(params);
		maxable.getParameters(gis);
		maxable.getParameters(old_params);

		
		for (iterations = 0; iterations < numIterations; iterations++) {

			boolean complete = false;
			double old = maxable.getValue();
			maxable.getGISUpdate(updates);
			MatrixOps.plusEquals(gis,updates);
			MatrixOps.plusEquals(params,updates,eta);
			maxable.setParameters(params);
			double next = maxable.getValue();
			
			// Different from normal AGIS, only fall back to GIS updates
			// If log-likelihood gets worse
			// i.e. if lower log-likelihood, always make AGIS update
			if(next > old) {
				complete = true;
				// don't let eta get too large
				if(eta*alpha < 99999999.0)
					eta = eta*alpha;
			}

			if(backTrack && complete == false) {
				// gone too far
				// unlike Roweis et al., we will back track on eta to find
				// acceptable value, instead of automatically setting it to 1
				while(eta > 1.0 && complete == false) {
					
					eta = eta/2.0;
					
					MatrixOps.set(params,old_params);
					
					MatrixOps.plusEquals(params,updates,eta);		
					maxable.setParameters(params);
					next = maxable.getValue();
					
					if(next > old)
						complete = true;
					
				}
			}
			else if(complete == false) {
				maxable.setParameters(gis);
				eta = 1.0;
				next = maxable.getValue();
			}
			
			logger.info("eta: " + eta);
			
			if (2.0*Math.abs(next-old) <= tolerance*(Math.abs(next)+Math.abs(old)+eps)) {
				converged = true;
				return true;
			}
			
			if(numIterations > 1) {
				maxable.getParameters(params);
				maxable.getParameters(old_params);
				maxable.getParameters(gis);
			}
		}
		converged = false;
		return false;
	}
	
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy