All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.topicmodeling.CountVectorizer Maven / Gradle / Ivy

The newest version!
package org.deeplearning4j.topicmodeling;

import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.util.MatrixUtil;
import org.deeplearning4j.word2vec.sentenceiterator.SentenceIterator;
import org.deeplearning4j.word2vec.tokenizer.Tokenizer;
import org.deeplearning4j.word2vec.tokenizer.TokenizerFactory;
import org.deeplearning4j.word2vec.viterbi.Index;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Given a {@link SentenceIterator} and a {@link TokenizerFactory}
 * and an {@link Index} constructs a word count vector
 * @author Adam Gibson
 *
 */
public class CountVectorizer {

	private SentenceIterator iter;
	private TokenizerFactory tokenizerFactory;
	private Index wordsToCount;
	private static Logger log = LoggerFactory.getLogger(CountVectorizer.class);
	
	public CountVectorizer(SentenceIterator iter,
			TokenizerFactory tokenizerFactory, Index wordsToCount) {
		super();
		this.iter = iter;
		this.tokenizerFactory = tokenizerFactory;
		this.wordsToCount = wordsToCount;
	}
	

	public CountVectorizer(SentenceIterator iter,
			TokenizerFactory tokenizerFactory) {
		super();
		this.iter = iter;
		this.tokenizerFactory = tokenizerFactory;
		this.wordsToCount = new Index();
	}
	
	
	public DoubleMatrix toNormalizedVector() {
		return MatrixUtil.normalizeByRowSums(toVector());
	}
	
	public DoubleMatrix toVector() {
		DoubleMatrix d = new DoubleMatrix(1,wordsToCount.size());
		Counter wordFrequencies = new Counter();
		while(iter.hasNext()) {
			String sentence = iter.nextSentence();
			if(sentence == null)
				continue;
			Tokenizer t = tokenizerFactory.create(sentence);
			while(t.hasMoreTokens()) {
				String token = t.nextToken();
				log.info("Token " + token);
				if(wordsToCount.indexOf(token) >= 0 || wordsToCount.size() < 1) 
					wordFrequencies.incrementCount(token, 1.0);
					
			}
		}
		
		for(int i = 0; i < wordsToCount.size(); i++) {
			d.put(i,wordFrequencies.getCount(wordsToCount.get(i).toString()));
		}
		
		return d;
	}
	
	public DoubleMatrix toBinaryVector() {
		DoubleMatrix d = new DoubleMatrix(1,wordsToCount.size());
		Counter wordFrequencies = new Counter();
		while(iter.hasNext()) {
			String sentence = iter.nextSentence();
			if(sentence == null)
				continue;
			Tokenizer t = tokenizerFactory.create(sentence);
			while(t.hasMoreTokens()) {
				String token = t.nextToken();
				log.info("Token " + token);
				if(wordsToCount.indexOf(token) >= 0 || wordsToCount.size() < 1) 
					wordFrequencies.incrementCount(token, 1.0);
					
			}
		}
		
		for(int i = 0; i < wordsToCount.size(); i++) {
			double count = wordFrequencies.getCount(wordsToCount.get(i).toString());
			d.put(i,count > 0 ? 1 : 0);
		}
		
		return d;
	}
	
	
	
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy