All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.pipe.iterator.RandomTokenSequenceIterator Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

There is a newer version: 2.0.12
Show newest version
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */




/** 
   @author Andrew McCallum [email protected]
 */

package cc.mallet.pipe.iterator;

import java.net.URI;
import java.util.Iterator;
import java.util.logging.*;

import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.iterator.PipeInputIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.Instance;
import cc.mallet.types.Label;
import cc.mallet.types.Multinomial;
import cc.mallet.types.TokenSequence;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;

public class RandomTokenSequenceIterator implements Iterator
{
	private static Logger logger = MalletLogger.getLogger(RandomTokenSequenceIterator.class.getName());

	Randoms r;
	Dirichlet classCentroidDistribution;
	double classCentroidAvergeAlphaMean;
	double classCentroidAvergeAlphaVariance;
	double featureVectorSizePoissonLambda;
	double classInstanceCountPoissonLamba;
	String[] classNames;

	int[] numInstancesPerClass;						// indexed over classes
	Dirichlet[] classCentroid;						// indexed over classes
	int currentClassIndex;
	int currentInstanceIndex;
	
	public RandomTokenSequenceIterator (Randoms r,
																			// the generator of all random-ness used here
																			Dirichlet classCentroidDistribution,
																			// includes a Alphabet
																			double classCentroidAvergeAlphaMean,
																			// Gaussian mean on the sum of alphas
																			double classCentroidAvergeAlphaVariance,
																			// Gaussian variance on the sum of alphas
																			double featureVectorSizePoissonLambda,
																			double classInstanceCountPoissonLamba,
																			String[] classNames)
	{
		this.r = r;
		this.classCentroidDistribution = classCentroidDistribution;
		assert (classCentroidDistribution.getAlphabet() instanceof Alphabet);
		this.classCentroidAvergeAlphaMean = classCentroidAvergeAlphaMean;
		this.classCentroidAvergeAlphaVariance = classCentroidAvergeAlphaVariance;
		this.featureVectorSizePoissonLambda = featureVectorSizePoissonLambda;
		this.classInstanceCountPoissonLamba = classInstanceCountPoissonLamba;
		this.classNames = classNames;
		this.numInstancesPerClass = new int[classNames.length];
		this.classCentroid = new Dirichlet[classNames.length];
		for (int i = 0; i < classNames.length; i++) {
			logger.fine ("classCentroidAvergeAlphaMean = "+classCentroidAvergeAlphaMean);
			double aveAlpha = r.nextGaussian (classCentroidAvergeAlphaMean,
																				classCentroidAvergeAlphaVariance);
			logger.fine ("aveAlpha = "+aveAlpha);
			classCentroid[i] = classCentroidDistribution.randomDirichlet (r, aveAlpha);
			//logger.fine ("Dirichlet for class "+classNames[i]);	classCentroid[i].print();
		}
		reset ();
	}

	public RandomTokenSequenceIterator (Randoms r, Alphabet vocab, String[] classnames)
	{
		this (r, new Dirichlet(vocab, 2.0),
					30, 0,
					10, 20, classnames);
	}

	public Alphabet getAlphabet () { return classCentroidDistribution.getAlphabet(); }

	private static Alphabet dictOfSize (int size)
	{
		Alphabet ret = new Alphabet ();
		for (int i = 0; i < size; i++)
			ret.lookupIndex ("feature"+i);
		return ret;
	}

	private static String[] classNamesOfSize (int size)
	{
		String[] ret = new String[size];
		for (int i = 0; i < size; i++)
			ret[i] = "class"+i;
		return ret;
	}

	public RandomTokenSequenceIterator (Randoms r, int vocabSize, int numClasses)
	{
		this (r, new Dirichlet(dictOfSize(vocabSize), 2.0),
					30, 0,
					10, 20, classNamesOfSize(numClasses));
	}

	public void reset ()
	{
		for (int i = 0; i < classNames.length; i++) {
			this.numInstancesPerClass[i] = r.nextPoisson (classInstanceCountPoissonLamba);
			logger.fine ("Class "+classNames[i]+" will have "
									 +numInstancesPerClass[i]+" instances.");
		}
		this.currentClassIndex = classNames.length - 1;
		this.currentInstanceIndex = numInstancesPerClass[currentClassIndex] - 1;
	}

	public Instance next ()
	{
		if (currentInstanceIndex < 0) {
			if (currentClassIndex <= 0)
				throw new IllegalStateException ("No next TokenSequence.");
			currentClassIndex--;
			currentInstanceIndex = numInstancesPerClass[currentClassIndex] - 1;
		}
		URI uri = null;
		try { uri = new URI ("random:" + classNames[currentClassIndex] + "/" + currentInstanceIndex); }
		catch (Exception e) {e.printStackTrace(); throw new IllegalStateException (); }
		//xxx Producing small numbers? int randomSize = r.nextPoisson (featureVectorSizePoissonLambda);
		int randomSize = (int)featureVectorSizePoissonLambda;
		TokenSequence ts = classCentroid[currentClassIndex].randomTokenSequence (r, randomSize);
		//logger.fine ("FeatureVector "+currentClassIndex+" "+currentInstanceIndex); fv.print();
		currentInstanceIndex--;
		return new Instance (ts, classNames[currentClassIndex], uri, null);
	}

	public boolean hasNext ()	{	return ! (currentClassIndex <= 0 && currentInstanceIndex <= 0);	}
	
	public void remove () {
		throw new IllegalStateException ("This Iterator does not support remove().");
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy