All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.deeplearning4j.models.word2vec.Word2Vec Maven / Gradle / Ivy
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.models.word2vec;
import java.io.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;
import akka.actor.ActorSystem;
import com.google.common.base.Function;
import com.google.common.util.concurrent.AtomicDouble;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.bagofwords.vectorizer.TextVectorizer;
import org.deeplearning4j.bagofwords.vectorizer.TfidfVectorizer;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.parallel.Parallelization;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.deeplearning4j.text.invertedindex.LuceneInvertedIndex;
import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.deeplearning4j.text.stopwords.StopWords;
import org.deeplearning4j.text.tokenization.tokenizerfactory.UimaTokenizerFactory;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Leveraging a 3 layer neural net with a softmax approach as output,
* converts a word based on its context and the training examples in to a
* numeric vector
* @author Adam Gibson
*
*/
public class Word2Vec extends WordVectorsImpl {
protected static final long serialVersionUID = -2367495638286018038L;
protected transient TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
protected transient SentenceIterator sentenceIter;
protected transient DocumentIterator docIter;
protected int batchSize = 1000;
protected double sample = 0;
protected long totalWords = 1;
//learning rate
protected AtomicDouble alpha = new AtomicDouble(0.025);
//context to use for gathering word frequencies
protected int window = 5;
protected transient RandomGenerator g;
protected static final Logger log = LoggerFactory.getLogger(Word2Vec.class);
protected boolean shouldReset = true;
//number of iterations to run
protected int numIterations = 1;
public final static String UNK = "UNK";
protected long seed = 123;
protected boolean saveVocab = false;
protected double minLearningRate = 0.01;
protected transient TextVectorizer vectorizer;
protected int learningRateDecayWords = 10000;
protected InvertedIndex invertedIndex;
protected boolean useAdaGrad = false;
protected int workers = Runtime.getRuntime().availableProcessors();
public Word2Vec() {}
public TextVectorizer getVectorizer() {
return vectorizer;
}
public void setVectorizer(TextVectorizer vectorizer) {
this.vectorizer = vectorizer;
}
/**
* Train the model
*/
public void fit() throws IOException {
boolean loaded = buildVocab();
//save vocab after building
if (!loaded && saveVocab)
vocab().saveVocab();
if (stopWords == null)
readStopWords();
log.info("Training word2vec multithreaded");
if (sentenceIter != null)
sentenceIter.reset();
if (docIter != null)
docIter.reset();
int[] docs = vectorizer.index().allDocs();
if(docs.length < 1) {
vectorizer.fit();
}
docs = vectorizer.index().allDocs();
if(docs.length < 1) {
throw new IllegalStateException("No documents found");
}
totalWords = vectorizer.numWordsEncountered();
if(totalWords < 1)
throw new IllegalStateException("Unable to train, total words less than 1");
totalWords *= numIterations;
log.info("Processing sentences...");
AtomicLong numWordsSoFar = new AtomicLong(0);
final AtomicLong nextRandom = new AtomicLong(5);
ExecutorService exec = new ThreadPoolExecutor(Runtime.getRuntime().availableProcessors(),
Runtime.getRuntime().availableProcessors(),
0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue(), new RejectedExecutionHandler() {
@Override
public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
executor.submit(r);
}
});
final Queue> batch2 = new ConcurrentLinkedDeque<>();
vectorizer.index().eachDoc(new Function, Void>() {
@Override
public Void apply(List input) {
List batch = new ArrayList<>();
addWords(input, nextRandom, batch);
if(!batch.isEmpty()) {
batch2.add(batch);
}
return null;
}
},exec);
exec.shutdown();
try {
exec.awaitTermination(1,TimeUnit.DAYS);
} catch (InterruptedException e) {
e.printStackTrace();
}
ActorSystem actorSystem = ActorSystem.create();
for(int i = 0; i < numIterations; i++)
doIteration(batch2,numWordsSoFar,nextRandom,actorSystem);
actorSystem.shutdown();
}
private void doIteration(Collection> batch2,final AtomicLong numWordsSoFar,final AtomicLong nextRandom,ActorSystem actorSystem) {
final AtomicLong lastReported = new AtomicLong(System.currentTimeMillis());
Parallelization.iterateInParallel(batch2, new Parallelization.RunnableWithParams>() {
@Override
public void run(List sentence, Object[] args) {
double alpha = Math.max(minLearningRate, Word2Vec.this.alpha.get() *
(1 - (1.0 * numWordsSoFar.get() / (double) totalWords)));
long now = System.currentTimeMillis();
long diff = Math.abs(now - lastReported.get());
if (numWordsSoFar.get() > 0 && diff > 1000) {
lastReported.set(now);
log.info("Words so far " + numWordsSoFar.get() + " with alpha at " + alpha);
}
trainSentence(sentence, nextRandom, alpha);
numWordsSoFar.set(numWordsSoFar.get() + sentence.size());
}
},actorSystem);
}
protected void addWords(List sentence,AtomicLong nextRandom,List currMiniBatch) {
for (VocabWord word : sentence) {
if(word == null)
continue;
// The subsampling randomly discards frequent words while keeping the ranking same
if (sample > 0) {
double numDocs = vectorizer.index().numDocuments();
double ran = (Math.sqrt(word.getWordFrequency() / (sample * numDocs)) + 1)
* (sample * numDocs) / word.getWordFrequency();
if (ran < (nextRandom.get() & 0xFFFF) / (double) 65536) {
continue;
}
currMiniBatch.add(word);
}
else
currMiniBatch.add(word);
}
}
/**
* Build the binary tree
* Reset the weights
*/
public void setup() {
log.info("Building binary tree");
buildBinaryTree();
log.info("Resetting weights");
if(shouldReset)
resetWeights();
}
/**
* Builds the vocabulary for training
*/
public boolean buildVocab() {
readStopWords();
if(vocab().vocabExists()) {
log.info("Loading vocab...");
vocab().loadVocab();
lookupTable.resetWeights();
return true;
}
if(invertedIndex == null)
invertedIndex = new LuceneInvertedIndex.Builder()
.cache(vocab()).stopWords(stopWords)
.build();
//vectorizer will handle setting up vocab meta data
if(vectorizer == null) {
vectorizer = new TfidfVectorizer.Builder().index(invertedIndex)
.cache(vocab()).iterate(docIter).iterate(sentenceIter).batchSize(batchSize)
.minWords(minWordFrequency).stopWords(stopWords)
.tokenize(tokenizerFactory).build();
vectorizer.fit();
}
//includes unk
else if(vocab().numWords() < 2)
vectorizer.fit();
setup();
return false;
}
/**
* Train on a list of vocab words
* @param sentence the list of vocab words to train on
*/
public void trainSentence(final List sentence,AtomicLong nextRandom,double alpha) {
if(sentence == null || sentence.isEmpty())
return;
for(int i = 0; i < sentence.size(); i++) {
nextRandom.set(nextRandom.get() * 25214903917L + 11);
skipGram(i, sentence, (int) nextRandom.get() % window,nextRandom,alpha);
}
}
/**
* Train via skip gram
* @param i
* @param sentence
*/
public void skipGram(int i,List sentence, int b,AtomicLong nextRandom,double alpha) {
final VocabWord word = sentence.get(i);
if(word == null || sentence.isEmpty())
return;
int end = window * 2 + 1 - b;
for(int a = b; a < end; a++) {
if(a != window) {
int c = i - window + a;
if(c >= 0 && c < sentence.size()) {
VocabWord lastWord = sentence.get(c);
iterate(word,lastWord,nextRandom,alpha);
}
}
}
}
/**
* Train the word vector
* on the given words
* @param w1 the first word to fit
*/
public void iterate(VocabWord w1, VocabWord w2,AtomicLong nextRandom,double alpha) {
lookupTable.iterateSample(w1,w2,nextRandom,alpha);
}
/* Builds the binary tree for the word relationships */
protected void buildBinaryTree() {
log.info("Constructing priority queue");
Huffman huffman = new Huffman(vocab().vocabWords());
huffman.build();
log.info("Built tree");
}
/* reinit weights */
protected void resetWeights() {
lookupTable.resetWeights();
}
@SuppressWarnings("unchecked")
protected void readStopWords() {
if(this.stopWords != null)
return;
this.stopWords = StopWords.getStopWords();
}
/**
* Note that calling a setter on this
* means assumes that this is a training continuation
* and therefore weights should not be reset.
* @param sentenceIter
*/
public void setSentenceIter(SentenceIterator sentenceIter) {
this.sentenceIter = sentenceIter;
this.shouldReset = false;
}
/**
* restart training on next fit().
* Use when sentence iterator is set for new training.
*/
public void resetWeightsOnSetup() {
this.shouldReset = true;
}
public int getWindow() {
return window;
}
public List getStopWords() {
return stopWords;
}
public synchronized SentenceIterator getSentenceIter() {
return sentenceIter;
}
public TokenizerFactory getTokenizerFactory() {
return tokenizerFactory;
}
public void setTokenizerFactory(TokenizerFactory tokenizerFactory) {
this.tokenizerFactory = tokenizerFactory;
}
public static class Builder {
protected int minWordFrequency = 1;
protected int layerSize = 50;
protected SentenceIterator iter;
protected List stopWords = StopWords.getStopWords();
protected int window = 5;
protected TokenizerFactory tokenizerFactory;
protected VocabCache vocabCache;
protected DocumentIterator docIter;
protected double lr = 2.5e-2;
protected int iterations = 1;
protected long seed = 123;
protected boolean saveVocab = false;
protected int batchSize = 1000;
protected int learningRateDecayWords = 10000;
protected boolean useAdaGrad = false;
protected TextVectorizer textVectorizer;
protected double minLearningRate = 1e-2;
protected double negative = 0;
protected double sampling = 1e-5;
protected int workers = Runtime.getRuntime().availableProcessors();
protected InvertedIndex index;
protected WeightLookupTable lookupTable;
public Builder lookupTable(WeightLookupTable lookupTable) {
this.lookupTable = lookupTable;
return this;
}
public Builder index(InvertedIndex index) {
this.index = index;
return this;
}
public Builder workers(int workers) {
this.workers = workers;
return this;
}
public Builder sampling(double sample) {
this.sampling = sample;
return this;
}
public Builder negativeSample(double negative) {
this.negative = negative;
return this;
}
public Builder minLearningRate(double minLearningRate) {
this.minLearningRate = minLearningRate;
return this;
}
public Builder useAdaGrad(boolean useAdaGrad) {
this.useAdaGrad = useAdaGrad;
return this;
}
public Builder vectorizer(TextVectorizer textVectorizer) {
this.textVectorizer = textVectorizer;
return this;
}
public Builder learningRateDecayWords(int learningRateDecayWords) {
this.learningRateDecayWords = learningRateDecayWords;
return this;
}
public Builder batchSize(int batchSize) {
this.batchSize = batchSize;
return this;
}
public Builder saveVocab(boolean saveVocab){
this.saveVocab = saveVocab;
return this;
}
public Builder seed(long seed) {
this.seed = seed;
return this;
}
public Builder iterations(int iterations) {
this.iterations = iterations;
return this;
}
public Builder learningRate(double lr) {
this.lr = lr;
return this;
}
public Builder iterate(DocumentIterator iter) {
this.docIter = iter;
return this;
}
public Builder vocabCache(VocabCache cache) {
this.vocabCache = cache;
return this;
}
public Builder minWordFrequency(int minWordFrequency) {
this.minWordFrequency = minWordFrequency;
return this;
}
public Builder tokenizerFactory(TokenizerFactory tokenizerFactory) {
this.tokenizerFactory = tokenizerFactory;
return this;
}
public Builder layerSize(int layerSize) {
this.layerSize = layerSize;
return this;
}
public Builder stopWords(List stopWords) {
this.stopWords = stopWords;
return this;
}
public Builder windowSize(int window) {
this.window = window;
return this;
}
public Builder iterate(SentenceIterator iter) {
this.iter = iter;
return this;
}
public Word2Vec build() {
if(iter == null) {
Word2Vec ret = new Word2Vec();
ret.window = window;
ret.alpha.set(lr);
ret.vectorizer = textVectorizer;
ret.stopWords = stopWords;
ret.setVocab(vocabCache);
ret.numIterations = iterations;
ret.minWordFrequency = minWordFrequency;
ret.seed = seed;
ret.saveVocab = saveVocab;
ret.batchSize = batchSize;
ret.useAdaGrad = useAdaGrad;
ret.minLearningRate = minLearningRate;
ret.sample = sampling;
ret.workers = workers;
ret.invertedIndex = index;
ret.lookupTable = lookupTable;
try {
if (tokenizerFactory == null)
tokenizerFactory = new UimaTokenizerFactory();
}catch(Exception e) {
throw new RuntimeException(e);
}
if(vocabCache == null) {
vocabCache = new InMemoryLookupCache();
ret.setVocab(vocabCache);
}
if(lookupTable == null) {
lookupTable = new InMemoryLookupTable.Builder().negative(negative)
.useAdaGrad(useAdaGrad).lr(lr).cache(vocabCache)
.vectorLength(layerSize).build();
}
ret.docIter = docIter;
ret.lookupTable = lookupTable;
ret.tokenizerFactory = tokenizerFactory;
return ret;
}
else {
Word2Vec ret = new Word2Vec();
ret.alpha.set(lr);
ret.sentenceIter = iter;
ret.window = window;
ret.useAdaGrad = useAdaGrad;
ret.minLearningRate = minLearningRate;
ret.vectorizer = textVectorizer;
ret.stopWords = stopWords;
ret.minWordFrequency = minWordFrequency;
ret.setVocab(vocabCache);
ret.docIter = docIter;
ret.minWordFrequency = minWordFrequency;
ret.numIterations = iterations;
ret.seed = seed;
ret.numIterations = iterations;
ret.saveVocab = saveVocab;
ret.batchSize = batchSize;
ret.sample = sampling;
ret.workers = workers;
ret.invertedIndex = index;
ret.lookupTable = lookupTable;
try {
if (tokenizerFactory == null)
tokenizerFactory = new UimaTokenizerFactory();
}catch(Exception e) {
throw new RuntimeException(e);
}
if(vocabCache == null) {
vocabCache = new InMemoryLookupCache();
ret.setVocab(vocabCache);
}
if(lookupTable == null) {
lookupTable = new InMemoryLookupTable.Builder().negative(negative)
.useAdaGrad(useAdaGrad).lr(lr).cache(vocabCache)
.vectorLength(layerSize).build();
}
ret.lookupTable = lookupTable;
ret.tokenizerFactory = tokenizerFactory;
return ret;
}
}
}
}