All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.glove.CoOccurrences Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.models.glove;

import akka.actor.ActorRef;
import akka.actor.ActorSystem;
import akka.actor.Props;
import akka.routing.RoundRobinPool;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.CounterMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.glove.actor.CoOccurrenceActor;
import org.deeplearning4j.models.glove.actor.SentenceWork;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.deeplearning4j.text.movingwindow.Util;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

/**
 *
 * Co occurrence counts
 *
 * @author Adam Gibson
 */
public class CoOccurrences implements Serializable {
    private transient  TokenizerFactory tokenizerFactory;
    private transient SentenceIterator sentenceIterator;
    private int windowSize = 15;
    protected transient VocabCache cache;
    protected InvertedIndex index;
    protected transient ActorSystem trainingSystem;
    protected boolean symmetric = true;
    private Counter sentenceOccurrences = Util.parallelCounter();
    private CounterMap coOCurreneCounts = Util.parallelCounterMap();
    private static final Logger log = LoggerFactory.getLogger(CoOccurrences.class);
    private List> coOccurrences;


    private CoOccurrences() {}

    public CoOccurrences(TokenizerFactory tokenizerFactory, SentenceIterator sentenceIterator, int windowSize, VocabCache cache, CounterMap coOCurreneCounts,boolean symmetric) {
        this.tokenizerFactory = tokenizerFactory;
        this.sentenceIterator = sentenceIterator;
        this.windowSize = windowSize;
        this.cache = cache;
        this.coOCurreneCounts = coOCurreneCounts;
        this.symmetric = symmetric;
    }

    /**
     *
     */
    public void fit() {
        if(trainingSystem == null)
            trainingSystem = ActorSystem.create();

        final AtomicInteger processed = new AtomicInteger(0);

        final ActorRef actor = trainingSystem.actorOf(
                new RoundRobinPool(Runtime.getRuntime().availableProcessors()).props(
                        Props.create(
                                CoOccurrenceActor.class,
                                processed,
                                tokenizerFactory,
                                windowSize,
                                cache,
                                coOCurreneCounts,symmetric,sentenceOccurrences)));


        sentenceIterator.reset();

        final AtomicInteger queued = new AtomicInteger(0);
        int id = 0;
        while(sentenceIterator.hasNext()) {
            actor.tell(new SentenceWork(id,sentenceIterator.nextSentence()),actor);
            id++;
            queued.incrementAndGet();
        }


        try {
            Thread.sleep(5000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        while(processed.get() < queued.get()) {
            try {
                Thread.sleep(10000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }

        trainingSystem.shutdown();
        trainingSystem = null;


        log.info("Done processing co occurrences: ended with " + numCoOccurrences());


    }



    public class CoOccurrenceBatchIterator implements Iterator>> {
        private Iterator> iter = coOccurrenceIteratorVocab();
        private int batchSize = 100;

        public CoOccurrenceBatchIterator(int batchSize) {
            this.batchSize = batchSize;
        }

        public CoOccurrenceBatchIterator() {
            this(100);
        }


        @Override
        public boolean hasNext() {
            return iter.hasNext();
        }

        @Override
        public List> next() {
            List> list = new ArrayList<>(batchSize);
            for(int i = 0; i < batchSize; i++) {
                if(!iter.hasNext())
                    break;
                Pair next = iter.next();
                list.add(next);
            }

            return list;
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }


    public class CoOccurrenceIterator implements Iterator> {
        private Iterator> iter = coOccurrenceIterator();

        @Override
        public boolean hasNext() {
            return iter.hasNext();
        }

        @Override
        public Pair next() {
            Pair next = iter.next();
            Pair ret = new Pair<>(cache.wordFor(next.getFirst()),cache.wordFor(next.getSecond()));
            return ret;
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }

    public Iterator>> coOccurrenceIteratorVocabBatch(int batchSize) {
        return new CoOccurrenceBatchIterator(batchSize);
    }
    public Iterator> coOccurrenceIteratorVocab() {
        return new CoOccurrenceIterator();
    }

    /**
     * Load from an input stream with the following format:
     * w1 w2 score
     * @param from the input stream to read from
     * @return the co occurrences based on the input stream
     */
    public static CoOccurrences load(InputStream from) {
        CoOccurrences ret = new CoOccurrences();
        ret.coOccurrences = new ArrayList<>();
        CounterMap counter = new CounterMap<>();
        Reader inputStream = new InputStreamReader(from);
        LineIterator iter = IOUtils.lineIterator(inputStream);
        String line;
        while((iter.hasNext())) {
            line = iter.nextLine();
            String[] split = line.split(" ");
            if(split.length < 3)
                continue;
            //no empty keys
            if(split[0].isEmpty() || split[1].isEmpty())
                continue;

            ret.coOccurrences.add(new Pair<>(split[0],split[1]));

            counter.incrementCount(split[0], split[1], Double.parseDouble(split[2]));

        }

        ret.coOCurreneCounts = counter;
        return ret;

    }

    public Counter getSentenceOccurrences() {
        return sentenceOccurrences;
    }

    public void setSentenceOccurrences(Counter sentenceOccurrences) {
        this.sentenceOccurrences = sentenceOccurrences;
    }

    /**
     * Return a list of all of the co occurrences
     * @return a list of all of the co occurrences
     */
    public List> coOccurrenceList() {
        if(coOccurrences != null)
            return coOccurrences;
        Iterator> pairIter = coOccurrenceIterator();
        final List> pairList = new ArrayList<>();

        while(pairIter.hasNext())
            pairList.add(pairIter.next());
        return pairList;

    }

    /**
     * Return a randomized list of the co occurrences
     * @return
     */
    public List> randomizedList() {
        List> coOccurrences = coOccurrenceList();
        Collections.shuffle(coOccurrences);
        return coOccurrences;
    }

    /**
     * The number of co occurrences
     * @return
     */
    public int numCoOccurrences() {
        return coOCurreneCounts.totalSize();
    }


    public double count(String w1,String w2) {
        return coOCurreneCounts.getCount(w1, w2);
    }


    /**
     * Get an iterator over all possible non zero
     * co occurrences
     * @return the iterator
     */
    public Iterator> coOccurrenceIterator() {
        return coOCurreneCounts.getPairIterator();
    }

    public CounterMap getCoOCurreneCounts() {
        return coOCurreneCounts;
    }

    public void setCoOCurreneCounts(CounterMap coOCurreneCounts) {
        this.coOCurreneCounts = coOCurreneCounts;
    }

    public static class Builder {
        private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
        private SentenceIterator sentenceIterator;
        private int windowSize = 15;
        private VocabCache cache;
        private CounterMap coOCurreneCounts = Util.parallelCounterMap();
        private boolean symmetric = true;


        public Builder symmetric(boolean symmetric) {
            this.symmetric = symmetric;
            return this;
        }

        public Builder tokenizer(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder iterate(SentenceIterator sentenceIterator) {
            this.sentenceIterator = sentenceIterator;
            return this;
        }

        public Builder windowSize(int windowSize) {
            this.windowSize = windowSize;
            return this;
        }

        public Builder cache(VocabCache cache) {
            this.cache = cache;
            return this;
        }

        public Builder coOCurreneCounts(CounterMap coOCurreneCounts) {
            this.coOCurreneCounts = coOCurreneCounts;
            return this;
        }

        public CoOccurrences build() {
            if(cache == null)
                throw new IllegalArgumentException("Vocab cache must not be null!");

            if(sentenceIterator == null)
                throw new IllegalArgumentException("Sentence iterator must not be null");

            return new CoOccurrences(tokenizerFactory, sentenceIterator, windowSize, cache, coOCurreneCounts,symmetric);
        }
    }




}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy