All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.glove.count.BinaryCoOccurrenceReader Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.models.glove.count;

import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Binary implementation of CoOccurenceReader interface, used to provide off-memory storage for cooccurrence maps generated for GloVe
 *
 * @author [email protected]
 */
public class BinaryCoOccurrenceReader implements CoOccurenceReader {
    private VocabCache vocabCache;
    private InputStream inputStream;
    private File file;
    private ArrayBlockingQueue> buffer;
    int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
    private StreamReaderThread readerThread;
    private CountMap countMap;


    protected static final Logger logger = LoggerFactory.getLogger(BinaryCoOccurrenceReader.class);

    public BinaryCoOccurrenceReader(@NonNull File file, @NonNull VocabCache vocabCache, CountMap map) {
        this.vocabCache = vocabCache;
        this.file = file;
        this.countMap = map;
        buffer = new ArrayBlockingQueue<>(200000);

        try {
            inputStream = new BufferedInputStream(new FileInputStream(this.file), 100 * 1024 * 1024);
            //inputStream = new BufferedInputStream(new FileInputStream(file), 1024 * 1024);
            readerThread = new StreamReaderThread(inputStream);
            readerThread.start();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public boolean hasMoreObjects() {

        if (!buffer.isEmpty())
            return true;

        try {
            return readerThread.hasMoreObjects() || !buffer.isEmpty();
        } catch (Exception e) {
            throw new RuntimeException(e);
            //return false;
        }
    }

    @Override
    public CoOccurrenceWeight nextObject() {
        if (!buffer.isEmpty()) {
            return buffer.poll();
        } else {
            // buffer can be starved, or we're already at the end of file.
            if (readerThread.hasMoreObjects()) {
                try {
                    return buffer.poll(3, TimeUnit.SECONDS);
                } catch (Exception e) {
                    return null;
                }
            }
        }


        return null;
        /*
        try {
            CoOccurrenceWeight ret = new CoOccurrenceWeight<>();
            ret.setElement1(vocabCache.elementAtIndex(inputStream.readInt()));
            ret.setElement2(vocabCache.elementAtIndex(inputStream.readInt()));
            ret.setWeight(inputStream.readDouble());
        
            return ret;
        } catch (Exception e) {
            return null;
        }
        */
    }

    @Override
    public void finish() {
        try {
            if (inputStream != null)
                inputStream.close();
        } catch (Exception e) {
            //
        }
    }

    private class StreamReaderThread extends Thread implements Runnable {
        private InputStream stream;
        private AtomicBoolean isReading = new AtomicBoolean(false);

        public StreamReaderThread(@NonNull InputStream stream) {
            this.stream = stream;
            isReading.set(false);
        }

        @Override
        public void run() {
            try {
                // we read pre-defined number of objects as byte array
                byte[] array = new byte[16 * 500000];
                while (true) {
                    int count = stream.read(array);

                    isReading.set(true);
                    if (count == 0)
                        break;

                    // now we deserialize them in separate threads to gain some speedup, if possible
                    List threads = new ArrayList<>();
                    AtomicInteger internalPosition = new AtomicInteger(0);

                    for (int t = 0; t < workers; t++) {
                        threads.add(t, new AsyncDeserializationThread(t, array, buffer, internalPosition, count));
                        threads.get(t).start();
                    }

                    // we'll block this cycle untill all objects are fit into queue
                    for (int t = 0; t < workers; t++) {
                        try {
                            threads.get(t).join();
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }

                    isReading.set(false);
                    if (count < array.length)
                        break;
                }

            } catch (Exception e) {
                isReading.set(false);
                throw new RuntimeException(e);
            }
        }

        public boolean hasMoreObjects() {
            try {
                return stream.available() > 0 || isReading.get();
            } catch (Exception e) {
                return false;
            } finally {
            }
        }
    }

    /**
     * Utility class that accepts byte array as input, and deserialize it into set of CoOccurrenceWeight objects
     */
    private class AsyncDeserializationThread extends Thread implements Runnable {
        private int threadId;
        private byte[] arrayReference;
        private ArrayBlockingQueue> targetBuffer;
        private AtomicInteger pointer;
        private int limit;

        public AsyncDeserializationThread(int threadId, @NonNull byte[] array,
                        @NonNull ArrayBlockingQueue> targetBuffer,
                        @NonNull AtomicInteger sharedPointer, int limit) {
            this.threadId = threadId;
            this.arrayReference = array;
            this.targetBuffer = targetBuffer;
            this.pointer = sharedPointer;
            this.limit = limit;


            setName("AsynDeserialization thread " + this.threadId);
        }

        @Override
        public void run() {
            ByteBuffer bB = ByteBuffer.wrap(arrayReference);
            int position = 0;
            while ((position = pointer.getAndAdd(16)) < this.limit) {
                if (position >= limit) {
                    continue;
                }


                int e1idx = bB.getInt(position);
                int e2idx = bB.getInt(position + 4);
                double eW = bB.getDouble(position + 8);


                CoOccurrenceWeight object = new CoOccurrenceWeight<>();
                object.setElement1(vocabCache.elementAtIndex(e1idx));
                object.setElement2(vocabCache.elementAtIndex(e2idx));

                if (countMap != null) {
                    double mW = countMap.getCount(object.getElement1(), object.getElement2());

                    if (mW > 0) {
                        eW += mW;
                        countMap.removePair(object.getElement1(), object.getElement2());
                    }
                }
                object.setWeight(eW);

                try {
                    targetBuffer.put(object);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy