All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.semi_supervised.pr.ConstraintsOptimizableByPR Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.fst.semi_supervised.pr;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Logger;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.Transducer.TransitionIterator;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.optimize.Optimizable.ByGradientValue;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;

/**
 * Optimizable for E-step/I-projection in Posterior Regularization (PR).
 *
 * @author Kedar Bellare
 * @author Gregory Druck
 */

public class ConstraintsOptimizableByPR implements Serializable, ByGradientValue {
	private static Logger logger = MalletLogger.getLogger(ConstraintsOptimizableByPR.class.getName());
 
  private static final long serialVersionUID = 1;
	
  protected boolean cacheStale;
	protected int numParameters;
	protected int numThreads;
	protected InstanceList trainingSet;
	protected double cachedValue = -123456789;
	protected double[] cachedGradient;
	protected CRF crf;
	protected ThreadPoolExecutor executor;
	protected double[][][][] cachedDots;
	PRAuxiliaryModel model;

	public ConstraintsOptimizableByPR(CRF crf, InstanceList ilist, PRAuxiliaryModel model) {
		this(crf,ilist,model,1);
	}
	
	public ConstraintsOptimizableByPR(CRF crf, InstanceList ilist, PRAuxiliaryModel model, int numThreads) {
		this.crf = crf;
		this.trainingSet = ilist;

		this.model = model;
		this.numParameters = model.numParameters();
		cachedGradient = new double[numParameters];
    this.cacheStale = true;
		
		this.numThreads = numThreads;
		this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numThreads);
	
		cacheDotProducts();
		
	}

	public void cacheDotProducts() {
	  cachedDots = new double[trainingSet.size()][][][];
	  for (int i = 0; i < trainingSet.size(); i++) {
	    FeatureVectorSequence input = (FeatureVectorSequence)trainingSet.get(i).getData();

	    cachedDots[i] = new double[input.size()][crf.numStates()][crf.numStates()];
	    for (int j = 0; j < input.size(); j++) {
	      for (int k = 0; k < crf.numStates(); k++) {
	        for (int l = 0; l < crf.numStates(); l++) {
	          cachedDots[i][j][k][l] = Transducer.IMPOSSIBLE_WEIGHT;
	        }
	      }
	    }

	    for (int j = 0; j < input.size(); j++) {
	      for (int k = 0; k < crf.numStates(); k++) {
	        TransitionIterator iter = crf.getState(k).transitionIterator(input, j);
	        while (iter.hasNext()) {
	          int l = iter.next().getIndex();
	          cachedDots[i][j][k][l] = iter.getWeight();
	        }
	      }
	    }
	  }
	}
	
	public int getNumParameters() {
		return numParameters;
	}

	public void getParameters(double[] params) {
		model.getParameters(params);
	}

	public double getParameter(int index) {
		return model.getParameter(index);
	}

	public void setParameters(double[] params) {
	  cacheStale = true;
		model.setParameters(params);
	}

	public void setParameter(int index, double value) {
	  cacheStale = true;
		model.setParameter(index, value);
	}

	protected double getExpectationValue() {
		model.zeroExpectations();

		// updating tasks
		ArrayList> tasks = new ArrayList>();
		int increment = trainingSet.size() / numThreads;
		int start = 0;
		int end = increment;
	  for (int taskIndex = 0; taskIndex < numThreads; taskIndex++) {
			tasks.add(new ExpectationTask(start,end,model.copy()));
			start = end;
			if (taskIndex == numThreads - 2) {
				end = trainingSet.size();
			}
			else {
				end = start + increment;
			}
		}
		
		double value = 0;
		try {
		  List> results = executor.invokeAll(tasks);
		
			// compute value
			for (Future f : results) {
				try {
					value += f.get();
				} catch (ExecutionException ee) {
					ee.printStackTrace();
				}
			}
	  } catch (InterruptedException ie) {
		  ie.printStackTrace();
	  }

	  // combine results
	  combine(model,tasks);

		// mu*b - w*||mu||^2
		value += model.getValue();
		return value;
	}

	/**
	 * Returns the log probability of the training sequence labels and the prior
	 * over parameters.
	 */
	public double getValue() {
		if (cacheStale) {
		  cachedValue = getExpectationValue();
		  model.getValueGradient(cachedGradient);
		  cacheStale = false;
			logger.info("getValue (auxiliary distribution) = " + cachedValue);
		}
		return cachedValue;
	}

	public double getCompleteValueContribution() {
		if (cacheStale) {
			getValue();
		}
		double value = model.getCompleteValueContribution();
		return value;
	}
	
	public void getValueGradient(double[] buffer) {
		if (cacheStale) {
			getValue();
		}
		System.arraycopy(cachedGradient, 0, buffer, 0, cachedGradient.length);
	}

	private void combine(PRAuxiliaryModel orig, ArrayList> tasks) {
	  for (int i = 0; i < tasks.size(); i++) {
	  	ExpectationTask task = (ExpectationTask)tasks.get(i);
	  	PRAuxiliaryModel model = task.getModelCopy();
	  	for (int ci = 0; ci < model.numConstraints(); ci++) {
	  		PRConstraint origConstraint = orig.getConstraint(ci);
	  		PRConstraint copyConstraint = model.getConstraint(ci);
	  		double[] expectation = new double[origConstraint.numDimensions()];
	  		copyConstraint.getExpectations(expectation);
	  		origConstraint.addExpectations(expectation);
	  	}
	  }
	}

  public void shutdown() {
    executor.shutdown();
  }
  
  public double[][][][] getCachedDots() { 
    return cachedDots;
  }
  
  public PRAuxiliaryModel getAuxModel() { 
    return model;
  }
	
	private class ExpectationTask implements Callable {
		
		private int start;
		private int end;
    private PRAuxiliaryModel modelCopy;
    
    public ExpectationTask(int start, int end, PRAuxiliaryModel modelCopy) {
    	this.start = start;
    	this.end = end;
    	this.modelCopy = modelCopy;
    }
		
		public Double call() throws Exception {
			double value = 0;
	    for (int ii = start; ii < end; ii++) {
	      Instance inst = trainingSet.get(ii);
				Sequence input = (Sequence) inst.getData();
				// logZ			
				value -= new SumLatticePR(crf, ii, input, null, modelCopy, cachedDots[ii], true, null, null, false).getTotalWeight();
			}
	    return value;
		}
		
		public PRAuxiliaryModel getModelCopy() {
			return modelCopy;
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy