org.deeplearning4j.models.glove.count.BinaryCoOccurrenceReader Maven / Gradle / Ivy
package org.deeplearning4j.models.glove.count;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Binary implementation of CoOccurenceReader interface, used to provide off-memory storage for cooccurrence maps generated for GloVe
*
* @author [email protected]
*/
public class BinaryCoOccurrenceReader implements CoOccurenceReader {
private VocabCache vocabCache;
private InputStream inputStream;
private File file;
private ArrayBlockingQueue> buffer;
int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
private StreamReaderThread readerThread;
private CountMap countMap;
protected static final Logger logger = LoggerFactory.getLogger(BinaryCoOccurrenceReader.class);
public BinaryCoOccurrenceReader(@NonNull File file, @NonNull VocabCache vocabCache, CountMap map) {
this.vocabCache = vocabCache;
this.file = file;
this.countMap = map;
buffer = new ArrayBlockingQueue<>(200000);
try {
inputStream = new BufferedInputStream(new FileInputStream(this.file), 100 * 1024 * 1024);
//inputStream = new BufferedInputStream(new FileInputStream(file), 1024 * 1024);
readerThread = new StreamReaderThread(inputStream);
readerThread.start();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public boolean hasMoreObjects() {
if (!buffer.isEmpty())
return true;
try {
return readerThread.hasMoreObjects() || !buffer.isEmpty();
} catch (Exception e) {
throw new RuntimeException(e);
//return false;
}
}
@Override
public CoOccurrenceWeight nextObject() {
if (!buffer.isEmpty()) {
return buffer.poll();
} else {
// buffer can be starved, or we're already at the end of file.
if (readerThread.hasMoreObjects()) {
try {
return buffer.poll(3, TimeUnit.SECONDS);
} catch (Exception e) {
return null;
}
}
}
return null;
/*
try {
CoOccurrenceWeight ret = new CoOccurrenceWeight<>();
ret.setElement1(vocabCache.elementAtIndex(inputStream.readInt()));
ret.setElement2(vocabCache.elementAtIndex(inputStream.readInt()));
ret.setWeight(inputStream.readDouble());
return ret;
} catch (Exception e) {
return null;
}
*/
}
@Override
public void finish() {
try {
if (inputStream != null)
inputStream.close();
} catch (Exception e) {
//
}
}
private class StreamReaderThread extends Thread implements Runnable {
private InputStream stream;
private AtomicBoolean isReading = new AtomicBoolean(false);
public StreamReaderThread(@NonNull InputStream stream) {
this.stream = stream;
isReading.set(false);
}
@Override
public void run() {
try {
// we read pre-defined number of objects as byte array
byte[] array = new byte[16 * 500000];
while (true) {
int count = stream.read(array);
isReading.set(true);
if (count == 0)
break;
// now we deserialize them in separate threads to gain some speedup, if possible
List threads = new ArrayList<>();
AtomicInteger internalPosition = new AtomicInteger(0);
for (int t = 0; t < workers; t++) {
threads.add(t, new AsyncDeserializationThread(t, array, buffer, internalPosition, count));
threads.get(t).start();
}
// we'll block this cycle untill all objects are fit into queue
for (int t = 0; t < workers; t++) {
try {
threads.get(t).join();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
isReading.set(false);
if (count < array.length)
break;
}
} catch (Exception e) {
isReading.set(false);
throw new RuntimeException(e);
}
}
public boolean hasMoreObjects() {
try {
return stream.available() > 0 || isReading.get();
} catch (Exception e) {
return false;
} finally {
}
}
}
/**
* Utility class that accepts byte array as input, and deserialize it into set of CoOccurrenceWeight objects
*/
private class AsyncDeserializationThread extends Thread implements Runnable {
private int threadId;
private byte[] arrayReference;
private ArrayBlockingQueue> targetBuffer;
private AtomicInteger pointer;
private int limit;
public AsyncDeserializationThread(int threadId, @NonNull byte[] array,
@NonNull ArrayBlockingQueue> targetBuffer,
@NonNull AtomicInteger sharedPointer, int limit) {
this.threadId = threadId;
this.arrayReference = array;
this.targetBuffer = targetBuffer;
this.pointer = sharedPointer;
this.limit = limit;
setName("AsynDeserialization thread " + this.threadId);
}
@Override
public void run() {
ByteBuffer bB = ByteBuffer.wrap(arrayReference);
int position = 0;
while ((position = pointer.getAndAdd(16)) < this.limit) {
if (position >= limit) {
continue;
}
int e1idx = bB.getInt(position);
int e2idx = bB.getInt(position + 4);
double eW = bB.getDouble(position + 8);
CoOccurrenceWeight object = new CoOccurrenceWeight<>();
object.setElement1(vocabCache.elementAtIndex(e1idx));
object.setElement2(vocabCache.elementAtIndex(e2idx));
if (countMap != null) {
double mW = countMap.getCount(object.getElement1(), object.getElement2());
if (mW > 0) {
eW += mW;
countMap.removePair(object.getElement1(), object.getElement2());
}
}
object.setWeight(eW);
try {
targetBuffer.put(object);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
}