All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.glove.Glove Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.models.glove;

import akka.actor.ActorSystem;
import com.google.common.collect.Lists;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.bagofwords.vectorizer.TextVectorizer;
import org.deeplearning4j.bagofwords.vectorizer.TfidfVectorizer;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.parallel.Parallelization;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.deeplearning4j.text.invertedindex.LuceneInvertedIndex;
import org.deeplearning4j.text.movingwindow.Util;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.stopwords.StopWords;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Glove by socher et. al
 *
 * @author Adam Gibson
 */
public class Glove  extends WordVectorsImpl {

    private transient SentenceIterator sentenceIterator;
    private transient TextVectorizer textVectorizer;
    private transient TokenizerFactory tokenizerFactory;
    private double learningRate = 0.05;
    private double xMax = 0.75;
    private int windowSize = 15;
    private CoOccurrences coOccurrences;
    private boolean stem = false;
    protected Queue>>> jobQueue = new LinkedBlockingDeque<>();
    private int batchSize = 1000;
    private int minWordFrequency = 5;
    private double maxCount = 100;
    public final static String UNK = Word2Vec.UNK;
    private int iterations = 5;
    private static final Logger log = LoggerFactory.getLogger(Glove.class);
    private boolean symmetric = true;
    private transient Random gen;
    private boolean shuffle = true;
    private transient Random shuffleRandom;
    private int numWorkers = Runtime.getRuntime().availableProcessors();

    private Glove(){}

    public Glove(VocabCache cache, SentenceIterator sentenceIterator, TextVectorizer textVectorizer, TokenizerFactory tokenizerFactory, GloveWeightLookupTable lookupTable, int layerSize, double learningRate, double xMax, int windowSize, CoOccurrences coOccurrences, List stopWords, boolean stem,int batchSize,int minWordFrequency,double maxCount,int iterations,boolean symmetric,Random gen,boolean shuffle,long seed,int numWorkers) {
        this.numWorkers = numWorkers;
        this.gen = gen;
        this.vocab = cache;
        this.layerSize = layerSize;
        this.shuffle = shuffle;
        this.sentenceIterator = sentenceIterator;
        this.textVectorizer = textVectorizer;
        this.tokenizerFactory = tokenizerFactory;
        this.lookupTable = lookupTable;
        this.learningRate = learningRate;
        this.xMax = xMax;
        this.windowSize = windowSize;
        this.coOccurrences = coOccurrences;
        this.stopWords = stopWords;
        this.stem = stem;
        this.batchSize = batchSize;
        this.minWordFrequency = minWordFrequency;
        this.maxCount = maxCount;
        this.iterations = iterations;
        this.symmetric = symmetric;
        shuffleRandom = Nd4j.getRandom();
    }

    public void fit() {
        boolean cacheFresh = false;

        if(vocab() == null) {
            cacheFresh  = true;
            setVocab(new InMemoryLookupCache());
        }

        if(textVectorizer == null && cacheFresh) {
            InvertedIndex index = new LuceneInvertedIndex(vocab(),false,"glove-index");
            textVectorizer = new TfidfVectorizer.Builder().tokenize(tokenizerFactory).index(index)
                    .cache(vocab()).iterate(sentenceIterator).minWords(minWordFrequency)
                    .stopWords(stopWords).stem(stem).build();

            textVectorizer.fit();
        }

        if(sentenceIterator != null)
            sentenceIterator.reset();

        if(coOccurrences == null) {
            coOccurrences = new CoOccurrences.Builder()
                    .cache(vocab()).iterate(sentenceIterator).symmetric(symmetric)
                    .tokenizer(tokenizerFactory).windowSize(windowSize)
                    .build();

            coOccurrences.fit();

        }

        if(lookupTable == null) {
            lookupTable = new GloveWeightLookupTable.Builder()
                    .cache(textVectorizer.vocab()).lr(learningRate)
                    .vectorLength(layerSize).maxCount(maxCount)
                   .build();
        }


        if(lookupTable().getSyn0() == null)
            lookupTable().resetWeights();
        final List> pairList = coOccurrences.coOccurrenceList();
        if(shuffle)
            Collections.shuffle(pairList,new java.util.Random());



        final AtomicInteger countUp = new AtomicInteger(0);
        final Counter errorPerIteration = Util.parallelCounter();
        log.info("Processing # of co occurrences " + coOccurrences.numCoOccurrences());
        for(int i = 0; i < iterations; i++) {
            final AtomicInteger processed = new AtomicInteger(coOccurrences.numCoOccurrences());
            doIteration(i, pairList, errorPerIteration, processed, countUp);
            log.info("Processed " + countUp.doubleValue() + " out of " + (pairList.size() * iterations) + " error was " + errorPerIteration.getCount(i));

        }


    }


    public void doIteration(final int i,List> pairList, final Counter errorPerIteration,final AtomicInteger processed,final AtomicInteger countUp) {
        log.info("Iteration " + i);
        if(shuffle)
            Collections.shuffle(pairList,new java.util.Random());
        List>> miniBatches = Lists.partition(pairList,batchSize);
        ActorSystem actor = ActorSystem.create();
        Parallelization.iterateInParallel(miniBatches,new Parallelization.RunnableWithParams>>() {
            @Override
            public void run(List> currentItem, Object[] args) {
                List> send = new ArrayList<>();
                for (Pair next : currentItem) {
                    String w1 = next.getFirst();
                    String w2 = next.getSecond();
                    VocabWord vocabWord = vocab().wordFor(w1);
                    VocabWord vocabWord1 = vocab().wordFor(w2);
                    send.add(new Pair<>(vocabWord, vocabWord1));

                }

                jobQueue.add(new Pair<>(i, send));
            }
        },actor);



        actor.shutdown();

        Parallelization.runInParallel(numWorkers,new Runnable() {
            @Override
            public void run() {
                while(processed.get() > 0 || !jobQueue.isEmpty()) {
                    Pair>> work = jobQueue.poll();
                    if(work == null)
                        continue;
                    List> batch = work.getSecond();

                    for(Pair pair : batch) {
                        VocabWord w1 = pair.getFirst();
                        VocabWord w2 = pair.getSecond();
                        double weight = getCount(w1.getWord(),w2.getWord());
                        if(weight <= 0) {
                            countUp.incrementAndGet();
                            processed.decrementAndGet();
                            continue;

                        }
                        errorPerIteration.incrementCount(work.getFirst(),lookupTable().iterateSample(w1, w2, weight));
                        countUp.incrementAndGet();
                        if(countUp.get() % 10000 == 0)
                            log.info("Processed " + countUp.get() + " co occurrences");
                        processed.decrementAndGet();
                    }




                }
            }
        },true);
    }


    /**
     * Load a glove model from an input stream.
     * The format is:
     * word num1 num2....
     * @param is the input stream to read from for the weights
     * @param biases the bias input stream
     * @return the loaded model
     * @throws IOException if one occurs
     */
    public static Glove load(InputStream is,InputStream biases) throws IOException {
        LineIterator iter = IOUtils.lineIterator(is,"UTF-8");
        Glove glove = new Glove();
        Map wordVectors = new HashMap<>();
        int count = 0;
        while(iter.hasNext()) {
            String line = iter.nextLine().trim();
            if(line.isEmpty())
                continue;
            String[] split = line.split(" ");
            String word = split[0];
            if(glove.vocab() == null)
                glove.setVocab(new InMemoryLookupCache());

            if(glove.lookupTable() == null) {
                glove.lookupTable = new GloveWeightLookupTable.Builder()
                        .cache(glove.vocab()).vectorLength(split.length - 1)
                        .build();

            }

            if(word.isEmpty())
                continue;
            float[] read = read(split,glove.lookupTable().getVectorLength());
            if(read.length < 1)
                continue;

            VocabWord w1 = new VocabWord(1,word);
            w1.setIndex(count);
            glove.vocab().addToken(w1);
            glove.vocab().addWordToIndex(count, word);
            glove.vocab().putVocabWord(word);
            wordVectors.put(word,read);
            count++;



        }

        glove.lookupTable().setSyn0(weights(glove, wordVectors));



        iter.close();

        glove.lookupTable().setBias(Nd4j.read(biases));

        return glove;

    }




    private static INDArray weights(Glove glove,Map data) {
        INDArray ret = Nd4j.create(data.size(),glove.lookupTable().getVectorLength());
        for(String key : data.keySet()) {
            INDArray row = Nd4j.create(Nd4j.createBuffer(data.get(key)));
            if(row.length() != glove.lookupTable().getVectorLength())
                continue;
            if(glove.vocab().indexOf(key) >= data.size())
                continue;
            ret.putRow(glove.vocab().indexOf(key), row);
        }
        return ret;
    }


    private static float[] read(String[] split,int length) {
        float[] ret = new float[length];
        for(int i = 1; i < split.length; i++) {
            ret[i - 1] = Float.parseFloat(split[i]);
        }
        return ret;
    }


    public double getCount(String w1,String w2) {
        return coOccurrences.getCoOCurreneCounts().getCount(w1,w2);
    }

    public CoOccurrences getCoOccurrences() {
        return coOccurrences;
    }

    public void setCoOccurrences(CoOccurrences coOccurrences) {
        this.coOccurrences = coOccurrences;
    }







    @Override
    public GloveWeightLookupTable lookupTable() {
        return (GloveWeightLookupTable) lookupTable;
    }

    public void setLookupTable(GloveWeightLookupTable lookupTable) {
        this.lookupTable = lookupTable;
    }

    public static class Builder {
        private VocabCache vocabCache;
        private SentenceIterator sentenceIterator;
        private TextVectorizer textVectorizer;
        private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
        private GloveWeightLookupTable weightLookupTable;
        private int layerSize = 300;
        private double learningRate = 0.05;
        private double xMax = 0.75;
        private int windowSize = 5;
        private CoOccurrences coOccurrences;
        private List stopWords = StopWords.getStopWords();
        private boolean stem = false;
        private int batchSize = 100;
        private int minWordFrequency = 5;
        private double maxCount = 100;
        private int iterations = 5;
        private boolean symmetric = true;
        private boolean shuffle = true;
        private long seed = 123;
        private int numWorkers = Runtime.getRuntime().availableProcessors();
        private org.nd4j.linalg.api.rng.Random gen = Nd4j.getRandom();


        public Builder numWorkers(int numWorkers) {
            this.numWorkers = numWorkers;
            return this;
        }

        public Builder seed(long seed) {
            this.seed = seed;
            return this;
        }

        public Builder shuffle(boolean shuffle) {
            this.shuffle = shuffle;
            return this;
        }
        public Builder rng(org.nd4j.linalg.api.rng.Random gen) {
            this.gen = gen;
            return this;
        }

        public Builder symmetric(boolean symmetric) {
            this.symmetric = symmetric;
            return this;
        }

        public Builder iterations(int iterations) {
            this.iterations = iterations;
            return this;
        }

        public Builder maxCount(double maxCount) {
            this.maxCount = maxCount;
            return this;
        }

        public Builder minWordFrequency(int minWordFrequency) {
            this.minWordFrequency = minWordFrequency;
            return this;
        }

        public Builder cache(VocabCache vocabCache) {
            this.vocabCache = vocabCache;
            return this;
        }

        public Builder iterate(SentenceIterator sentenceIterator) {
            this.sentenceIterator = sentenceIterator;
            return this;
        }

        public Builder vectorizer(TextVectorizer textVectorizer) {
            this.textVectorizer = textVectorizer;
            return this;
        }

        public Builder tokenizer(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder weights(GloveWeightLookupTable weightLookupTable) {
            this.weightLookupTable = weightLookupTable;
            return this;
        }

        public Builder layerSize(int layerSize) {
            this.layerSize = layerSize;
            return this;
        }

        public Builder learningRate(double learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        public Builder xMax(double xMax) {
            this.xMax = xMax;
            return this;
        }

        public Builder windowSize(int windowSize) {
            this.windowSize = windowSize;
            return this;
        }

        public Builder coOccurrences(CoOccurrences coOccurrences) {
            this.coOccurrences = coOccurrences;
            return this;
        }

        public Builder stopWords(List stopWords) {
            this.stopWords = stopWords;
            return this;
        }

        public Builder stem(boolean stem) {
            this.stem = stem;
            return this;
        }

        public Builder batchSize(int batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        public Glove build() {
            return new Glove(vocabCache, sentenceIterator, textVectorizer, tokenizerFactory, weightLookupTable, layerSize, learningRate, xMax, windowSize, coOccurrences, stopWords, stem, batchSize,minWordFrequency,maxCount,iterations,symmetric,gen,shuffle,seed,numWorkers);
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy