Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.sequencevectors;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.val;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.impl.elements.BatchSequences;
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
import org.deeplearning4j.models.embeddings.learning.impl.sequence.DBOW;
import org.deeplearning4j.models.embeddings.learning.impl.sequence.DM;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.sequencevectors.enums.ListenerEvent;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.util.ThreadUtils;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* SequenceVectors implements abstract features extraction for Sequences and SequenceElements, using SkipGram, CBOW or DBOW (for Sequence features extraction).
*
*
* @author [email protected]
*/
public class SequenceVectors extends WordVectorsImpl implements WordVectors {
private static final long serialVersionUID = 78249242142L;
@Getter
protected transient SequenceIterator iterator;
@Setter
protected transient ElementsLearningAlgorithm elementsLearningAlgorithm;
protected transient SequenceLearningAlgorithm sequenceLearningAlgorithm;
@Getter
protected VectorsConfiguration configuration = new VectorsConfiguration();
protected static final Logger log = LoggerFactory.getLogger(SequenceVectors.class);
protected transient WordVectors existingModel;
protected transient WordVectors intersectModel;
protected transient T unknownElement;
protected transient AtomicDouble scoreElements = new AtomicDouble(0.0);
protected transient AtomicDouble scoreSequences = new AtomicDouble(0.0);
protected transient boolean configured = false;
protected transient boolean lockFactor = false;
protected boolean enableScavenger = false;
protected int vocabLimit = 0;
private BatchSequences batchSequences;
@Setter
protected transient Set> eventListeners;
@Override
public String getUNK() {
return configuration.getUNK();
}
@Override
public void setUNK(String UNK) {
configuration.setUNK(UNK);
super.setUNK(UNK);
}
public double getElementsScore() {
return scoreElements.get();
}
public double getSequencesScore() {
return scoreSequences.get();
}
@Override
public INDArray getWordVectorMatrix(String word) {
if (configuration.isUseUnknown() && !hasWord(word)) {
return super.getWordVectorMatrix(getUNK());
} else
return super.getWordVectorMatrix(word);
}
/**
* Builds vocabulary from provided SequenceIterator instance
*/
public void buildVocab() {
val constructor = new VocabConstructor.Builder().addSource(iterator, minWordFrequency)
.setTargetVocabCache(vocab).fetchLabels(trainSequenceVectors).setStopWords(stopWords)
.enableScavenger(enableScavenger).setEntriesLimit(vocabLimit)
.allowParallelTokenization(configuration.isAllowParallelTokenization())
.setUnk(useUnknown && unknownElement != null ? unknownElement : null).build();
if (existingModel != null && lookupTable instanceof InMemoryLookupTable
&& existingModel.lookupTable() instanceof InMemoryLookupTable) {
log.info("Merging existing vocabulary into the current one...");
/*
if we have existing model defined, we're forced to fetch labels only.
the rest of vocabulary & weights should be transferred from existing model
*/
constructor.buildMergedVocabulary(existingModel, true);
/*
Now we have vocab transferred, and we should transfer syn0 values into lookup table
*/
((InMemoryLookupTable) lookupTable)
.consume((InMemoryLookupTable) existingModel.lookupTable());
} else {
log.info("Starting vocabulary building...");
// if we don't have existing model defined, we just build vocabulary
constructor.buildJointVocabulary(false, true);
/*
if (useUnknown && unknownElement != null && !vocab.containsWord(unknownElement.getLabel())) {
log.info("Adding UNK element...");
unknownElement.setSpecial(true);
unknownElement.markAsLabel(false);
unknownElement.setIndex(vocab.numWords());
vocab.addToken(unknownElement);
}
*/
// check for malformed inputs. if numWords/numSentences ratio is huge, then user is passing something weird
if (vocab.numWords() / constructor.getNumberOfSequences() > 1000) {
log.warn("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!");
log.warn("! !");
log.warn("! Your input looks malformed: number of sentences is too low, model accuracy may suffer !");
log.warn("! !");
log.warn("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!");
}
}
}
protected synchronized void initLearners() {
if (!configured) {
log.info("Building learning algorithms:");
if (trainElementsVectors && elementsLearningAlgorithm != null && !trainSequenceVectors) {
log.info(" building ElementsLearningAlgorithm: [" + elementsLearningAlgorithm.getCodeName()
+ "]");
elementsLearningAlgorithm.configure(vocab, lookupTable, configuration);
elementsLearningAlgorithm.pretrain(iterator);
}
if (trainSequenceVectors && sequenceLearningAlgorithm != null) {
log.info(" building SequenceLearningAlgorithm: [" + sequenceLearningAlgorithm.getCodeName()
+ "]");
sequenceLearningAlgorithm.configure(vocab, lookupTable, configuration);
sequenceLearningAlgorithm.pretrain(this.iterator);
// we'll use the ELA compatible with selected SLA
if (trainElementsVectors) {
elementsLearningAlgorithm = sequenceLearningAlgorithm.getElementsLearningAlgorithm();
log.info(" building ElementsLearningAlgorithm: [" + elementsLearningAlgorithm.getCodeName()
+ "]");
}
}
configured = true;
}
}
private void initIntersectVectors() {
if (intersectModel != null && intersectModel.vocab().numWords() > 0) {
List indexes = new ArrayList<>();
for (int i = 0; i < intersectModel.vocab().numWords(); ++i) {
String externalWord = intersectModel.vocab().wordAtIndex(i);
int index = this.vocab.indexOf(externalWord);
if (index >= 0) {
this.vocab.wordFor(externalWord).setLocked(lockFactor);
indexes.add(index);
}
}
if (indexes.size() > 0) {
int[] intersectIndexes = Ints.toArray(indexes);
Nd4j.scatterUpdate(org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate.UpdateOp.ASSIGN,
((InMemoryLookupTable) lookupTable).getSyn0(),
Nd4j.createFromArray(intersectIndexes),
((InMemoryLookupTable) intersectModel.lookupTable()).getSyn0(),
1);
}
}
}
/**
* Starts training over
*/
public void fit() {
val props = Nd4j.getExecutioner().getEnvironmentInformation();
if (props.getProperty("backend").equals("CUDA")) {
if (Nd4j.getAffinityManager().getNumberOfDevices() > 1)
throw new IllegalStateException("Multi-GPU word2vec/doc2vec isn't available atm");
//if (!NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())
//throw new IllegalStateException("Running Word2Vec on multi-gpu system requires P2P support between GPUs, which looks to be unavailable on your system.");
}
Nd4j.getRandom().setSeed(configuration.getSeed());
AtomicLong timeSpent = new AtomicLong(0);
if (!trainElementsVectors && !trainSequenceVectors)
throw new IllegalStateException(
"You should define at least one training goal 'trainElementsRepresentation' or 'trainSequenceRepresentation'");
if (iterator == null)
throw new IllegalStateException("You can't fit() data without SequenceIterator defined");
if (resetModel || (lookupTable != null && vocab != null && vocab.numWords() == 0)) {
// build vocabulary from scratches
buildVocab();
}
WordVectorSerializer.printOutProjectedMemoryUse(vocab.numWords(), configuration.getLayersSize(),
configuration.isUseHierarchicSoftmax() && configuration.getNegative() > 0 ? 3 : 2);
if (vocab == null || lookupTable == null || vocab.numWords() == 0)
throw new IllegalStateException("You can't fit() model with empty Vocabulary or WeightLookupTable");
// if model vocab and lookupTable is built externally we basically should check that lookupTable was properly initialized
if (!resetModel || existingModel != null) {
lookupTable.resetWeights(false);
} else {
// otherwise we reset weights, independent of actual current state of lookup table
lookupTable.resetWeights(true);
// if preciseWeights used, we roll over data once again
if (configuration.isPreciseWeightInit()) {
log.info("Using precise weights init...");
iterator.reset();
while (iterator.hasMoreSequences()) {
val sequence = iterator.nextSequence();
// initializing elements, only once
for (T element : sequence.getElements()) {
T realElement = vocab.tokenFor(element.getLabel());
if (realElement != null && !realElement.isInit()) {
val rng = Nd4j.getRandomFactory().getNewRandomInstance(
configuration.getSeed() * realElement.hashCode(),
configuration.getLayersSize() + 1);
val randArray = Nd4j.rand(new int[] {1, configuration.getLayersSize()}, rng).subi(0.5)
.divi(configuration.getLayersSize());
lookupTable.getWeights().getRow(realElement.getIndex(), true).assign(randArray);
realElement.setInit(true);
}
}
// initializing labels, only once
for (T label : sequence.getSequenceLabels()) {
T realElement = vocab.tokenFor(label.getLabel());
if (realElement != null && !realElement.isInit()) {
Random rng = Nd4j.getRandomFactory().getNewRandomInstance(
configuration.getSeed() * realElement.hashCode(),
configuration.getLayersSize() + 1);
INDArray randArray = Nd4j.rand(new int[] {1, configuration.getLayersSize()}, rng).subi(0.5)
.divi(configuration.getLayersSize());
lookupTable.getWeights().getRow(realElement.getIndex(), true).assign(randArray);
realElement.setInit(true);
}
}
}
this.iterator.reset();
}
}
initLearners();
initIntersectVectors();
log.info("Starting learning process...");
timeSpent.set(System.currentTimeMillis());
if (this.stopWords == null)
this.stopWords = new ArrayList<>();
val wordsCounter = new AtomicLong(0);
for (int currentEpoch = 1; currentEpoch <= numEpochs; currentEpoch++) {
val linesCounter = new AtomicLong(0);
val sequencer = new AsyncSequencer(this.iterator, this.stopWords);
sequencer.start();
val timer = new AtomicLong(System.currentTimeMillis());
val thread = new VectorCalculationsThread(0, currentEpoch, wordsCounter, vocab.totalWordOccurrences(),
linesCounter, sequencer, timer, numEpochs);
thread.start();
try {
sequencer.join();
} catch (Exception e) {
throw new RuntimeException(e);
}
try {
thread.join();
} catch (Exception e) {
throw new RuntimeException(e);
}
// TODO: fix this to non-exclusive termination
if (trainElementsVectors && elementsLearningAlgorithm != null
&& (!trainSequenceVectors || sequenceLearningAlgorithm == null)
&& elementsLearningAlgorithm.isEarlyTerminationHit()) {
break;
}
if (trainSequenceVectors && sequenceLearningAlgorithm != null
&& (!trainElementsVectors || elementsLearningAlgorithm == null)
&& sequenceLearningAlgorithm.isEarlyTerminationHit()) {
break;
}
log.info("Epoch [" + currentEpoch + "] finished; Elements processed so far: [" + wordsCounter.get()
+ "]; Sequences processed: [" + linesCounter.get() + "]");
if (eventListeners != null && !eventListeners.isEmpty()) {
for (VectorsListener listener : eventListeners) {
if (listener.validateEvent(ListenerEvent.EPOCH, currentEpoch))
listener.processEvent(ListenerEvent.EPOCH, this, currentEpoch);
}
}
}
log.info("Time spent on training: {} ms", System.currentTimeMillis() - timeSpent.get());
}
protected void trainSequence(@NonNull Sequence sequence, AtomicLong nextRandom, double alpha) {
if (sequence.getElements().isEmpty())
return;
/*
we do NOT train elements separately if sequnceLearningAlgorithm isn't CBOW
we skip that, because PV-DM includes CBOW
*/
if (trainElementsVectors && !(trainSequenceVectors && sequenceLearningAlgorithm instanceof DM)) {
// call for ElementsLearningAlgorithm
nextRandom.set(nextRandom.get() * 25214903917L + 11);
if (!elementsLearningAlgorithm.isEarlyTerminationHit()) {
scoreElements.set(elementsLearningAlgorithm.learnSequence(sequence, nextRandom, alpha, batchSequences));
}
else
scoreElements.set(elementsLearningAlgorithm.learnSequence(sequence, nextRandom, alpha));
}
if (trainSequenceVectors) {
// call for SequenceLearningAlgorithm
nextRandom.set(nextRandom.get() * 25214903917L + 11);
if (!sequenceLearningAlgorithm.isEarlyTerminationHit())
scoreSequences.set(sequenceLearningAlgorithm.learnSequence(sequence, nextRandom, alpha, batchSequences));
}
}
public static class Builder {
protected VocabCache vocabCache;
protected WeightLookupTable lookupTable;
protected SequenceIterator iterator;
protected ModelUtils modelUtils = new BasicModelUtils<>();
protected WordVectors existingVectors;
protected SequenceVectors intersectVectors;
protected boolean lockFactor = false;
protected double sampling = 0;
protected double negative = 0;
protected double learningRate = 0.025;
protected double minLearningRate = 0.0001;
protected int minWordFrequency = 0;
protected int iterations = 1;
protected int numEpochs = 1;
protected int layerSize = 100;
protected int window = 5;
protected boolean hugeModelExpected = false;
protected int batchSize = 512;
protected int learningRateDecayWords;
protected long seed;
protected boolean useAdaGrad = false;
protected boolean resetModel = true;
protected int workers = Runtime.getRuntime().availableProcessors();
protected boolean useUnknown = false;
protected boolean useHierarchicSoftmax = true;
protected int[] variableWindows;
protected boolean trainSequenceVectors = false;
protected boolean trainElementsVectors = true;
protected boolean preciseWeightInit = false;
protected Collection stopWords = new ArrayList<>();
protected VectorsConfiguration configuration = new VectorsConfiguration();
protected transient T unknownElement;
protected String UNK = configuration.getUNK();
protected String STOP = configuration.getSTOP();
protected boolean enableScavenger = false;
protected int vocabLimit;
/**
* Experimental field. Switches on precise mode for batch operations.
*/
protected boolean preciseMode = false;
// defaults values for learning algorithms are set here
protected ElementsLearningAlgorithm elementsLearningAlgorithm = new SkipGram<>();
protected SequenceLearningAlgorithm sequenceLearningAlgorithm = new DBOW<>();
protected Set> vectorsListeners = new HashSet<>();
public Builder() {
}
public Builder(@NonNull VectorsConfiguration configuration) {
this.configuration = configuration;
this.iterations = configuration.getIterations();
this.numEpochs = configuration.getEpochs();
this.minLearningRate = configuration.getMinLearningRate();
this.learningRate = configuration.getLearningRate();
this.sampling = configuration.getSampling();
this.negative = configuration.getNegative();
this.minWordFrequency = configuration.getMinWordFrequency();
this.seed = configuration.getSeed();
this.hugeModelExpected = configuration.isHugeModelExpected();
this.batchSize = configuration.getBatchSize();
this.layerSize = configuration.getLayersSize();
this.learningRateDecayWords = configuration.getLearningRateDecayWords();
this.useAdaGrad = configuration.isUseAdaGrad();
this.window = configuration.getWindow();
this.UNK = configuration.getUNK();
this.STOP = configuration.getSTOP();
this.variableWindows = configuration.getVariableWindows();
this.useHierarchicSoftmax = configuration.isUseHierarchicSoftmax();
this.preciseMode = configuration.isPreciseMode();
if (configuration.getModelUtils() != null && !configuration.getModelUtils().isEmpty()) {
try {
this.modelUtils = (ModelUtils) Class.forName(configuration.getModelUtils()).newInstance();
} catch (Exception e) {
log.error("Got {} trying to instantiate ModelUtils, falling back to BasicModelUtils instead");
this.modelUtils = new BasicModelUtils<>();
}
}
if (configuration.getElementsLearningAlgorithm() != null
&& !configuration.getElementsLearningAlgorithm().isEmpty()) {
this.elementsLearningAlgorithm(configuration.getElementsLearningAlgorithm());
}
if (configuration.getSequenceLearningAlgorithm() != null
&& !configuration.getSequenceLearningAlgorithm().isEmpty()) {
this.sequenceLearningAlgorithm(configuration.getSequenceLearningAlgorithm());
}
if (configuration.getStopList() != null)
this.stopWords.addAll(configuration.getStopList());
}
/**
* This method allows you to use pre-built WordVectors model (SkipGram or GloVe) for DBOW sequence learning.
* Existing model will be transferred into new model before training starts.
*
* PLEASE NOTE: This model has no effect for elements learning algorithms. Only sequence learning is affected.
* PLEASE NOTE: Non-normalized model is recommended to use here.
*
* @param vec existing WordVectors model
* @return
*/
protected Builder useExistingWordVectors(@NonNull WordVectors vec) {
this.existingVectors = vec;
return this;
}
/**
* This method defines SequenceIterator to be used for model building
* @param iterator
* @return
*/
public Builder iterate(@NonNull SequenceIterator iterator) {
this.iterator = iterator;
return this;
}
/**
* Sets specific LearningAlgorithm as Sequence Learning Algorithm
*
* @param algoName fully qualified class name
* @return
*/
public Builder sequenceLearningAlgorithm(@NonNull String algoName) {
try {
Class clazz = Class.forName(algoName);
sequenceLearningAlgorithm = (SequenceLearningAlgorithm) clazz.newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
return this;
}
/**
* Sets specific LearningAlgorithm as Sequence Learning Algorithm
*
* @param algorithm SequenceLearningAlgorithm implementation
* @return
*/
public Builder sequenceLearningAlgorithm(@NonNull SequenceLearningAlgorithm algorithm) {
this.sequenceLearningAlgorithm = algorithm;
return this;
}
/**
* * Sets specific LearningAlgorithm as Elements Learning Algorithm
*
* @param algoName fully qualified class name
* @return
*/
public Builder elementsLearningAlgorithm(@NonNull String algoName) {
try {
Class clazz = Class.forName(algoName);
elementsLearningAlgorithm = (ElementsLearningAlgorithm) clazz.newInstance();
this.configuration.setElementsLearningAlgorithm(elementsLearningAlgorithm.getClass().getCanonicalName());
} catch (Exception e) {
throw new RuntimeException(e);
}
return this;
}
/**
* * Sets specific LearningAlgorithm as Elements Learning Algorithm
*
* @param algorithm ElementsLearningAlgorithm implementation
* @return
*/
public Builder elementsLearningAlgorithm(@NonNull ElementsLearningAlgorithm algorithm) {
this.elementsLearningAlgorithm = algorithm;
return this;
}
/**
* This method defines batchSize option, viable only if iterations > 1
*
* @param batchSize
* @return
*/
public Builder batchSize(int batchSize) {
this.batchSize = batchSize;
return this;
}
/**
* This method defines how much iterations should be done over batched sequences.
*
* @param iterations
* @return
*/
public Builder iterations(int iterations) {
this.iterations = iterations;
return this;
}
/**
* This method defines how much iterations should be done over whole training corpus during modelling
* @param numEpochs
* @return
*/
public Builder epochs(int numEpochs) {
this.numEpochs = numEpochs;
return this;
}
/**
* Sets number of worker threads to be used in calculations
*
* @param numWorkers
* @return
*/
public Builder workers(int numWorkers) {
this.workers = numWorkers;
return this;
}
/**
* Enable/disable hierarchic softmax
*
* @param reallyUse
* @return
*/
public Builder useHierarchicSoftmax(boolean reallyUse) {
this.useHierarchicSoftmax = reallyUse;
return this;
}
/**
* This method defines if Adaptive Gradients should be used in calculations
*
* @param reallyUse
* @return
*/
@Deprecated
public Builder useAdaGrad(boolean reallyUse) {
this.useAdaGrad = reallyUse;
return this;
}
/**
* This method defines number of dimensions for outcome vectors.
* Please note: This option has effect only if lookupTable wasn't defined during building process.
*
* @param layerSize
* @return
*/
public Builder layerSize(int layerSize) {
this.layerSize = layerSize;
return this;
}
/**
* This method defines initial learning rate.
* Default value is 0.025
*
* @param learningRate
* @return
*/
public Builder learningRate(double learningRate) {
this.learningRate = learningRate;
return this;
}
/**
* This method defines minimal element frequency for elements found in the training corpus. All elements with frequency below this threshold will be removed before training.
* Please note: this method has effect only if vocabulary is built internally.
*
* @param minWordFrequency
* @return
*/
public Builder minWordFrequency(int minWordFrequency) {
this.minWordFrequency = minWordFrequency;
return this;
}
/**
* This method sets vocabulary limit during construction.
*
* Default value: 0. Means no limit
*
* @param limit
* @return
*/
public Builder limitVocabularySize(int limit) {
if (limit < 0)
throw new DL4JInvalidConfigException("Vocabulary limit should be non-negative number");
this.vocabLimit = limit;
return this;
}
/**
* This method defines minimum learning rate after decay being applied.
* Default value is 0.01
*
* @param minLearningRate
* @return
*/
public Builder minLearningRate(double minLearningRate) {
this.minLearningRate = minLearningRate;
return this;
}
/**
* This method defines, should all model be reset before training. If set to true, vocabulary and WeightLookupTable will be reset before training, and will be built from scratches
*
* @param reallyReset
* @return
*/
public Builder resetModel(boolean reallyReset) {
this.resetModel = reallyReset;
return this;
}
/**
* You can pass externally built vocabCache object, containing vocabulary
*
* @param vocabCache
* @return
*/
public Builder vocabCache(@NonNull VocabCache vocabCache) {
this.vocabCache = vocabCache;
return this;
}
/**
* You can pass externally built WeightLookupTable, containing model weights and vocabulary.
*
* @param lookupTable
* @return
*/
public Builder lookupTable(@NonNull WeightLookupTable lookupTable) {
this.lookupTable = lookupTable;
this.layerSize(lookupTable.layerSize());
return this;
}
/**
* This method defines sub-sampling threshold.
*
* @param sampling
* @return
*/
public Builder sampling(double sampling) {
this.sampling = sampling;
return this;
}
/**
* This method defines negative sampling value for skip-gram algorithm.
*
* @param negative
* @return
*/
public Builder negativeSample(double negative) {
this.negative = negative;
return this;
}
/**
* You can provide collection of objects to be ignored, and excluded out of model
* Please note: Object labels and hashCode will be used for filtering
*
* @param stopList
* @return
*/
public Builder stopWords(@NonNull List stopList) {
this.stopWords.addAll(stopList);
return this;
}
/**
*
* @param trainElements
* @return
*/
public Builder trainElementsRepresentation(boolean trainElements) {
this.trainElementsVectors = trainElements;
return this;
}
public Builder trainSequencesRepresentation(boolean trainSequences) {
this.trainSequenceVectors = trainSequences;
return this;
}
/**
* You can provide collection of objects to be ignored, and excluded out of model
* Please note: Object labels and hashCode will be used for filtering
*
* @param stopList
* @return
*/
public Builder stopWords(@NonNull Collection stopList) {
for (T word : stopList) {
this.stopWords.add(word.getLabel());
}
return this;
}
/**
* Sets window size for skip-Gram training
*
* @param windowSize
* @return
*/
public Builder windowSize(int windowSize) {
this.window = windowSize;
return this;
}
/**
* Sets seed for random numbers generator.
* Please note: this has effect only if vocabulary and WeightLookupTable is built internally
*
* @param randomSeed
* @return
*/
public Builder seed(long randomSeed) {
// has no effect in original w2v actually
this.seed = randomSeed;
return this;
}
/**
* ModelUtils implementation, that will be used to access model.
* Methods like: similarity, wordsNearest, accuracy are provided by user-defined ModelUtils
*
* @param modelUtils model utils to be used
* @return
*/
public Builder modelUtils(@NonNull ModelUtils modelUtils) {
this.modelUtils = modelUtils;
this.configuration.setModelUtils(modelUtils.getClass().getCanonicalName());
return this;
}
/**
* This method allows you to specify, if UNK word should be used internally
* @param reallyUse
* @return
*/
public Builder useUnknown(boolean reallyUse) {
this.useUnknown = reallyUse;
this.configuration.setUseUnknown(reallyUse);
return this;
}
/**
* This method allows you to specify SequenceElement that will be used as UNK element, if UNK is used
* @param element
* @return
*/
public Builder unknownElement(@NonNull T element) {
this.unknownElement = element;
this.UNK = element.getLabel();
this.configuration.setUNK(this.UNK);
return this;
}
/**
* This method allows to use variable window size. In this case, every batch gets processed using one of predefined window sizes
*
* @param windows
* @return
*/
public Builder useVariableWindow(int... windows) {
if (windows == null || windows.length == 0)
throw new IllegalStateException("Variable windows can't be empty");
variableWindows = windows;
return this;
}
/**
* If set to true, initial weights for elements/sequences will be derived from elements themself.
* However, this implies additional cycle through input iterator.
*
* Default value: FALSE
*
* @param reallyUse
* @return
*/
public Builder usePreciseWeightInit(boolean reallyUse) {
this.preciseWeightInit = reallyUse;
this.configuration.setPreciseWeightInit(reallyUse);
return this;
}
public Builder usePreciseMode(boolean reallyUse) {
this.preciseMode = reallyUse;
this.configuration.setPreciseMode(reallyUse);
return this;
}
/**
* This method creates new WeightLookupTable and VocabCache if there were none set
*/
protected void presetTables() {
if (lookupTable == null) {
if (vocabCache == null) {
vocabCache = new AbstractCache.Builder().hugeModelExpected(hugeModelExpected)
.scavengerRetentionDelay(this.configuration.getScavengerRetentionDelay())
.scavengerThreshold(this.configuration.getScavengerActivationThreshold())
.minElementFrequency(minWordFrequency).build();
}
lookupTable = new InMemoryLookupTable.Builder().useAdaGrad(this.useAdaGrad).cache(vocabCache)
.negative(negative).useHierarchicSoftmax(useHierarchicSoftmax).vectorLength(layerSize)
.lr(learningRate).seed(seed).build();
}
if (this.configuration.getElementsLearningAlgorithm() != null) {
try {
elementsLearningAlgorithm = (ElementsLearningAlgorithm) Class
.forName(this.configuration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (this.configuration.getSequenceLearningAlgorithm() != null) {
try {
sequenceLearningAlgorithm = (SequenceLearningAlgorithm) Class
.forName(this.configuration.getSequenceLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (trainElementsVectors && elementsLearningAlgorithm == null) {
// create default implementation of ElementsLearningAlgorithm
elementsLearningAlgorithm = new SkipGram<>();
}
if (trainSequenceVectors && sequenceLearningAlgorithm == null) {
sequenceLearningAlgorithm = new DBOW<>();
}
this.modelUtils.init(lookupTable);
}
/**
* This method sets VectorsListeners for this SequenceVectors model
*
* @param listeners
* @return
*/
public Builder setVectorsListeners(@NonNull Collection> listeners) {
vectorsListeners.addAll(listeners);
return this;
}
/**
* This method ebables/disables periodical vocab truncation during construction
*
* Default value: disabled
*
* @param reallyEnable
* @return
*/
public Builder enableScavenger(boolean reallyEnable) {
this.enableScavenger = reallyEnable;
return this;
}
public Builder intersectModel(@NonNull SequenceVectors intersectVectors, boolean lockFactor) {
this.intersectVectors = intersectVectors;
this.lockFactor = lockFactor;
return this;
}
/**
* Build SequenceVectors instance with defined settings/options
* @return
*/
public SequenceVectors build() {
presetTables();
SequenceVectors vectors = new SequenceVectors<>();
if (this.existingVectors != null) {
this.trainElementsVectors = false;
this.elementsLearningAlgorithm = null;
}
vectors.numEpochs = this.numEpochs;
vectors.numIterations = this.iterations;
vectors.vocab = this.vocabCache;
vectors.minWordFrequency = this.minWordFrequency;
vectors.learningRate.set(this.learningRate);
vectors.minLearningRate = this.minLearningRate;
vectors.sampling = this.sampling;
vectors.negative = this.negative;
vectors.layerSize = this.layerSize;
vectors.batchSize = this.batchSize;
vectors.learningRateDecayWords = this.learningRateDecayWords;
vectors.window = this.window;
vectors.resetModel = this.resetModel;
vectors.useAdeGrad = this.useAdaGrad;
vectors.stopWords = this.stopWords;
vectors.workers = this.workers;
vectors.iterator = this.iterator;
vectors.lookupTable = this.lookupTable;
vectors.modelUtils = this.modelUtils;
vectors.useUnknown = this.useUnknown;
vectors.unknownElement = this.unknownElement;
vectors.variableWindows = this.variableWindows;
vectors.vocabLimit = this.vocabLimit;
vectors.trainElementsVectors = this.trainElementsVectors;
vectors.trainSequenceVectors = this.trainSequenceVectors;
vectors.elementsLearningAlgorithm = this.elementsLearningAlgorithm;
vectors.sequenceLearningAlgorithm = this.sequenceLearningAlgorithm;
vectors.existingModel = this.existingVectors;
vectors.intersectModel = this.intersectVectors;
vectors.enableScavenger = this.enableScavenger;
vectors.lockFactor = this.lockFactor;
this.configuration.setLearningRate(this.learningRate);
this.configuration.setLayersSize(layerSize);
this.configuration.setHugeModelExpected(hugeModelExpected);
this.configuration.setWindow(window);
this.configuration.setMinWordFrequency(minWordFrequency);
this.configuration.setIterations(iterations);
this.configuration.setSeed(seed);
this.configuration.setBatchSize(batchSize);
this.configuration.setLearningRateDecayWords(learningRateDecayWords);
this.configuration.setMinLearningRate(minLearningRate);
this.configuration.setSampling(this.sampling);
this.configuration.setUseAdaGrad(useAdaGrad);
this.configuration.setNegative(negative);
this.configuration.setEpochs(this.numEpochs);
this.configuration.setStopList(this.stopWords);
this.configuration.setVariableWindows(variableWindows);
this.configuration.setUseHierarchicSoftmax(this.useHierarchicSoftmax);
this.configuration.setPreciseWeightInit(this.preciseWeightInit);
this.configuration.setModelUtils(this.modelUtils.getClass().getCanonicalName());
vectors.configuration = this.configuration;
return vectors;
}
}
/**
* This class is used to fetch data from iterator in background thread, and convert it to List
*
* It becomes very usefull if text processing pipeline behind iterator is complex, and we're not loading data from simple text file with whitespaces as separator.
* Since this method allows you to hide preprocessing latency in background.
*
* This mechanics will be change to PrefetchingSentenceIterator wrapper.
*/
protected class AsyncSequencer extends Thread implements Runnable {
private final SequenceIterator iterator;
private final LinkedBlockingQueue> buffer;
// private final AtomicLong linesCounter;
private final int limitUpper;
private final int limitLower;
private AtomicBoolean isRunning = new AtomicBoolean(true);
private AtomicLong nextRandom;
private Collection stopList;
private static final int DEFAULT_BUFFER_SIZE = 512;
public AsyncSequencer(SequenceIterator iterator, @NonNull Collection stopList) {
this.iterator = iterator;
// this.linesCounter = linesCounter;
this.setName("AsyncSequencer thread");
this.nextRandom = new AtomicLong(workers + 1);
this.iterator.reset();
this.stopList = stopList;
this.setDaemon(true);
limitLower = workers * (batchSize < DEFAULT_BUFFER_SIZE ? DEFAULT_BUFFER_SIZE : batchSize);
limitUpper = limitLower * 2;
this.buffer = new LinkedBlockingQueue<>(limitUpper);
}
// Preserve order of input sequences to gurantee order of output tokens
@Override
public void run() {
isRunning.set(true);
while (this.iterator.hasMoreSequences()) {
// if buffered level is below limitLower, we're going to fetch limitUpper number of strings from fetcher
if (buffer.size() < limitLower) {
update();
AtomicInteger linesLoaded = new AtomicInteger(0);
while (linesLoaded.getAndIncrement() < limitUpper && this.iterator.hasMoreSequences()) {
Sequence document = this.iterator.nextSequence();
/*
We can't hope/assume that underlying iterator contains synchronized elements
That's why we're going to rebuild sequence from vocabulary
*/
Sequence newSequence = new Sequence<>();
if (document.getSequenceLabel() != null) {
T newLabel = vocab.wordFor(document.getSequenceLabel().getLabel());
if (newLabel != null)
newSequence.setSequenceLabel(newLabel);
}
for (T element : document.getElements()) {
if (stopList.contains(element.getLabel()))
continue;
T realElement = vocab.wordFor(element.getLabel());
// please note: this serquence element CAN be absent in vocab, due to minFreq or stopWord or whatever else
if (realElement != null) {
newSequence.addElement(realElement);
} else if (useUnknown && unknownElement != null) {
newSequence.addElement(unknownElement);
}
}
// due to subsampling and null words, new sequence size CAN be 0, so there's no need to insert empty sequence into processing chain
if (!newSequence.getElements().isEmpty())
try {
buffer.put(newSequence);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
linesLoaded.incrementAndGet();
}
} else {
ThreadUtils.uncheckedSleep(50);
}
}
isRunning.set(false);
}
public boolean hasMoreLines() {
// statement order does matter here, since there's possible race condition
return !buffer.isEmpty() || isRunning.get();
}
public Sequence nextSentence() {
try {
return buffer.poll(3L, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return null;
}
}
}
/**
* VectorCalculationsThreads are used for vector calculations, and work together with AsyncIteratorDigitizer.
* Basically, all they do is just transfer of digitized sentences into math layer.
*
* Please note, they do not iterate the sentences over and over, each sentence processed only once.
* Training corpus iteration is implemented in fit() method.
*
*/
private class VectorCalculationsThread extends Thread implements Runnable {
private final int threadId;
private final int epochNumber;
private final AtomicLong wordsCounter;
private final long totalWordsCount;
private final AtomicLong totalLines;
private final AsyncSequencer digitizer;
private final AtomicLong nextRandom;
private final AtomicLong timer;
private final long startTime;
private final int totalEpochs;
/*
Long constructors suck, so this should be reduced to something reasonable later
*/
public VectorCalculationsThread(int threadId, int epoch, AtomicLong wordsCounter, long totalWordsCount,
AtomicLong linesCounter, AsyncSequencer digitizer, AtomicLong timer, int totalEpochs) {
this.threadId = threadId;
this.totalEpochs = totalEpochs;
this.epochNumber = epoch;
this.wordsCounter = wordsCounter;
this.totalWordsCount = totalWordsCount;
this.totalLines = linesCounter;
this.digitizer = digitizer;
this.timer = timer;
this.startTime = timer.get();
this.nextRandom = new AtomicLong(this.threadId);
this.setName("VectorCalculationsThread " + this.threadId);
}
@Override
public void run() {
// small workspace, just to handle
val conf = WorkspaceConfiguration.builder()
.policyLearning(LearningPolicy.OVER_TIME)
.cyclesBeforeInitialization(3)
.initialSize(25L * 1024L * 1024L)
.build();
val workspace_id = "sequence_vectors_training_" + java.util.UUID.randomUUID().toString();
Nd4j.getAffinityManager().getDeviceForCurrentThread();
while (digitizer.hasMoreLines()) {
try {
// get current sentence as list of VocabularyWords
List> sequences = new ArrayList<>();
for (int x = 0; x < batchSize; x++) {
if (digitizer.hasMoreLines()) {
Sequence sequence = digitizer.nextSentence();
if (sequence != null) {
sequences.add(sequence);
}
}
}
double alpha = 0.025;
if (sequences.isEmpty()) {
continue;
}
// getting back number of iterations
for (int i = 0; i < numIterations; i++) {
batchSequences = new BatchSequences<>(configuration.getBatchSize());
// we roll over sequences derived from digitizer, it's NOT window loop
for (int x = 0; x < sequences.size(); x++) {
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(conf, workspace_id)) {
Sequence sequence = sequences.get(x);
//log.info("LR before: {}; wordsCounter: {}; totalWordsCount: {}", learningRate.get(), this.wordsCounter.get(), this.totalWordsCount);
alpha = Math.max(minLearningRate,
learningRate.get() * (1 - (1.0 * this.wordsCounter.get()
/ ((double) this.totalWordsCount) / (numIterations
* totalEpochs))));
trainSequence(sequence, nextRandom, alpha);
// increment processed word count, please note: this affects learningRate decay
totalLines.incrementAndGet();
this.wordsCounter.addAndGet(sequence.getElements().size());
if (totalLines.get() % 100000 == 0) {
long currentTime = System.currentTimeMillis();
long timeSpent = currentTime - timer.get();
timer.set(currentTime);
long totalTimeSpent = currentTime - startTime;
double seqSec = (100000.0 / ((double) timeSpent / 1000.0));
double wordsSecTotal = this.wordsCounter.get() / ((double) totalTimeSpent / 1000.0);
log.info("Epoch: [{}]; Words vectorized so far: [{}]; Lines vectorized so far: [{}]; Seq/sec: [{}]; Words/sec: [{}]; learningRate: [{}]",
this.epochNumber, this.wordsCounter.get(), this.totalLines.get(),
String.format("%.2f", seqSec), String.format("%.2f", wordsSecTotal),
alpha);
}
if (eventListeners != null && !eventListeners.isEmpty()) {
for (VectorsListener listener : eventListeners) {
if (listener.validateEvent(ListenerEvent.LINE, totalLines.get()))
listener.processEvent(ListenerEvent.LINE, SequenceVectors.this,
totalLines.get());
}
}
}
}
if (elementsLearningAlgorithm instanceof SkipGram)
((SkipGram)elementsLearningAlgorithm).setWorkers(workers);
else if (elementsLearningAlgorithm instanceof CBOW)
((CBOW)elementsLearningAlgorithm).setWorkers(workers);
int batchSize = configuration.getBatchSize();
if (batchSize > 1 && batchSequences != null) {
int rest = batchSequences.size() % batchSize;
int chunks = ((batchSequences.size() >= batchSize) ? batchSequences.size() / batchSize : 0) + ((rest > 0)? 1 : 0);
for (int j = 0; j < chunks; ++j) {
if (trainElementsVectors) {
if (elementsLearningAlgorithm instanceof SkipGram)
((SkipGram) elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
else if (elementsLearningAlgorithm instanceof CBOW)
((CBOW) elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
}
if (trainSequenceVectors) {
if (sequenceLearningAlgorithm instanceof DBOW)
((SkipGram) sequenceLearningAlgorithm.getElementsLearningAlgorithm()).iterateSample(batchSequences.get(j));
else if (sequenceLearningAlgorithm instanceof DM)
((CBOW) sequenceLearningAlgorithm.getElementsLearningAlgorithm()).iterateSample(batchSequences.get(j));
}
}
batchSequences.clear();
batchSequences = null;
}
if (eventListeners != null && !eventListeners.isEmpty()) {
for (VectorsListener listener : eventListeners) {
if (listener.validateEvent(ListenerEvent.ITERATION, i))
listener.processEvent(ListenerEvent.ITERATION, SequenceVectors.this, i);
}
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (trainElementsVectors) {
elementsLearningAlgorithm.finish();
}
if (trainSequenceVectors) {
sequenceLearningAlgorithm.finish();
}
}
}
}