All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.word2vec.Word2Vec Maven / Gradle / Ivy

The newest version!
package org.deeplearning4j.word2vec;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.Stack;
import java.util.TreeSet;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.commons.io.IOUtils;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.nn.Persistable;
import org.deeplearning4j.util.MatrixUtil;
import org.deeplearning4j.word2vec.actor.SentenceActor;
import org.deeplearning4j.word2vec.actor.VocabActor;
import org.deeplearning4j.word2vec.sentenceiterator.CollectionSentenceIterator;
import org.deeplearning4j.word2vec.sentenceiterator.SentenceIterator;
import org.deeplearning4j.word2vec.tokenizer.DefaultTokenizerFactory;
import org.deeplearning4j.word2vec.tokenizer.Tokenizer;
import org.deeplearning4j.word2vec.tokenizer.TokenizerFactory;
import org.deeplearning4j.word2vec.util.Util;
import org.deeplearning4j.word2vec.viterbi.Index;
import org.jblas.DoubleMatrix;
import org.jblas.SimpleBlas;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

import scala.concurrent.Future;
import akka.actor.ActorRef;
import akka.actor.ActorSystem;
import akka.actor.Props;
import akka.dispatch.Futures;
import akka.dispatch.OnComplete;
import akka.routing.RoundRobinPool;


/**
 * Leveraging a 3 layer neural net with a softmax approach as output,
 * converts a word based on its context and the training examples in to a
 * numeric vector
 * @author Adam Gibson
 *
 */
public class Word2Vec implements Persistable {


	private static final long serialVersionUID = -2367495638286018038L;
	private Map vocab = new ConcurrentHashMap();

	private transient TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
	private transient SentenceIterator sentenceIter;
	private int topNSize = 40;
	//matrix row of a given word
	private Index wordIndex = new Index();
	private int sample = 1;
	//learning rate
	private Double alpha = 0.025;
	private int wordCount  = 0;
	public final static double MIN_ALPHA =  0.001;
	//number of times the word must occur in the vocab to appear in the calculations, otherwise treat as unknown
	private int minWordFrequency = 5;
	//context to use for gathering word frequencies
	private int window = 5;
	private int trainWordsCount = 0;
	//number of neurons per layer
	private int layerSize = 50;
	private static Logger log = LoggerFactory.getLogger(Word2Vec.class);
	private int size = 0;
	private int words = 0;
	//input layer
	private DoubleMatrix syn0,syn0Norm;
	//hidden layer
	private DoubleMatrix syn1;
	private int allWordsCount = 0;
	private int numSentencesProcessed = 0;
	private static ActorSystem trainingSystem;
	private List stopWords;
	/* out of vocab */
	private double[] oob;
	private boolean shouldReset = true;

	public Word2Vec() {}

	/**
	 * Specify a sentence iterator
	 * 
	 * 
	 * 
	 * 
	 *
	 */
	public Word2Vec(SentenceIterator sentenceIter) {
		oob = new double[layerSize];
		Arrays.fill(oob,0.0);
		readStopWords();
		this.sentenceIter = sentenceIter;
		buildVocab();
	}



	public Word2Vec(SentenceIterator sentenceIter,int minWordFrequency) {
		oob = new double[layerSize];
		Arrays.fill(oob,0.0);
		readStopWords();
		this.sentenceIter = sentenceIter;
		this.minWordFrequency = minWordFrequency;
	}


	public Word2Vec(TokenizerFactory factory,SentenceIterator sentenceIter) {
		this(sentenceIter);
		this.tokenizerFactory = factory;
	}

	/**
	 * Specify a custom tokenizer, sentence iterator
	 * and minimum word frequency
	 * @param factory
	 * @param sentenceIter
	 * @param minWordFrequency
	 */
	public Word2Vec(TokenizerFactory factory,SentenceIterator sentenceIter,int minWordFrequency) {
		this(factory,sentenceIter);
		this.minWordFrequency = minWordFrequency;
	}



	/**
	 * Assumes whole dataset is passed in. 
	 * This is purely meant for batch methods.
	 * Same as calling {@link #Word2Vec(Collection, int)}
	 * with the second argument being 5
	 * @param sentences the sentences to use
	 * to train on
	 */
	public Word2Vec(Collection sentences) {
		this(sentences,5);
		readStopWords();

	}

	public Word2Vec(Collection sentences,TokenizerFactory factory) {
		this(sentences);
		this.tokenizerFactory = factory;
	}

	/**
	 * Initializes based on assumption of whole data set being passed in.
	 * @param sentences the sentences to be used for training
	 * @param minWordFrequency the minimum word frequency
	 * to be counted in the vocab
	 */
	public Word2Vec(Collection sentences,int minWordFrequency) {
		this.minWordFrequency = minWordFrequency;
		this.sentenceIter = new CollectionSentenceIterator(sentences);

		this.buildVocab();
		oob = new double[layerSize];
		Arrays.fill(oob,0.0);
		readStopWords();

	}


	public Word2Vec(Collection sentences,int minWordFrequency,TokenizerFactory factory) {
		this(sentences,minWordFrequency);
		this.tokenizerFactory = factory;
	}


	public double[] getWordVectorNormalized(String word) {
		int i = this.wordIndex.indexOf(word);
		if(i < 0) {
			i = wordIndex.indexOf("STOP");
			if(i < 0)
				return oob;
		}

		return syn0Norm.getRow(i).toArray();
	}

	public double[] getWordVector(String word) {
		int i = this.wordIndex.indexOf(word);
		if(i < 0) {
			i = wordIndex.indexOf("STOP");
			if(i < 0)
				return oob;
		}

		return syn0.getRow(i).toArray();
	}

	public int indexOf(String word) {
		return wordIndex.indexOf(word);
	}

	public DoubleMatrix getWordVectorMatrix(String word) {
		int i = this.wordIndex.indexOf(word);
		if(i < 0)
			return new DoubleMatrix(oob);
		return syn0.getRow(i);
	}

	public DoubleMatrix getWordVectorMatrixNormalized(String word) {
		int i = this.wordIndex.indexOf(word);
		if(i < 0)
			return new DoubleMatrix(oob);
		return syn0.getRow(i);
	}



	public VocabWord getWord(String key) {
		return vocab.get(key);
	}





	public Collection wordsNearest(String word,int n) {
		DoubleMatrix vec = this.getWordVectorMatrix(word);
		if(vec == null)
			return new ArrayList();
		Counter distances = new Counter();
		for(int i = 0; i < syn0.rows; i++) {
			double sim = similarity(word,wordIndex.get(i).toString());
			distances.incrementCount(wordIndex.get(i).toString(), sim);
		}


		distances.keepTopNKeys(n);
		return distances.keySet();

	}


	public List analogyWords(String w1,String w2,String w3) {
		TreeSet analogies = this.analogy(w1, w2, w3);
		List ret = new ArrayList();
		for(VocabWord w : analogies)
			ret.add(wordIndex.get(w.getIndex()).toString());
		return ret;
	}




	private void insertTopN(String name, double score, List wordsEntrys) {
		if (wordsEntrys.size() < topNSize) {
			VocabWord v = new VocabWord(score,layerSize);
			v.setIndex(wordIndex.indexOf(name));
			wordsEntrys.add(v);
			return;
		}
		double min = Float.MAX_VALUE;
		int minOffe = 0;
		int minIndex = -1;
		for (int i = 0; i < topNSize; i++) {
			VocabWord wordEntry = wordsEntrys.get(i);
			if (min > wordEntry.getWordFrequency()) {
				min = (double) wordEntry.getWordFrequency();
				minOffe = i;
				minIndex = wordEntry.getIndex();
			}
		}

		if (score > min) {
			VocabWord w = new VocabWord(score,layerSize);
			w.setIndex(minIndex);
			wordsEntrys.set(minOffe,w);
		}

	}

	public boolean hasWord(String word) {
		return wordIndex.indexOf(word) >= 0;
	}

	public void train() {
		if(trainingSystem == null)
			trainingSystem = ActorSystem.create();

		if(stopWords == null)
			readStopWords();
		log.info("Training word2vec multithreaded");

		

		final Counter totalWords = Util.parallelCounter();

		getSentenceIter().reset();

		final AtomicLong changed = new AtomicLong(System.currentTimeMillis());


		trainingSystem.actorOf(new RoundRobinPool(Runtime.getRuntime().availableProcessors() *3 ).props(Props.create(new SentenceActor.SentenceActorCreator(this)).withDispatcher("akka.actor.worker-dispatcher")));


		if(syn0.rows != this.vocab.size())
			throw new IllegalStateException("We appear to be missing vectors here. Unable to train. Please ensure vectors were loaded properly.");


		while(getSentenceIter().hasNext()) {
			final String sentence = sentenceIter.nextSentence();
			if(sentence != null) {
				Future f = Futures.future(new Callable() {

					@Override
					public Void call() throws Exception {
						processSentence(sentence, totalWords);
						return null;
					}

				},trainingSystem.dispatcher());
				f.onComplete(new OnComplete() {

					@Override
					public void onComplete(Throwable arg0, Void arg1)
							throws Throwable {
						if(arg0 != null)
							throw arg0;
						numSentencesProcessed++;
						changed.set(System.currentTimeMillis());

					}

				}, trainingSystem.dispatcher());

			}

			/*
			sentenceActor.tell(new SentenceMessage(totalWords, sentence, changed),sentenceActor);
			numSentences++;
			if(numSentences % 10000 == 0) {
				log.info("Sent " + numSentences + " for training");
			}*/
		}


		boolean done = false;
		long fiveMinutes = TimeUnit.MINUTES.toMillis(1);
		while(!done) {
			long curr = System.currentTimeMillis();
			long lastChanged = changed.get();
			long diff = Math.abs(curr - lastChanged);
			//hasn't changed for 5 minutes
			if(diff >= fiveMinutes) {
				done = true;	
			} 

			else
				try {
					Thread.sleep(15000);
				} catch (InterruptedException e) {
					Thread.currentThread().interrupt();
				}
		}


		log.info("Shutting down system; done training");

		if(trainingSystem != null)
			trainingSystem.shutdown();

	}

	public void processSentence(final String sentence,final Counter totalWords) {
		trainSentence(sentence, totalWords);
		if(numSentencesProcessed % 10000 == 0) {
			alpha = new Double(Math.max(MIN_ALPHA, alpha * (1 - 1.0 * totalWords.totalCount() / allWordsCount)));
			log.info("Alpha updated " + alpha + " progress " + numSentencesProcessed);
		}
	}




	public List trainSentence(String sentence,Counter totalWords) {
		Tokenizer tokenizer = tokenizerFactory.create(sentence);
		List sentence2 = new ArrayList();
		while(tokenizer.hasMoreTokens()) {
			String next = tokenizer.nextToken();
			if(stopWords.contains(next))
				next = "STOP";
			VocabWord word = vocab.get(next);
			if(word == null) 
				continue;

			sentence2.add(word);
			totalWords.incrementCount(next, 1.0);

		}

		trainSentence(sentence2);
		return sentence2;
	}


	/**
	 *
	 * 
	 * @param word
	 * @return
	 */
	public Set distance(String word) {
		DoubleMatrix wordVector = getWordVectorMatrix(word);
		if (wordVector == null) {
			return null;
		}
		DoubleMatrix tempVector = null;
		List wordEntrys = new ArrayList(topNSize);
		String name = null;
		for (int i = 0; i < syn0.rows; i++) {
			name = wordIndex.get(i).toString();
			if (name.equals(word)) {
				continue;
			}
			double dist = 0;
			tempVector = syn0.getRow(i);
			dist = wordVector.dot(tempVector);
			insertTopN(name, dist, wordEntrys);
		}
		return new TreeSet(wordEntrys);
	}

	/**
	 *
	 * @return 
	 */
	public TreeSet analogy(String word0, String word1, String word2) {
		DoubleMatrix wv0 = getWordVectorMatrix(word0);
		DoubleMatrix wv1 = getWordVectorMatrix(word1);
		DoubleMatrix wv2 = getWordVectorMatrix(word2);


		DoubleMatrix wordVector = wv1.sub(wv0).add(wv2);

		if (wv1 == null || wv2 == null || wv0 == null) 
			return null;

		DoubleMatrix tempVector;
		String name;
		List wordEntrys = new ArrayList(topNSize);
		for (int i = 0; i < syn0.rows; i++) {
			name = wordIndex.get(i).toString();

			if (name.equals(word0) || name.equals(word1) || name.equals(word2)) {
				continue;
			}
			tempVector = syn0.getRow(i);
			double dist = wordVector.dot(tempVector);
			insertTopN(name, dist, wordEntrys);
		}
		return new TreeSet(wordEntrys);
	}


	public void setup() {

		log.info("Building binary tree");
		buildBinaryTree();
		log.info("Resetting weights");
		if(shouldReset)
			resetWeights();
	}


	public void buildVocab() {
		readStopWords();

		if(trainingSystem == null)
			trainingSystem = ActorSystem.create();



		final Counter rawVocab = Util.parallelCounter();
		final AtomicLong semaphore = new AtomicLong(System.currentTimeMillis());
		final AtomicInteger numSentences = new AtomicInteger(0);
		int queued = 0;

		final ActorRef vocabActor = trainingSystem.actorOf(new RoundRobinPool(Runtime.getRuntime().availableProcessors()).props(Props.create(VocabActor.class,tokenizerFactory,wordIndex,minWordFrequency,vocab,layerSize,stopWords,rawVocab,semaphore)));

		/* all words; including those not in the actual ending index */
		while(getSentenceIter().hasNext()) {
			String sentence = getSentenceIter().nextSentence();
			if(sentence == null)
				continue;

			vocabActor.tell(sentence, vocabActor);
			log.info("Sent " + queued);
			queued++;



		}

		boolean done = false;
		long fiveMinutes = TimeUnit.MINUTES.toMillis(1);
		while(!done) {
			long curr = System.currentTimeMillis();
			long lastChanged = semaphore.get();
			long diff = Math.abs(curr - lastChanged);
			log.info("Waiting on setup...");
			//hasn't changed for 5 minutes
			if(diff >= fiveMinutes) {
				done = true;	
			} 

			else
				try {
					Thread.sleep(15000);
				} catch (InterruptedException e) {
					Thread.currentThread().interrupt();
				}
		}


		setup();

	}

	public void trainSentence(List sentence) {
		long nextRandom = 5;
		for(int i = 0; i < sentence.size(); i++) {
			VocabWord entry = sentence.get(i);
			// The subsampling randomly discards frequent words while keeping the ranking same
			if (sample > 0) {
				double ran = (Math.sqrt(entry.getWordFrequency() / (sample * trainWordsCount)) + 1)
						* (sample * trainWordsCount) / entry.getWordFrequency();
				nextRandom = nextRandom * 25214903917L + 11;
				if (ran < (nextRandom & 0xFFFF) / (double) 65536) {
					continue;
				}
			}
			nextRandom = nextRandom * 25214903917L + 11;
			int b = (int) nextRandom % window;
			skipGram(i,sentence,b);
		}
	}



	public void skipGram(int i,List sentence,int b) {
		VocabWord word = sentence.get(i);
		if(word == null)
			return;


		//subsampling
		for(int j = b; j < window * 2 + 1 - b; j++) {
			if(j == window)
				continue;
			int c1 = i - window + j;

			if (c1 < 0 || c1 >= sentence.size()) 
				continue;

			VocabWord word2 = sentence.get(c1);
			iterate(word,word2);
		}
	}

	public void  iterate(VocabWord w1,VocabWord w2) {
		DoubleMatrix l1 = syn0.getRow(w2.getIndex());
		DoubleMatrix l2a = syn1.getRows(w1.getCodes());
		DoubleMatrix fa = MatrixUtil.sigmoid(MatrixUtil.dot(l1, l2a.transpose()));
		// ga = (1 - word.code - fa) * alpha  # vector of error gradients multiplied by the learning rate
		DoubleMatrix ga = DoubleMatrix.ones(fa.length).sub(MatrixUtil.toMatrix(w1.getCodes())).sub(fa).mul(alpha);
		DoubleMatrix outer = ga.mmul(l1);
		for(int i = 0; i < w1.getPoints().length; i++) {
			DoubleMatrix toAdd = l2a.getRow(i).add(outer.getRow(i));
			syn1.putRow(w1.getPoints()[i],toAdd);
		}

		DoubleMatrix updatedInput = l1.add(MatrixUtil.dot(ga, l2a));
		syn0.putRow(w2.getIndex(),updatedInput);
	}




	/* Builds the binary tree for the word relationships */
	private void buildBinaryTree() {
		PriorityQueue heap = new PriorityQueue(vocab.values());
		int i = 0;
		while(heap.size() > 1) {
			VocabWord min1 = heap.poll();
			VocabWord min2 = heap.poll();


			VocabWord add = new VocabWord(min1.getWordFrequency() + min2.getWordFrequency(),layerSize);
			int index = (vocab.size() + i);

			add.setIndex(index); 
			add.setLeft(min1);
			add.setRight(min2);
			min1.setCode(0);
			min2.setCode(1);
			min1.setParent(add);
			min2.setParent(add);
			heap.add(add);
			i++;
		}

		Triple triple = new Triple(heap.poll(),new int[]{},new int[]{});
		Stack> stack = new Stack<>();
		stack.add(triple);
		while(!stack.isEmpty())  {
			triple = stack.pop();
			int[] codes = triple.getSecond();
			int[] points = triple.getThird();
			VocabWord node = triple.getFirst();
			if(node == null) {
				log.info("Node was null");
				continue;
			}
			if(node.getIndex() < vocab.size()) {
				node.setCodes(codes);
				node.setPoints(points);
			}
			else {
				int[] copy = plus(points,node.getIndex() - vocab.size());
				points = copy;
				triple.setThird(points);
				stack.add(new Triple(node.getLeft(),plus(codes,0),points));
				stack.add(new Triple(node.getRight(),plus(codes,1),points));

			}
		}



		log.info("Built tree");
	}

	private int[] plus (int[] addTo,int add) {
		int[] copy = new int[addTo.length + 1];
		for(int c = 0; c < addTo.length; c++)
			copy[c] = addTo[c];
		copy[addTo.length] = add;
		return copy;
	}


	/* reinit weights */
	private void resetWeights() {
		syn1 = DoubleMatrix.zeros(vocab.size(), layerSize);
		syn0 = DoubleMatrix.zeros(vocab.size(),layerSize);
		org.jblas.util.Random.seed(1);
		for(int i = 0; i < syn0.rows; i++)
			for(int j = 0; j < syn0.columns; j++) {
				syn0.put(i,j,(org.jblas.util.Random.nextDouble() - 0.5) / layerSize);
			}
	}


	/**
	 * Returns the similarity of 2 words
	 * @param word the first word
	 * @param word2 the second word
	 * @return a normalized similarity (cosine similarity)
	 */
	public double similarity(String word,String word2) {
		if(word.equals(word2))
			return 1.0;
		if(syn0Norm == null)
			this.syn0Norm = syn0.div(SimpleBlas.nrm2(syn0));
		DoubleMatrix vector = getWordVectorMatrixNormalized(word);
		DoubleMatrix vector2 = getWordVectorMatrixNormalized(word2);
		if(vector == null || vector2 == null)
			return -1;
		DoubleMatrix d1 = MatrixUtil.unitVec(vector);
		DoubleMatrix d2 = MatrixUtil.unitVec(vector2);
		double ret = d1.dot(d2);
		if(ret <  0)
			return 0;
		return ret;
	}




	@SuppressWarnings("unchecked")
	private void readStopWords() {
		try {
			stopWords = IOUtils.readLines(new ClassPathResource("/stopwords").getInputStream());
		} catch (IOException e) {
			throw new RuntimeException(e);
		}

	}


	public void setMinWordFrequency(int minWordFrequency) {
		this.minWordFrequency = minWordFrequency;
	}


	public int getLayerSize() {
		return layerSize;
	}
	public void setLayerSize(int layerSize) {
		this.layerSize = layerSize;
	}


	public int getTrainWordsCount() {
		return trainWordsCount;
	}


	public Index getWordIndex() {
		return wordIndex;
	}
	public void setWordIndex(Index wordIndex) {
		this.wordIndex = wordIndex;
	}


	public DoubleMatrix getSyn0() {
		return syn0;
	}

	public DoubleMatrix getSyn1() {
		return syn1;
	}


	public Map getVocab() {
		return vocab;
	}

	public double getAlpha() {
		return alpha;
	}




	public int getWordCount() {
		return wordCount;
	}




	public int getMinWordFrequency() {
		return minWordFrequency;
	}




	public int getWindow() {
		return window;
	}


	public int getTopNSize() {
		return topNSize;
	}

	public int getSample() {
		return sample;
	}



	public int getSize() {
		return size;
	}

	public double[]	 getOob() {
		return oob;
	}

	public int getWords() {
		return words;
	}



	public int getAllWordsCount() {
		return allWordsCount;
	}

	public static ActorSystem getTrainingSystem() {
		return trainingSystem;
	}


	public void setSyn0(DoubleMatrix syn0) {
		this.syn0 = syn0;
	}

	public void setSyn1(DoubleMatrix syn1) {
		this.syn1 = syn1;
	}

	public void setWindow(int window) {
		this.window = window;
	}




	public List getStopWords() {
		return stopWords;
	}

	public synchronized SentenceIterator getSentenceIter() {
		return sentenceIter;
	}



	public synchronized TokenizerFactory getTokenizerFactory() {
		return tokenizerFactory;
	}
	
	
	

	public synchronized void setTokenizerFactory(TokenizerFactory tokenizerFactory) {
		this.tokenizerFactory = tokenizerFactory;
	}

	/**
	 * Note that calling a setter on this
	 * means assumes that this is a training continuation
	 * and therefore weights should not be reset.
	 * @param sentenceIter
	 */
	public void setSentenceIter(SentenceIterator sentenceIter) {
		this.sentenceIter = sentenceIter;
		this.shouldReset = false;
	}

	@Override
	public void write(OutputStream os) {
		try {
			ObjectOutputStream dos = new ObjectOutputStream(os);

			dos.writeObject(this);

		} catch (IOException e) {
			throw new RuntimeException(e);
		}

	}

	@Override
	public void load(InputStream is) {
		try {
			ObjectInputStream ois = new ObjectInputStream(is);
			Word2Vec vec = (Word2Vec) ois.readObject();
			this.allWordsCount = vec.allWordsCount;
			this.alpha = vec.alpha;
			this.minWordFrequency = vec.minWordFrequency;
			this.numSentencesProcessed = vec.numSentencesProcessed;
			this.oob = vec.oob;
			this.sample = vec.sample;
			this.size = vec.size;
			this.wordIndex = vec.wordIndex;
			this.stopWords = vec.stopWords;
			this.syn0 = vec.syn0;
			this.syn1 = vec.syn1;
			this.topNSize = vec.topNSize;
			this.trainWordsCount = vec.trainWordsCount;
			this.window = vec.window;

		}catch(Exception e) {
			throw new RuntimeException(e);
		}



	}


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy