All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.sequencevectors.SequenceVectors Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.models.sequencevectors;

import com.google.common.util.concurrent.AtomicDouble;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
import org.deeplearning4j.models.embeddings.learning.impl.sequence.DBOW;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.sequencevectors.enums.ListenerEvent;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

/**
 * SequenceVectors implements abstract features extraction for Sequences and SequenceElements, using SkipGram, CBOW or DBOW (for Sequence features extraction).
 *
 *
 * @author [email protected]
 */
public class SequenceVectors extends WordVectorsImpl implements WordVectors {
    @Getter protected transient SequenceIterator iterator;

    protected transient ElementsLearningAlgorithm elementsLearningAlgorithm;
    protected transient SequenceLearningAlgorithm sequenceLearningAlgorithm;

    @Getter protected VectorsConfiguration configuration;

    protected static final Logger log = LoggerFactory.getLogger(SequenceVectors.class);

    protected transient WordVectors existingModel;
    protected transient T unknownElement;
    protected transient AtomicDouble scoreElements = new AtomicDouble(0.0);
    protected transient AtomicDouble scoreSequences = new AtomicDouble(0.0);



    @Setter protected transient Set> eventListeners;


    public double getElementsScore() {
        return scoreElements.get();
    }

    public double getSequencesScore() {
        return scoreSequences.get();
    }

    /**
     * Builds vocabulary from provided SequenceIterator instance
     */
    public void buildVocab() {


        VocabConstructor constructor = new VocabConstructor.Builder()
                .addSource(iterator, minWordFrequency)
                .setTargetVocabCache(vocab)
                .fetchLabels(trainSequenceVectors)
                .setStopWords(stopWords)
                .build();

        if (existingModel != null && lookupTable instanceof InMemoryLookupTable && existingModel.lookupTable() instanceof InMemoryLookupTable) {
            log.info("Merging existing vocabulary into the current one...");
            /*
                if we have existing model defined, we're forced to fetch labels only.
                the rest of vocabulary & weights should be transferred from existing model
             */

            constructor.buildMergedVocabulary(existingModel, true);

            /*
                Now we have vocab transferred, and we should transfer syn0 values into lookup table
             */
            ((InMemoryLookupTable) lookupTable).consume((InMemoryLookupTable) existingModel.lookupTable());
        } else {
            log.info("Starting vocabulary building...");
            // if we don't have existing model defined, we just build vocabulary


            constructor.buildJointVocabulary(false, true);

            if (useUnknown && unknownElement != null && !vocab.containsWord(unknownElement.getLabel())) {
                log.info("Adding UNK element...");
                unknownElement.setSpecial(true);
                unknownElement.markAsLabel(false);
                unknownElement.setIndex(vocab.numWords());
                vocab.addToken(unknownElement);

            }


            // check for malformed inputs. if numWords/numSentences ratio is huge, then user is passing something weird
            if (vocab.numWords() / constructor.getNumberOfSequences() > 1000) {
                log.warn("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!");
                log.warn("!                                                                                       !");
                log.warn("! Your input looks malformed: number of sentences is too low, model accuracy may suffer !");
                log.warn("!                                                                                       !");
                log.warn("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!");
            }
        }


    }



    /**
     * Starts training over
     */
    public void fit() {
        if (!trainElementsVectors && !trainSequenceVectors) throw new IllegalStateException("You should define at least one training goal 'trainElementsRepresentation' or 'trainSequenceRepresentation'");
        if (iterator == null) throw new IllegalStateException("You can't fit() data without SequenceIterator defined");

        if (resetModel || (lookupTable != null && vocab != null && vocab.numWords() == 0)) {
            // build vocabulary from scratches
            buildVocab();
        }

        if (vocab == null || lookupTable == null || vocab.numWords() == 0) throw new IllegalStateException("You can't fit() model with empty Vocabulary or WeightLookupTable");

        // if model vocab and lookupTable is built externally we basically should check that lookupTable was properly initialized
        if (!resetModel || existingModel != null) {
            lookupTable.resetWeights(false);
        } else {
            // otherwise we reset weights, independent of actual current state of lookup table
            lookupTable.resetWeights(true);
        }

        log.info("Building learning algorithms:");
        if (trainElementsVectors && elementsLearningAlgorithm != null) {
            log.info("          building ElementsLearningAlgorithm: [" +elementsLearningAlgorithm.getCodeName()+ "]");
            elementsLearningAlgorithm.configure(vocab, lookupTable, configuration);
            elementsLearningAlgorithm.pretrain(iterator);
        }
        if (trainSequenceVectors && sequenceLearningAlgorithm != null) {
            log.info("          building SequenceLearningAlgorithm: [" +sequenceLearningAlgorithm.getCodeName()+ "]");
            sequenceLearningAlgorithm.configure(vocab, lookupTable, configuration);
            sequenceLearningAlgorithm.pretrain(this.iterator);
        }

        log.info("Starting learning process...");
        if (this.stopWords == null) this.stopWords = new ArrayList<>();
        for (int currentEpoch = 1; currentEpoch <= numEpochs; currentEpoch++) {
            final AtomicLong linesCounter = new AtomicLong(0);
            final AtomicLong wordsCounter = new AtomicLong(0);

            AsyncSequencer sequencer = new AsyncSequencer(this.iterator, this.stopWords);
            sequencer.start();


            //final VectorCalculationsThread[] threads = new VectorCalculationsThread[workers];
            final List threads = new ArrayList<>();
            for (int x = 0; x < workers; x++) {
                threads.add(x, new VectorCalculationsThread(x, currentEpoch, wordsCounter, vocab.totalWordOccurrences(), linesCounter,  sequencer));
                threads.get(x).start();
            }

            try {
                sequencer.join();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }

            for (int x = 0; x < workers; x++) {
                try {
                    threads.get(x).join();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }

            // TODO: fix this to non-exclusive termination
            if (trainElementsVectors && elementsLearningAlgorithm != null
                    && (!trainSequenceVectors || sequenceLearningAlgorithm == null)
                    && elementsLearningAlgorithm.isEarlyTerminationHit()) {
                break;
            }

            if (trainSequenceVectors && sequenceLearningAlgorithm!= null
                    && (!trainElementsVectors || elementsLearningAlgorithm == null)
                    && sequenceLearningAlgorithm.isEarlyTerminationHit()) {
                break;
            }
            log.info("Epoch: [" + currentEpoch+ "]; Words vectorized so far: [" + wordsCounter.get() + "];  Lines vectorized so far: [" + linesCounter.get() + "]; learningRate: [" + minLearningRate + "]");

            if (eventListeners != null && !eventListeners.isEmpty()) {
                for (VectorsListener listener: eventListeners) {
                    if (listener.validateEvent(ListenerEvent.EPOCH, currentEpoch))
                        listener.processEvent(ListenerEvent.EPOCH, this, currentEpoch);
                }
            }
        }
    }


    protected void trainSequence(@NonNull Sequence sequence, AtomicLong nextRandom, double alpha) {

        if (sequence.getElements().isEmpty()) return;

        if (trainElementsVectors) {
            // call for ElementsLearningAlgorithm
            nextRandom.set(nextRandom.get() * 25214903917L + 11);
            if (!elementsLearningAlgorithm.isEarlyTerminationHit()) scoreElements.set(elementsLearningAlgorithm.learnSequence(sequence, nextRandom, alpha));
        }

        if (trainSequenceVectors)  {
            // call for SequenceLearningAlgorithm
            nextRandom.set(nextRandom.get() * 25214903917L + 11);
            if (!sequenceLearningAlgorithm.isEarlyTerminationHit()) scoreSequences.set(sequenceLearningAlgorithm.learnSequence(sequence, nextRandom, alpha));
        }
    }


    public static class Builder {
        protected VocabCache vocabCache;
        protected WeightLookupTable lookupTable;
        protected SequenceIterator iterator;
        protected ModelUtils modelUtils = new BasicModelUtils<>();

        protected WordVectors existingVectors;

        protected double sampling = 0;
        protected double negative = 0;
        protected double learningRate = 0.025;
        protected double minLearningRate = 0.0001;
        protected int minWordFrequency = 0;
        protected int iterations = 1;
        protected int numEpochs = 1;
        protected int layerSize = 100;
        protected int window = 5;
        protected boolean hugeModelExpected = false;
        protected int batchSize = 100;
        protected int learningRateDecayWords;
        protected long seed;
        protected boolean useAdaGrad = false;
        protected boolean resetModel = true;
        protected int workers = Runtime.getRuntime().availableProcessors();
        protected boolean useUnknown = false;
        protected int[] variableWindows;

        protected boolean trainSequenceVectors = false;
        protected boolean trainElementsVectors = true;

        protected List stopWords = new ArrayList<>();

        protected VectorsConfiguration configuration = new VectorsConfiguration();

        protected transient T unknownElement;
        protected String UNK = configuration.getUNK();
        protected String STOP = configuration.getSTOP();

        // defaults values for learning algorithms are set here
        protected ElementsLearningAlgorithm elementsLearningAlgorithm = new SkipGram<>();
        protected SequenceLearningAlgorithm sequenceLearningAlgorithm = new DBOW<>();

        protected Set> vectorsListeners = new HashSet<>();

        public Builder() {

        }

        public Builder(@NonNull VectorsConfiguration configuration) {
            this.configuration = configuration;
            this.iterations = configuration.getIterations();
            this.numEpochs = configuration.getEpochs();
            this.minLearningRate = configuration.getMinLearningRate();
            this.learningRate = configuration.getLearningRate();
            this.sampling = configuration.getSampling();
            this.negative = configuration.getNegative();
            this.minWordFrequency = configuration.getMinWordFrequency();
            this.seed = configuration.getSeed();
            this.hugeModelExpected = configuration.isHugeModelExpected();
            this.batchSize = configuration.getBatchSize();
            this.layerSize = configuration.getLayersSize();
            this.learningRateDecayWords = configuration.getLearningRateDecayWords();
            this.useAdaGrad = configuration.isUseAdaGrad();
            this.window = configuration.getWindow();
            this.UNK = configuration.getUNK();
            this.STOP = configuration.getSTOP();
            this.variableWindows = configuration.getVariableWindows();

            if (configuration.getElementsLearningAlgorithm() != null && !configuration.getElementsLearningAlgorithm().isEmpty()) {
                this.elementsLearningAlgorithm(configuration.getElementsLearningAlgorithm());
            }

            if (configuration.getSequenceLearningAlgorithm() != null && !configuration.getSequenceLearningAlgorithm().isEmpty()) {
                this.sequenceLearningAlgorithm(configuration.getSequenceLearningAlgorithm());
            }

            if (configuration.getStopList() != null) this.stopWords.addAll(configuration.getStopList());
        }

        /**
         * This method allows you to use pre-built WordVectors model (SkipGram or GloVe) for DBOW sequence learning.
         * Existing model will be transferred into new model before training starts.
         *
         * PLEASE NOTE: This model has no effect for elements learning algorithms. Only sequence learning is affected.
         * PLEASE NOTE: Non-normalized model is recommended to use here.
         *
         * @param vec existing WordVectors model
         * @return
         */
        protected Builder useExistingWordVectors(@NonNull WordVectors vec) {
            this.existingVectors = vec;
            return this;
        }

        /**
         * This method defines SequenceIterator to be used for model building
         * @param iterator
         * @return
         */
        public Builder iterate(@NonNull SequenceIterator iterator) {
            this.iterator = iterator;
            return this;
        }

        /**
         * Sets specific LearningAlgorithm as Sequence Learning Algorithm
         *
         * @param algoName fully qualified class name
         * @return
         */
        public Builder sequenceLearningAlgorithm(@NonNull String algoName) {
            try {
                Class clazz = Class.forName(algoName);
                sequenceLearningAlgorithm = (SequenceLearningAlgorithm) clazz.newInstance();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
            return this;
        }

        /**
         * Sets specific LearningAlgorithm as Sequence Learning Algorithm
         *
         * @param algorithm SequenceLearningAlgorithm implementation
         * @return
         */
        public Builder sequenceLearningAlgorithm(@NonNull SequenceLearningAlgorithm algorithm) {
            this.sequenceLearningAlgorithm = algorithm;
            return this;
        }

        /**
         * * Sets specific LearningAlgorithm as Elements Learning Algorithm
         *
         * @param algoName fully qualified class name
         * @return
         */
        public Builder elementsLearningAlgorithm(@NonNull String algoName) {
            try {
                Class clazz = Class.forName(algoName);
                elementsLearningAlgorithm = (ElementsLearningAlgorithm) clazz.newInstance();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
            return this;
        }

        /**
         * * Sets specific LearningAlgorithm as Elements Learning Algorithm
         *
         * @param algorithm ElementsLearningAlgorithm implementation
         * @return
         */
        public Builder elementsLearningAlgorithm(@NonNull ElementsLearningAlgorithm algorithm) {
            this.elementsLearningAlgorithm = algorithm;
            return this;
        }

        /**
         * This method defines batchSize option, viable only if iterations > 1
         *
         * @param batchSize
         * @return
         */
        public Builder batchSize(int batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        /**
         * This method defines how much iterations should be done over batched sequences.
         *
         * @param iterations
         * @return
         */
        public Builder iterations(int iterations) {
            this.iterations = iterations;
            return this;
        }

        /**
         * This method defines how much iterations should be done over whole training corpus during modelling
         * @param numEpochs
         * @return
         */
        public Builder epochs(int numEpochs) {
            this.numEpochs = numEpochs;
            return this;
        }

        /**
         * Sets number of worker threads to be used in calculations
         *
         * @param numWorkers
         * @return
         */
        public Builder workers(int numWorkers) {
            this.workers = numWorkers;
            return this;
        }

        /**
         * This method defines if Adaptive Gradients should be used in calculations
         *
         * @param reallyUse
         * @return
         */
        public Builder useAdaGrad(boolean reallyUse) {
            this.useAdaGrad = reallyUse;
            return this;
        }

        /**
         * This method defines number of dimensions for outcome vectors.
         * Please note: This option has effect only if lookupTable wasn't defined during building process.
         *
         * @param layerSize
         * @return
         */
        public Builder layerSize(int layerSize) {
            this.layerSize = layerSize;
            return this;
        }

        /**
         * This method defines initial learning rate.
         * Default value is 0.025
         *
         * @param learningRate
         * @return
         */
        public Builder learningRate(double learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        /**
         * This method defines minimal element frequency for elements found in the training corpus. All elements with frequency below this threshold will be removed before training.
         * Please note: this method has effect only if vocabulary is built internally.
         *
         * @param minWordFrequency
         * @return
         */
        public Builder minWordFrequency(int minWordFrequency) {
            this.minWordFrequency = minWordFrequency;
            return this;
        }

        /**
         * This method defines minimum learning rate after decay being applied.
         * Default value is 0.01
         *
         * @param minLearningRate
         * @return
         */
        public Builder minLearningRate(double minLearningRate) {
            this.minLearningRate = minLearningRate;
            return this;
        }

        /**
         * This method defines, should all model be reset before training. If set to true, vocabulary and WeightLookupTable will be reset before training, and will be built from scratches
         *
         * @param reallyReset
         * @return
         */
        public Builder resetModel(boolean reallyReset) {
            this.resetModel = reallyReset;
            return this;
        }

        /**
         * You can pass externally built vocabCache object, containing vocabulary
         *
         * @param vocabCache
         * @return
         */
        public Builder vocabCache(@NonNull VocabCache vocabCache) {
            this.vocabCache = vocabCache;
            return this;
        }

        /**
         * You can pass externally built WeightLookupTable, containing model weights and vocabulary.
         *
         * @param lookupTable
         * @return
         */
        public Builder lookupTable(@NonNull WeightLookupTable lookupTable) {
            this.lookupTable = lookupTable;
            return this;
        }

        /**
         * This method defines sub-sampling threshold.
         *
         * @param sampling
         * @return
         */
        public Builder sampling(double sampling) {
            this.sampling = sampling;
            return this;
        }

        /**
         * This method defines negative sampling value for skip-gram algorithm.
         *
         * @param negative
         * @return
         */
        public Builder negativeSample(double negative) {
            this.negative = negative;
            return this;
        }

        /**
         *  You can provide collection of objects to be ignored, and excluded out of model
         *  Please note: Object labels and hashCode will be used for filtering
         *
         * @param stopList
         * @return
         */
        public Builder stopWords(@NonNull List stopList) {
            this.stopWords.addAll(stopList);
            return this;
        }

        /**
         *
         * @param trainElements
         * @return
         */
        public Builder trainElementsRepresentation(boolean trainElements) {
            this.trainElementsVectors = trainElements;
            return this;
        }

        public Builder trainSequencesRepresentation(boolean trainSequences) {
            this.trainSequenceVectors = trainSequences;
            return this;
        }

        /**
         * You can provide collection of objects to be ignored, and excluded out of model
         * Please note: Object labels and hashCode will be used for filtering
         *
         * @param stopList
         * @return
         */
        public Builder stopWords(@NonNull Collection stopList) {
            for (T word: stopList) {
                this.stopWords.add(word.getLabel());
            }
            return this;
        }

        /**
         * Sets window size for skip-Gram training
         *
         * @param windowSize
         * @return
         */
        public Builder windowSize(int windowSize) {
            this.window = windowSize;
            return this;
        }

        /**
         * Sets seed for random numbers generator.
         * Please note: this has effect only if vocabulary and WeightLookupTable is built internally
         *
         * @param randomSeed
         * @return
         */
        public Builder seed(long randomSeed) {
            // has no effect in original w2v actually
            return this;
        }

        /**
         * ModelUtils implementation, that will be used to access model.
         * Methods like: similarity, wordsNearest, accuracy are provided by user-defined ModelUtils
         *
         * @param modelUtils model utils to be used
         * @return
         */
        public Builder modelUtils(@NonNull ModelUtils modelUtils) {
            this.modelUtils = modelUtils;
            return this;
        }

        /**
         * This method allows you to specify, if UNK word should be used internally
         * @param reallyUse
         * @return
         */
        public Builder useUnknown(boolean reallyUse) {
            this.useUnknown = reallyUse;
            return this;
        }

        /**
         * This method allows you to specify SequenceElement that will be used as UNK element, if UNK is used
         * @param element
         * @return
         */
        public Builder unknownElement(@NonNull T element) {
            this.unknownElement = element;
            this.UNK = element.getLabel();
            return this;
        }

        /**
         * This method allows to use variable window size. In this case, every batch gets processed using one of predefined window sizes
         *
         * @param windows
         * @return
         */
        public Builder useVariableWindow(int... windows) {
            if (windows == null || windows.length == 0)
                throw new IllegalStateException("Variable windows can't be empty");

            variableWindows = windows;

            return this;
        }

        /**
         * This method creates new WeightLookupTable and VocabCache if there were none set
         */
        protected void presetTables() {
            if (lookupTable == null) {

                if (vocabCache == null) {
                    vocabCache = new AbstractCache.Builder()
                            .hugeModelExpected(hugeModelExpected)
                            .scavengerRetentionDelay(this.configuration.getScavengerRetentionDelay())
                            .scavengerThreshold(this.configuration.getScavengerActivationThreshold())
                            .minElementFrequency(minWordFrequency)
                            .build();
                }

                lookupTable = new InMemoryLookupTable.Builder()
                        .useAdaGrad(this.useAdaGrad)
                        .cache(vocabCache)
                        .negative(negative)
                        .vectorLength(layerSize)
                        .lr(learningRate)
                        .seed(seed)
                        .build();
            }

            if (trainElementsVectors && elementsLearningAlgorithm == null) {
                // create default implementation of ElementsLearningAlgorithm
                elementsLearningAlgorithm = new SkipGram<>();
            }

            if (trainSequenceVectors && sequenceLearningAlgorithm == null) {
                sequenceLearningAlgorithm = new DBOW<>();
            }

            this.modelUtils.init(lookupTable);
        }

        /**
         * This method sets VectorsListeners for this SequenceVectors model
         *
         * @param listeners
         * @return
         */
        public Builder setVectorsListeners(@NonNull Collection> listeners) {
            vectorsListeners.addAll(listeners);
            return this;
        }

        /**
         * Build SequenceVectors instance with defined settings/options
         * @return
         */
        public SequenceVectors build() {
            presetTables();

            SequenceVectors vectors = new SequenceVectors<>();

            if (this.existingVectors != null) {
                this.trainElementsVectors = false;
                this.elementsLearningAlgorithm = null;
            }

            vectors.numEpochs = this.numEpochs;
            vectors.numIterations = this.iterations;
            vectors.vocab = this.vocabCache;
            vectors.minWordFrequency = this.minWordFrequency;
            vectors.learningRate.set(this.learningRate);
            vectors.minLearningRate = this.minLearningRate;
            vectors.sampling = this.sampling;
            vectors.negative = this.negative;
            vectors.layerSize = this.layerSize;
            vectors.batchSize = this.batchSize;
            vectors.learningRateDecayWords = this.learningRateDecayWords;
            vectors.window = this.window;
            vectors.resetModel = this.resetModel;
            vectors.useAdeGrad = this.useAdaGrad;
            vectors.stopWords = this.stopWords;
            vectors.workers = this.workers;

            vectors.iterator = this.iterator;
            vectors.lookupTable = this.lookupTable;
            vectors.modelUtils = this.modelUtils;
            vectors.useUnknown = this.useUnknown;
            vectors.unknownElement = this.unknownElement;
            vectors.variableWindows = this.variableWindows;


            vectors.trainElementsVectors = this.trainElementsVectors;
            vectors.trainSequenceVectors = this.trainSequenceVectors;

            vectors.elementsLearningAlgorithm = this.elementsLearningAlgorithm;
            vectors.sequenceLearningAlgorithm = this.sequenceLearningAlgorithm;

            vectors.existingModel = this.existingVectors;

            vectors.setUNK(this.UNK);

            this.configuration.setLearningRate(this.learningRate);
            this.configuration.setLayersSize(layerSize);
            this.configuration.setHugeModelExpected(hugeModelExpected);
            this.configuration.setWindow(window);
            this.configuration.setMinWordFrequency(minWordFrequency);
            this.configuration.setIterations(iterations);
            this.configuration.setSeed(seed);
            this.configuration.setBatchSize(batchSize);
            this.configuration.setLearningRateDecayWords(learningRateDecayWords);
            this.configuration.setMinLearningRate(minLearningRate);
            this.configuration.setSampling(this.sampling);
            this.configuration.setUseAdaGrad(useAdaGrad);
            this.configuration.setNegative(negative);
            this.configuration.setEpochs(this.numEpochs);
            this.configuration.setStopList(this.stopWords);
            this.configuration.setUNK(this.UNK);
            this.configuration.setVariableWindows(variableWindows);

            vectors.configuration = this.configuration;

            return vectors;
        }
    }

    /**
     * This class is used to fetch data from iterator in background thread, and convert it to List
     *
     * It becomes very usefull if text processing pipeline behind iterator is complex, and we're not loading data from simple text file with whitespaces as separator.
     * Since this method allows you to hide preprocessing latency in background.
     *
     * This mechanics will be change to PrefetchingSentenceIterator wrapper.
     */
    protected class AsyncSequencer extends Thread implements Runnable {
        private final SequenceIterator iterator;
        private final LinkedBlockingQueue> buffer;
   //     private final AtomicLong linesCounter;
        private final int limitUpper = 10000;
        private final int limitLower = 5000;
        private AtomicBoolean isRunning = new AtomicBoolean(true);
        private AtomicLong nextRandom;
        private List stopList;

        public AsyncSequencer(SequenceIterator iterator, @NonNull List stopList) {
            this.iterator = iterator;
            this.buffer = new LinkedBlockingQueue<>();
//            this.linesCounter = linesCounter;
            this.setName("AsyncSequencer thread");
            this.nextRandom = new AtomicLong(workers + 1);
            this.iterator.reset();
            this.stopList = stopList;
            this.setDaemon(true);
        }

        @Override
        public void run() {
            isRunning.set(true);
            while (this.iterator.hasMoreSequences()) {

                // if buffered level is below limitLower, we're going to fetch limitUpper number of strings from fetcher
                if (buffer.size() < limitLower) {
                    update();
                    AtomicInteger linesLoaded = new AtomicInteger(0);
                    while (linesLoaded.getAndIncrement() < limitUpper && this.iterator.hasMoreSequences() ) {
                        Sequence document = this.iterator.nextSequence();

                        /*
                            We can't hope/assume that underlying iterator contains synchronized elements
                            That's why we're going to rebuild sequence from vocabulary
                          */
                        Sequence newSequence = new Sequence<>();

                        if (document.getSequenceLabel() != null) {
                            T newLabel = vocab.wordFor(document.getSequenceLabel().getLabel());
                            if (newLabel != null) newSequence.setSequenceLabel(newLabel);
                        }

                        for (T element: document.getElements()) {
                            if (stopList.contains(element.getLabel())) continue;
                            T realElement = vocab.wordFor(element.getLabel());

                            // please note: this serquence element CAN be absent in vocab, due to minFreq or stopWord or whatever else
                            if (realElement != null) {
                                newSequence.addElement(realElement);
                            } else if (useUnknown && unknownElement != null) {
                                newSequence.addElement(unknownElement);
                            }
                        }

                        // due to subsampling and null words, new sequence size CAN be 0, so there's no need to insert empty sequence into processing chain
                        if (!newSequence.getElements().isEmpty()) buffer.add(newSequence);

                        linesLoaded.incrementAndGet();
                    }
                } else {
                    try {
                        Thread.sleep(50);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
            isRunning.set(false);
        }

        public boolean hasMoreLines() {
            // statement order does matter here, since there's possible race condition
            return !buffer.isEmpty() || isRunning.get();
        }

        public Sequence nextSentence() {
            try {
                return buffer.poll(3L, TimeUnit.SECONDS);
            } catch (Exception e) {
                return null;
            }
        }
    }

    /**
     * VectorCalculationsThreads are used for vector calculations, and work together with AsyncIteratorDigitizer.
     * Basically, all they do is just transfer of digitized sentences into math layer.
     *
     * Please note, they do not iterate the sentences over and over, each sentence processed only once.
     * Training corpus iteration is implemented in fit() method.
     *
     */
    private class VectorCalculationsThread extends Thread implements Runnable {
        private final int threadId;
        private final int epochNumber;
        private final AtomicLong wordsCounter;
        private final long totalWordsCount;
        private final AtomicLong totalLines;

        private final AsyncSequencer digitizer;
        private final AtomicLong nextRandom;

        /*
                Long constructors suck, so this should be reduced to something reasonable later
         */
        public VectorCalculationsThread(int threadId, int epoch, AtomicLong wordsCounter, long totalWordsCount, AtomicLong linesCounter, AsyncSequencer digitizer) {
            this.threadId = threadId;
            this.epochNumber = epoch;
            this.wordsCounter = wordsCounter;
            this.totalWordsCount = totalWordsCount;
            this.totalLines = linesCounter;
            this.digitizer = digitizer;
            this.nextRandom = new AtomicLong(this.threadId);
            this.setName("VectorCalculationsThread " + this.threadId);
        }

        @Override
        public void run() {
            while ( digitizer.hasMoreLines()) {
                try {
                    // get current sentence as list of VocabularyWords
                    List> sequences = new ArrayList<>();
                    for (int x = 0; x < batchSize; x++) {
                        if (digitizer.hasMoreLines()) {
                            Sequence sequence = digitizer.nextSentence();
                            if (sequence != null) {
                                sequences.add(sequence);
                            }
                        }
                    }
                    /*
                            TODO: investigate, if fix needed here to become iteration-dependent, not line-position
                      */
                    double alpha = 0.025;

                    if (sequences.isEmpty()) {
                        continue;
                    }

                    // getting back number of iterations
                    for (int i = 0; i < numIterations; i++) {
                        for (int x = 0; x< sequences.size(); x++) {
                            Sequence sequence = sequences.get(x);

                            alpha = Math.max(minLearningRate, learningRate.get() * (1 - (1.0 * this.wordsCounter.get() / (double) this.totalWordsCount)));

                            trainSequence(sequence, nextRandom, alpha);

                            // increment processed word count, please note: this affects learningRate decay
                            totalLines.incrementAndGet();
                            this.wordsCounter.addAndGet(sequence.getElements().size());

                            if (totalLines.get() % 100000 == 0) log.info("Epoch: [" + this.epochNumber+ "]; Words vectorized so far: [" + this.wordsCounter.get() + "];  Lines vectorized so far: [" + this.totalLines.get() + "]; learningRate: [" + alpha + "]");
                            if (eventListeners != null && !eventListeners.isEmpty()) {
                                for (VectorsListener listener: eventListeners) {
                                    if (listener.validateEvent(ListenerEvent.LINE, totalLines.get()))
                                        listener.processEvent(ListenerEvent.LINE, SequenceVectors.this, totalLines.get());
                                }
                            }
                        }

                        if (eventListeners != null && eventListeners.size() > 0) {
                            for (VectorsListener listener: eventListeners) {
                                if (listener.validateEvent(ListenerEvent.ITERATION, i))
                                    listener.processEvent(ListenerEvent.ITERATION, SequenceVectors.this, i);
                            }
                        }
                    }


                } catch (Exception  e) {
                    throw new RuntimeException(e);
                }
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy