Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.deeplearning4j.text.invertedindex.LuceneInvertedIndex Maven / Gradle / Ivy
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.text.invertedindex;
import com.google.common.base.Function;
import org.apache.commons.io.FileUtils;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.*;
import org.apache.lucene.search.*;
import org.apache.lucene.store.*;
import org.apache.lucene.util.Bits;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.stopwords.StopWords;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* Lucene based inverted index
*
* @author Adam Gibson
*/
public class LuceneInvertedIndex implements InvertedIndex, IndexReader.ReaderClosedListener,
Iterator> {
private transient Directory dir;
private transient IndexReader reader;
private transient Analyzer analyzer;
private VocabCache vocabCache;
public final static String WORD_FIELD = "word";
public final static String LABEL = "label";
private int numDocs = 0; // FIXME this is never used
private AtomicBoolean indexBeingCreated = new AtomicBoolean(false);
private static final Logger log = LoggerFactory.getLogger(LuceneInvertedIndex.class);
public final static String INDEX_PATH = "word2vec-index";
private AtomicBoolean readerClosed = new AtomicBoolean(false);
private AtomicInteger totalWords = new AtomicInteger(0);
private int batchSize = 1000;
private List> miniBatches = new CopyOnWriteArrayList<>();
private List currMiniBatch = Collections.synchronizedList(new ArrayList());
private double sample = 0;
private AtomicLong nextRandom = new AtomicLong(5);
private String indexPath = INDEX_PATH;
private Queue> miniBatchDocs = new ConcurrentLinkedDeque<>(); // FIXME this is never used
private AtomicBoolean miniBatchGoing = new AtomicBoolean(true);
private boolean miniBatch = false;
public final static String DEFAULT_INDEX_DIR = "word2vec-index";
private transient SearcherManager searcherManager;
private transient ReaderManager readerManager;
private transient TrackingIndexWriter indexWriter;
private transient LockFactory lockFactory;
public LuceneInvertedIndex(VocabCache vocabCache,boolean cache) {
this(vocabCache,cache,DEFAULT_INDEX_DIR);
}
public LuceneInvertedIndex(VocabCache vocabCache,boolean cache,String indexPath) { // FIXME cache is not used
this.vocabCache = vocabCache;
this.indexPath = indexPath;
if(new File(indexPath).exists()) {
indexPath = UUID.randomUUID().toString();
log.warn("Changing index path to" + indexPath);
}
initReader();
}
public LuceneInvertedIndex(VocabCache vocabCache) {
this(vocabCache,false,DEFAULT_INDEX_DIR);
}
@Override
public Iterator>> batchIter(int batchSize) {
return new BatchDocIter(batchSize);
}
@Override
public Iterator> docs() {
return new DocIter();
}
@Override
public void unlock() {
if(lockFactory == null)
lockFactory = NativeFSLockFactory.getDefault();
}
@Override
public void cleanup() {
try {
indexWriter.deleteAll();
indexWriter.getIndexWriter().commit();
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public double sample() {
return sample;
}
@Override
public Iterator> miniBatches() {
return this;
}
@Override
public synchronized List document(int index) {
List ret = new CopyOnWriteArrayList<>();
try {
DirectoryReader reader = readerManager.acquire();
Document doc = reader.document(index);
reader.close();
String[] values = doc.getValues(WORD_FIELD);
for(String s : values) {
VocabWord word = vocabCache.wordFor(s);
if(word != null)
ret.add(vocabCache.wordFor(s));
}
}
catch (Exception e) {
e.printStackTrace();
}
return ret;
}
@Override
public int[] documents(VocabWord vocabWord) {
try {
TermQuery query = new TermQuery(new Term(WORD_FIELD,vocabWord.getWord().toLowerCase()));
searcherManager.maybeRefresh();
IndexSearcher searcher = searcherManager.acquire();
TopDocs topdocs = searcher.search(query,Integer.MAX_VALUE);
int[] ret = new int[topdocs.totalHits];
for(int i = 0; i < topdocs.totalHits; i++) {
ret[i] = topdocs.scoreDocs[i].doc;
}
searcherManager.release(searcher);
return ret;
}
catch(AlreadyClosedException e) {
return documents(vocabWord);
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public int numDocuments() {
try {
readerManager.maybeRefresh();
DirectoryReader reader = readerManager.acquire();
int ret = reader.numDocs();
readerManager.release(reader);
return ret;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public int[] allDocs() {
DirectoryReader reader;
try {
readerManager.maybeRefreshBlocking();
reader = readerManager.acquire();
} catch (IOException e) {
throw new RuntimeException(e);
}
int[] docIds = new int[reader.maxDoc()];
if(docIds.length < 1)
throw new IllegalStateException("No documents found");
int count = 0;
Bits liveDocs = MultiFields.getLiveDocs(reader);
for(int i = 0; i < reader.maxDoc(); i++) {
if (liveDocs != null && !liveDocs.get(i))
continue;
if(count > docIds.length) {
int[] newCopy = new int[docIds.length * 2];
System.arraycopy(docIds,0,newCopy,0,docIds.length);
docIds = newCopy;
log.info("Reallocating doc ids");
}
docIds[count++] = i;
}
try {
reader.close();
} catch (IOException e) {
e.printStackTrace();
}
return docIds;
}
@Override
public void addWordToDoc(int doc, VocabWord word) {
Field f = new TextField(WORD_FIELD,word.getWord(),Field.Store.YES);
try {
IndexSearcher searcher = searcherManager.acquire();
Document doc2 = searcher.doc(doc);
if(doc2 != null)
doc2.add(f);
else {
Document d = new Document();
d.add(f);
}
searcherManager.release(searcher);
} catch (IOException e) {
e.printStackTrace();
}
}
private void initReader() {
if(reader == null) {
try {
ensureDirExists();
if(getWriter() == null) {
this.indexWriter = null;
while(getWriter() == null) {
log.warn("Writer was null...reinitializing");
Thread.sleep(1000);
}
}
IndexWriter writer = getWriter().getIndexWriter();
if(writer == null)
throw new IllegalStateException("index writer was null");
searcherManager = new SearcherManager(writer,true,new SearcherFactory());
writer.commit();
readerManager = new ReaderManager(dir);
DirectoryReader reader = readerManager.acquire();
numDocs = readerManager.acquire().numDocs();
readerManager.release(reader);
}catch (IOException | InterruptedException e) {
throw new RuntimeException(e);
}
}
}
@Override
public void addWordsToDoc(int doc,final List words) {
Document d = new Document();
for (VocabWord word : words)
d.add(new TextField(WORD_FIELD, word.getWord(), Field.Store.YES));
totalWords.set(totalWords.get() + words.size());
addWords(words);
try {
getWriter().addDocument(d);
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public Pair, String> documentWithLabel(int index) {
List ret = new CopyOnWriteArrayList<>();
String label = "NONE";
try {
DirectoryReader reader = readerManager.acquire();
Document doc = reader.document(index);
reader.close();
String[] values = doc.getValues(WORD_FIELD);
label = doc.get(LABEL);
if(label == null)
label = "NONE";
for(String s : values) {
ret.add(vocabCache.wordFor(s));
}
}
catch (Exception e) {
e.printStackTrace();
}
return new Pair<>(ret,label);
}
@Override
public Pair, Collection> documentWithLabels(int index) {
List ret = new CopyOnWriteArrayList<>();
Collection labels = new ArrayList<>();
try {
DirectoryReader reader = readerManager.acquire();
Document doc = reader.document(index);
readerManager.release(reader);
String[] values = doc.getValues(WORD_FIELD);
String[] labels2 = doc.getValues(LABEL);
for(String s : values) {
ret.add(vocabCache.wordFor(s));
}
Collections.addAll(labels, labels2);
}
catch (Exception e) {
e.printStackTrace();
}
return new Pair<>(ret,labels);
}
@Override
public void addLabelForDoc(int doc, VocabWord word) {
addLabelForDoc(doc,word.getWord());
}
@Override
public void addLabelForDoc(int doc, String label) {
try {
DirectoryReader reader = readerManager.acquire();
Document doc2 = reader.document(doc);
doc2.add(new TextField(LABEL,label,Field.Store.YES));
readerManager.release(reader);
TrackingIndexWriter writer = getWriter();
Term term = new Term(LABEL,label);
writer.updateDocument(term,doc2);
} catch (Exception e) {
e.printStackTrace();
}
}
@Override
public void addWordsToDoc(int doc, List words, String label) {
Document d = new Document();
for (VocabWord word : words)
d.add(new TextField(WORD_FIELD, word.getWord(), Field.Store.YES));
d.add(new TextField(LABEL,label,Field.Store.YES));
totalWords.set(totalWords.get() + words.size());
addWords(words);
try {
getWriter().addDocument(d);
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void addWordsToDoc(int doc, List words, VocabWord label) {
addWordsToDoc(doc,words,label.getWord());
}
@Override
public void addLabelsForDoc(int doc, List label) {
try {
DirectoryReader reader = readerManager.acquire();
Document doc2 = reader.document(doc);
for(VocabWord s : label)
doc2.add(new TextField(LABEL,s.getWord(),Field.Store.YES));
readerManager.release(reader);
TrackingIndexWriter writer = getWriter();
List terms = new ArrayList<>();
for(VocabWord s : label) {
Term term = new Term(LABEL,s.getWord());
terms.add(term);
}
writer.addDocument(doc2);
} catch (Exception e) {
e.printStackTrace();
}
}
@Override
public void addLabelsForDoc(int doc, Collection label) {
try {
DirectoryReader reader = readerManager.acquire();
Document doc2 = reader.document(doc);
for(String s : label)
doc2.add(new TextField(LABEL,s,Field.Store.YES));
readerManager.release(reader);
TrackingIndexWriter writer = getWriter();
List terms = new ArrayList<>();
for(String s : label) {
Term term = new Term(LABEL,s);
terms.add(term);
}
writer.addDocument(doc2);
} catch (Exception e) {
e.printStackTrace();
}
}
@Override
public void addWordsToDoc(int doc, List words, Collection label) {
Document d = new Document();
for (VocabWord word : words)
d.add(new TextField(WORD_FIELD, word.getWord(), Field.Store.YES));
for(String s : label)
d.add(new TextField(LABEL,s,Field.Store.YES));
totalWords.set(totalWords.get() + words.size());
addWords(words);
try {
getWriter().addDocument(d);
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void addWordsToDocVocabWord(int doc, List words, Collection label) {
Document d = new Document();
for (VocabWord word : words)
d.add(new TextField(WORD_FIELD, word.getWord(), Field.Store.YES));
for(VocabWord s : label)
d.add(new TextField(LABEL,s.getWord(),Field.Store.YES));
totalWords.set(totalWords.get() + words.size());
addWords(words);
try {
getWriter().addDocument(d);
} catch (IOException e) {
e.printStackTrace();
}
}
private void addWords(List words) {
if(!miniBatch)
return;
for (VocabWord word : words) {
// The subsampling randomly discards frequent words while keeping the ranking same
if (sample > 0) {
double ran = (Math.sqrt(word.getWordFrequency() / (sample * numDocuments())) + 1)
* (sample * numDocuments()) / word.getWordFrequency();
if (ran < (nextRandom.get() & 0xFFFF) / (double) 65536) {
continue;
}
currMiniBatch.add(word);
} else {
currMiniBatch.add(word);
if (currMiniBatch.size() >= batchSize) {
miniBatches.add(new ArrayList<>(currMiniBatch));
currMiniBatch.clear();
}
}
}
}
private void ensureDirExists() throws IOException {
if (dir == null) {
File directory = new File(indexPath);
log.info("Creating directory " + directory.getCanonicalPath());
FileUtils.deleteDirectory(directory);
dir = FSDirectory.open(Paths.get(URI.create("file:" + directory.getAbsolutePath())));
if (!directory.exists()) {
directory.mkdir();
}
}
}
private synchronized TrackingIndexWriter getWriterWithRetry() {
if(this.indexWriter != null)
return this.indexWriter;
IndexWriterConfig iwc;
IndexWriter writer;
try {
if(analyzer == null)
analyzer = new StandardAnalyzer(new InputStreamReader(new ByteArrayInputStream("".getBytes())));
ensureDirExists();
if(this.indexWriter == null) {
indexBeingCreated.set(true);
iwc = new IndexWriterConfig(analyzer);
iwc.setOpenMode(IndexWriterConfig.OpenMode.CREATE);
iwc.setWriteLockTimeout(1000);
log.info("Creating new index writer");
while((writer = tryCreateWriter(iwc)) == null) {
log.warn("Failed to create writer...trying again");
iwc = new IndexWriterConfig(analyzer);
iwc.setOpenMode(IndexWriterConfig.OpenMode.CREATE);
iwc.setWriteLockTimeout(1000);
Thread.sleep(10000);
}
this.indexWriter = new TrackingIndexWriter(writer);
}
}
catch(Exception e) {
throw new IllegalStateException(e);
}
return this.indexWriter;
}
private IndexWriter tryCreateWriter(IndexWriterConfig iwc) {
try {
dir.close();
dir = null;
FileUtils.deleteDirectory(new File(indexPath));
ensureDirExists();
if(lockFactory == null)
lockFactory = NativeFSLockFactory.getDefault();
return new IndexWriter(dir,iwc);
}
catch (Exception e) {
String id = UUID.randomUUID().toString();
indexPath = id;
log.warn("Setting index path to " + id);
log.warn("Couldn't create index ",e);
return null;
}
}
private synchronized TrackingIndexWriter getWriter() {
int attempts = 0;
while(getWriterWithRetry() == null && attempts < 3) {
try {
Thread.sleep(1000 * attempts);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
if(attempts >= 3)
throw new IllegalStateException("Can't obtain write lock");
attempts++;
}
return this.indexWriter;
}
@Override
public void finish() {
try {
initReader();
DirectoryReader reader = readerManager.acquire();
numDocs = reader.numDocs();
readerManager.release(reader);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public long totalWords() {
return totalWords.get();
}
@Override
public int batchSize() {
return batchSize;
}
@Override
public void eachDocWithLabels(final Function, Collection>, Void> func, ExecutorService exec) {
int[] docIds = allDocs();
for(int i : docIds) {
final int j = i;
exec.execute(new Runnable() {
@Override
public void run() {
func.apply(documentWithLabels(j));
}
});
}
}
@Override
public void eachDocWithLabel(final Function, String>, Void> func, ExecutorService exec) {
int[] docIds = allDocs();
for(int i : docIds) {
final int j = i;
exec.execute(new Runnable() {
@Override
public void run() {
func.apply(documentWithLabel(j));
}
});
}
}
@Override
public void eachDoc(final Function, Void> func,ExecutorService exec) {
int[] docIds = allDocs();
for(int i : docIds) {
final int j = i;
exec.execute(new Runnable() {
@Override
public void run() {
func.apply(document(j));
}
});
}
}
@Override
public void onClose(IndexReader reader) {
readerClosed.set(true);
}
@Override
public boolean hasNext() {
if(!miniBatch)
throw new IllegalStateException("Mini batch mode turned off");
return !miniBatchDocs.isEmpty() ||
miniBatchGoing.get();
}
@Override
public List next() {
if(!miniBatch)
throw new IllegalStateException("Mini batch mode turned off");
if(!miniBatches.isEmpty())
return miniBatches.remove(0);
else if(miniBatchGoing.get()) {
while(miniBatches.isEmpty()) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
log.warn("Waiting on more data...");
if(!miniBatches.isEmpty())
return miniBatches.remove(0);
}
}
return null;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
public class BatchDocIter implements Iterator>> {
private int batchSize = 1000;
private Iterator> docIter = new DocIter();
public BatchDocIter(int batchSize) {
this.batchSize = batchSize;
}
@Override
public boolean hasNext() {
return docIter.hasNext();
}
@Override
public List> next() {
List> ret = new ArrayList<>();
for(int i = 0; i < batchSize; i++) {
if(!docIter.hasNext())
break;
ret.add(docIter.next());
}
return ret;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
}
public class DocIter implements Iterator> {
private int currIndex = 0;
private int[] docs = allDocs();
@Override
public boolean hasNext() {
return currIndex < docs.length;
}
@Override
public List next() {
return document(docs[currIndex++]);
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
}
public static class Builder {
private File indexDir;
private Directory dir;
private IndexReader reader;
private Analyzer analyzer;
private IndexSearcher searcher;
private IndexWriter writer;
private IndexWriterConfig iwc = new IndexWriterConfig(analyzer);
private VocabCache vocabCache;
private List stopWords = StopWords.getStopWords();
private boolean cache = false;
private int batchSize = 1000;
private double sample = 0;
private boolean miniBatch = false;
public Builder miniBatch(boolean miniBatch) {
this.miniBatch = miniBatch;
return this;
}
public Builder cacheInRam(boolean cache) {
this.cache = cache;
return this;
}
public Builder sample(double sample) {
this.sample = sample;
return this;
}
public Builder batchSize(int batchSize) {
this.batchSize = batchSize;
return this;
}
public Builder indexDir(File indexDir) {
this.indexDir = indexDir;
return this;
}
public Builder cache(VocabCache cache) {
this.vocabCache = cache;
return this;
}
public Builder stopWords(List stopWords) {
this.stopWords = stopWords;
return this;
}
public Builder dir(Directory dir) {
this.dir = dir;
return this;
}
public Builder reader(IndexReader reader) {
this.reader = reader;
return this;
}
public Builder writer(IndexWriter writer) {
this.writer = writer;
return this;
}
public Builder analyzer(Analyzer analyzer) {
this.analyzer = analyzer;
return this;
}
public InvertedIndex build() {
LuceneInvertedIndex ret;
if(indexDir != null) {
ret = new LuceneInvertedIndex(vocabCache,cache,indexDir.getAbsolutePath());
}
else
ret = new LuceneInvertedIndex(vocabCache);
try {
ret.batchSize = batchSize;
if(dir != null)
ret.dir = dir;
ret.miniBatch = miniBatch;
if(reader != null)
ret.reader = reader;
if(analyzer != null)
ret.analyzer = analyzer;
}catch(Exception e) {
throw new RuntimeException(e);
}
return ret;
}
}
}