All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.glove.AbstractCoOccurrences Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.models.glove;

import lombok.NonNull;
import org.deeplearning4j.models.glove.count.*;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.iterators.FilteredSequenceIterator;
import org.deeplearning4j.models.sequencevectors.iterators.SynchronizedSequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator;
import org.deeplearning4j.util.DL4JFileUtils;
import org.deeplearning4j.util.ThreadUtils;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;

/**
 * This class implements building cooccurrence map for abstract training corpus.
 * However it's performance rather low, due to exsessive IO that happens in ShadowCopyThread
 *
 * PLEASE NOTE: Current implementation involves massive IO, and it should be rewritter as soon as ND4j gets sparse arrays support
 *
 * @author [email protected]
 */
public class AbstractCoOccurrences implements Serializable {

    protected boolean symmetric;
    protected int windowSize;
    protected VocabCache vocabCache;
    protected SequenceIterator sequenceIterator;

    // please note, we need enough room for ShadowCopy thread, that's why -1 there
    protected int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);

    // target file, where text with cooccurrencies should be saved
    protected File targetFile;

    protected ReentrantReadWriteLock lock = new ReentrantReadWriteLock();

    protected long memory_threshold = 0;

    private ShadowCopyThread shadowThread;

    //    private Counter sentenceOccurrences = Util.parallelCounter();
    //private CounterMap coOccurrenceCounts = Util.parallelCounterMap();
    private volatile CountMap coOccurrenceCounts = new CountMap<>();
    //private Counter occurrenceAllocations = Util.parallelCounter();
    //private List> coOccurrences;
    private AtomicLong processedSequences = new AtomicLong(0);


    protected static final Logger logger = LoggerFactory.getLogger(AbstractCoOccurrences.class);

    // this method should be private, to avoid non-configured instantiation
    private AbstractCoOccurrences() {}

    /**
     * This method returns cooccurrence distance weights for two SequenceElements
     *
     * @param element1
     * @param element2
     * @return distance weight
     */
    public double getCoOccurrenceCount(@NonNull T element1, @NonNull T element2) {
        return coOccurrenceCounts.getCount(element1, element2);
    }

    /**
     * This method returns estimated memory footrpint, based on current CountMap content
     * @return
     */
    protected long getMemoryFootprint() {
        // TODO: implement this method. It should return approx. memory used by appropriate CountMap
        try {
            lock.readLock().lock();
            return ((long) coOccurrenceCounts.size()) * 24L * 5L;
        } finally {
            lock.readLock().unlock();
        }
    }

    /**
     * This memory returns memory threshold, defined as 1/2 of memory allowed for allocation
     * @return
     */
    protected long getMemoryThreshold() {
        return memory_threshold / 2L;
    }

    public void fit() {
        shadowThread = new ShadowCopyThread();
        shadowThread.start();

        // we should reset iterator before counting cooccurrences
        sequenceIterator.reset();

        List threads = new ArrayList<>();
        for (int x = 0; x < workers; x++) {
            threads.add(x, new CoOccurrencesCalculatorThread(x, new FilteredSequenceIterator<>(
                            new SynchronizedSequenceIterator<>(sequenceIterator), vocabCache), processedSequences));
            threads.get(x).start();
        }

        for (int x = 0; x < workers; x++) {
            try {
                threads.get(x).join();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        shadowThread.finish();
        logger.info("CoOccurrences map was built.");
    }

    /**
     *
     *  This method returns iterator with elements pairs and their weights. Resulting iterator is safe to use in multi-threaded environment.
     *
     * Developer's note: thread safety on received iterator is delegated to PrefetchedSentenceIterator
     * @return
     */
    public Iterator, Double>> iterator() {
        final SentenceIterator iterator;

        try {
            iterator = new SynchronizedSentenceIterator(
                            new PrefetchingSentenceIterator.Builder(new BasicLineIterator(targetFile))
                                            .setFetchSize(500000).build());

        } catch (Exception e) {
            logger.error("Target file was not found on last stage!");
            throw new RuntimeException(e);
        }
        return new Iterator, Double>>() {
            /*
                    iterator should be built on top of current text file with all pairs
             */

            @Override
            public boolean hasNext() {
                return iterator.hasNext();
            }

            @Override
            public Pair, Double> next() {
                String line = iterator.nextSentence();
                String[] strings = line.split(" ");

                T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0]));
                T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1]));
                Double weight = Double.valueOf(strings[2]);

                return new Pair<>(new Pair<>(element1, element2), weight);
            }

            @Override
            public void remove() {
                throw new UnsupportedOperationException("remove() method can't be supported on read-only interface");
            }
        };
    }

    public static class Builder {

        protected boolean symmetric;
        protected int windowSize = 5;
        protected VocabCache vocabCache;
        protected SequenceIterator sequenceIterator;
        protected int workers = Runtime.getRuntime().availableProcessors();
        protected File target;
        protected long maxmemory = Runtime.getRuntime().maxMemory();

        public Builder() {

        }

        public Builder symmetric(boolean reallySymmetric) {
            this.symmetric = reallySymmetric;
            return this;
        }

        public Builder windowSize(int windowSize) {
            this.windowSize = windowSize;
            return this;
        }

        public Builder vocabCache(@NonNull VocabCache cache) {
            this.vocabCache = cache;
            return this;
        }

        public Builder iterate(@NonNull SequenceIterator iterator) {
            this.sequenceIterator = new SynchronizedSequenceIterator<>(iterator);
            return this;
        }

        public Builder workers(int numWorkers) {
            this.workers = numWorkers;
            return this;
        }

        /**
         * This method allows you to specify maximum memory available for CoOccurrence map builder.
         *
         * Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm.
         * Please note: this option won't override -Xmx JVM value.
         *
         * @param gbytes memory available, in GigaBytes
         * @return
         */
        public Builder maxMemory(int gbytes) {
            if (gbytes > 0) {
                this.maxmemory = Math.max(gbytes - 1, 1) * 1024 * 1024 * 1024L;
            }

            return this;
        }

        /**
         * Path to save cooccurrence map after construction.
         * If targetFile is not specified, temporary file will be used.
         *
         * @param path
         * @return
         */
        public Builder targetFile(@NonNull String path) {
            this.targetFile(new File(path));
            return this;
        }

        /**
         * Path to save cooccurrence map after construction.
         * If targetFile is not specified, temporary file will be used.
         *
         * @param file
         * @return
         */
        public Builder targetFile(@NonNull File file) {
            this.target = file;
            return this;
        }

        public AbstractCoOccurrences build() {
            AbstractCoOccurrences ret = new AbstractCoOccurrences<>();
            ret.sequenceIterator = this.sequenceIterator;
            ret.windowSize = this.windowSize;
            ret.vocabCache = this.vocabCache;
            ret.symmetric = this.symmetric;
            ret.workers = this.workers;

            if (this.maxmemory < 1) {
                this.maxmemory = Runtime.getRuntime().maxMemory();
            }
            ret.memory_threshold = this.maxmemory;


            logger.info("Actual memory limit: [" + this.maxmemory + "]");

            // use temp file, if no target file was specified
            try {
                if (this.target == null) {
                    this.target = DL4JFileUtils.createTempFile("cooccurrence", "map");
                }
                this.target.deleteOnExit();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }

            ret.targetFile = this.target;

            return ret;
        }
    }

    private class CoOccurrencesCalculatorThread extends Thread implements Runnable {

        private final SequenceIterator iterator;
        private final AtomicLong sequenceCounter;
        private int threadId;

        public CoOccurrencesCalculatorThread(int threadId, @NonNull SequenceIterator iterator,
                        @NonNull AtomicLong sequenceCounter) {
            this.iterator = iterator;
            this.sequenceCounter = sequenceCounter;
            this.threadId = threadId;

            this.setName("CoOccurrencesCalculatorThread " + threadId);
        }

        @Override
        public void run() {
            while (iterator.hasMoreSequences()) {
                Sequence sequence = iterator.nextSequence();

                List tokens = new ArrayList<>(sequence.asLabels());
                //            logger.info("Tokens size: " + tokens.size());
                for (int x = 0; x < sequence.getElements().size(); x++) {
                    int wordIdx = vocabCache.indexOf(tokens.get(x));
                    if (wordIdx < 0) {
                        continue;
                    }
                    String w1 = vocabCache.wordFor(tokens.get(x)).getLabel();

                    // THIS iS SAFE TO REMOVE, NO CHANCE WE'll HAVE UNK WORD INSIDE SEQUENCE
                    /*if(w1.equals(Glove.UNK))
                        continue;
                    */

                    int windowStop = Math.min(x + windowSize + 1, tokens.size());
                    for (int j = x; j < windowStop; j++) {
                        int otherWord = vocabCache.indexOf(tokens.get(j));
                        if (otherWord < 0) {
                            continue;
                        }
                        String w2 = vocabCache.wordFor(tokens.get(j)).getLabel();

                        if (w2.equals(Glove.DEFAULT_UNK) || otherWord == wordIdx) {
                            continue;
                        }


                        T tokenX = vocabCache.wordFor(tokens.get(x));
                        T tokenJ = vocabCache.wordFor(tokens.get(j));
                        double nWeight = 1.0 / (j - x + Nd4j.EPS_THRESHOLD);

                        while (getMemoryFootprint() >= getMemoryThreshold()) {
                            shadowThread.invoke();
                            /*lock.readLock().lock();
                            int size = coOccurrenceCounts.size();
                            lock.readLock().unlock();
                            */
                            if (threadId == 0) {
                                logger.debug("Memory consuimption > threshold: {footrpint: [" + getMemoryFootprint()
                                                + "], threshold: [" + getMemoryThreshold() + "] }");
                            }
                            ThreadUtils.uncheckedSleep(10000);
                        }
                        /*
                        if (getMemoryFootprint() == 0) {
                            logger.info("Zero size!");
                        }
                        */

                        try {
                            lock.readLock().lock();
                            if (wordIdx < otherWord) {
                                coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight);
                                if (symmetric) {
                                    coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight);
                                }
                            } else {
                                coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight);

                                if (symmetric) {
                                    coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight);
                                }
                            }
                        } finally {
                            lock.readLock().unlock();
                        }
                    }
                }

                sequenceCounter.incrementAndGet();
            }
        }
    }

    /**
     * This class is designed to provide shadow copy functionality for CoOccurence maps, since with proper corpus size you can't fit such a map into memory
     *
     */
    private class ShadowCopyThread extends Thread implements Runnable {

        private AtomicBoolean isFinished = new AtomicBoolean(false);
        private AtomicBoolean isTerminate = new AtomicBoolean(false);
        private AtomicBoolean isInvoked = new AtomicBoolean(false);
        private AtomicBoolean shouldInvoke = new AtomicBoolean(false);

        // file that contains resuts from previous runs
        private File[] tempFiles;
        private RoundCount counter;

        public ShadowCopyThread() {
            try {

                counter = new RoundCount(1);
                tempFiles = new File[2];

                tempFiles[0] = DL4JFileUtils.createTempFile("aco", "tmp");
                tempFiles[1] = DL4JFileUtils.createTempFile("aco", "tmp");

                tempFiles[0].deleteOnExit();
                tempFiles[1].deleteOnExit();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }

            this.setName("ACO ShadowCopy thread");
        }

        @Override
        public void run() {
            /*
                  Basic idea is pretty simple: run quetly, untill memory gets filled up to some high volume.
                  As soon as this happens - execute shadow copy.
            */
            while (!isFinished.get() && !isTerminate.get()) {
                // check used memory. if memory use below threshold - sleep for a while. if above threshold - invoke copier

                if (getMemoryFootprint() > getMemoryThreshold() || (shouldInvoke.get() && !isInvoked.get())) {
                    // we'll just invoke copier, nothing else
                    shouldInvoke.compareAndSet(true, false);
                    invokeBlocking();
                } else {
                    /*
                           commented and left here for future debugging purposes, if needed
                    
                            //lock.readLock().lock();
                            //int size = coOccurrenceCounts.size();
                            //lock.readLock().unlock();
                            //logger.info("Current memory situation: {size: [" +size+ "], footprint: [" + getMemoryFootprint()+"], threshold: ["+ getMemoryThreshold() +"]}");
                     */
                    ThreadUtils.uncheckedSleep(1000);
                }
            }
        }

        /**
         * This methods advises shadow copy process to start
         */
        public void invoke() {
            shouldInvoke.compareAndSet(false, true);
        }

        /**
         * This methods dumps cooccurrence map into save file.
         * Please note: this method is synchronized and will block, until complete
         */
        public synchronized void invokeBlocking() {
            if (getMemoryFootprint() < getMemoryThreshold() && !isFinished.get()) {
                return;
            }

            int numberOfLinesSaved = 0;

            isInvoked.set(true);

            logger.debug("Memory purge started.");

            /*
                Basic plan:
                    1. Open temp file
                    2. Read that file line by line
                    3. For each read line do synchronization in memory > new file direction
             */

            counter.tick();

            CountMap localMap;
            try {
                // in any given moment there's going to be only 1 WriteLock, due to invokeBlocking() being synchronized call
                lock.writeLock().lock();



                // obtain local copy of CountMap
                localMap = coOccurrenceCounts;

                // set new CountMap, and release write lock
                coOccurrenceCounts = new CountMap<>();
            } finally {
                lock.writeLock().unlock();
            }

            try {

                File file = null;
                if (!isFinished.get()) {
                    file = tempFiles[counter.previous()];
                } else
                    file = targetFile;


                //    PrintWriter pw = new PrintWriter(file);

                int linesRead = 0;

                logger.debug("Saving to: [" + counter.get() + "], Reading from: [" + counter.previous() + "]");
                CoOccurenceReader reader =
                                new BinaryCoOccurrenceReader<>(tempFiles[counter.previous()], vocabCache, localMap);
                CoOccurrenceWriter writer = (isFinished.get()) ? new ASCIICoOccurrenceWriter(targetFile)
                                : new BinaryCoOccurrenceWriter(tempFiles[counter.get()]);
                while (reader.hasMoreObjects()) {
                    CoOccurrenceWeight line = reader.nextObject();

                    if (line != null) {
                        writer.writeObject(line);
                        numberOfLinesSaved++;
                        linesRead++;
                    }
                }
                reader.finish();

                logger.debug("Lines read: [" + linesRead + "]");

                //now, we can dump the rest of elements, which were not presented in existing dump
                Iterator> iterator = localMap.getPairIterator();
                while (iterator.hasNext()) {
                    Pair pair = iterator.next();
                    double mWeight = localMap.getCount(pair);
                    CoOccurrenceWeight object = new CoOccurrenceWeight<>();
                    object.setElement1(pair.getFirst());
                    object.setElement2(pair.getSecond());
                    object.setWeight(mWeight);

                    writer.writeObject(object);

                    numberOfLinesSaved++;
                    //      if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
                }

                writer.finish();

                /*
                SentenceIterator sIterator =  new PrefetchingSentenceIterator.Builder(new BasicLineIterator(tempFiles[counter.get()]))
                        .setFetchSize(500000)
                        .build();
                
                
                int linesRead = 0;
                while (sIterator.hasNext()) {
                    //List list = new ArrayList<>(reader.next());
                    String sentence = sIterator.nextSentence();
                    if (sentence == null || sentence.isEmpty()) continue;
                    String[] strings = sentence.split(" ");
                
                
                    // first two elements are integers - vocab indexes
                    //T element1 = vocabCache.wordFor(vocabCache.wordAtIndex(list.get(0).toInt()));
                    //T element2 = vocabCache.wordFor(vocabCache.wordAtIndex(list.get(1).toInt()));
                    T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0]));
                    T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1]));
                
                    // getting third element, previously stored weight
                    double sWeight = Double.valueOf(strings[2]);  // list.get(2).toDouble();
                
                    // now, since we have both elements ready, we can check this pair against inmemory map
                        double mWeight = localMap.getCount(element1, element2);
                        if (mWeight <= 0) {
                            // this means we have no such pair in memory, so we'll do nothing to sWeight
                        } else {
                            // since we have new weight value in memory, we should update sWeight value before moving it off memory
                            sWeight += mWeight;
                
                            // original pair can be safely removed from CountMap
                            localMap.removePair(element1,element2);
                        }
                
                        StringBuilder builder = new StringBuilder().append(element1.getIndex()).append(" ").append(element2.getIndex()).append(" ").append(sWeight);
                        pw.println(builder.toString());
                        numberOfLinesSaved++;
                        linesRead++;
                
                   // if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
                  //  if (linesRead % 100000 == 0) logger.info("Lines read: [" + linesRead +"]");
                }
                */
                /*
                logger.info("Lines read: [" + linesRead + "]");
                
                //now, we can dump the rest of elements, which were not presented in existing dump
                Iterator> iterator = localMap.getPairIterator();
                while (iterator.hasNext()) {
                    Pair pair = iterator.next();
                    double mWeight = localMap.getCount(pair);
                
                    StringBuilder builder = new StringBuilder().append(pair.getFirst().getIndex()).append(" ").append(pair.getFirst().getIndex()).append(" ").append(mWeight);
                    pw.println(builder.toString());
                    numberOfLinesSaved++;
                
                              //      if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
                }
                
                pw.flush();
                pw.close();
                
                */

                // just a hint for gc
                localMap = null;
                //sIterator.finish();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }

            logger.info("Number of word pairs saved so far: [" + numberOfLinesSaved + "]");
            isInvoked.set(false);
        }

        /**
         * This method provides soft finish ability for shadow copy process.
         * Please note: it's blocking call, since it requires for final merge.
         */
        public void finish() {
            if (this.isFinished.get()) {
                return;
            }

            this.isFinished.set(true);
            invokeBlocking();
        }

        /**
         * This method provides hard fiinish ability for shadow copy process
         */
        public void terminate() {
            this.isTerminate.set(true);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy