org.deeplearning4j.iterator.provider.FileLabeledSentenceProvider Maven / Gradle / Ivy
package org.deeplearning4j.iterator.provider;
import lombok.NonNull;
import org.apache.commons.io.FileUtils;
import org.datavec.api.util.RandomUtils;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.nd4j.linalg.collection.CompactHeapStringList;
import java.io.File;
import java.io.IOException;
import java.util.*;
/**
* Iterate over a set of sentences/documents, where the sentences are to be loaded (as required) from the provided files.
*
* @author Alex Black
*/
public class FileLabeledSentenceProvider implements LabeledSentenceProvider {
private final int totalCount;
private final List filePaths;
private final int[] fileLabelIndexes;
private final Random rng;
private final int[] order;
private final List allLabels;
private int cursor = 0;
/**
* @param filesByLabel Key: label. Value: list of files for that label
*/
public FileLabeledSentenceProvider(Map> filesByLabel) {
this(filesByLabel, new Random());
}
/**
*
* @param filesByLabel Key: label. Value: list of files for that label
* @param rng Random number generator. May be null.
*/
public FileLabeledSentenceProvider(@NonNull Map> filesByLabel, Random rng) {
int totalCount = 0;
for (List l : filesByLabel.values()) {
totalCount += l.size();
}
this.totalCount = totalCount;
this.rng = rng;
if (rng == null) {
order = null;
} else {
order = new int[totalCount];
for (int i = 0; i < totalCount; i++) {
order[i] = i;
}
RandomUtils.shuffleInPlace(order, rng);
}
allLabels = new ArrayList<>(filesByLabel.keySet());
Collections.sort(allLabels);
Map labelsToIdx = new HashMap<>();
for (int i = 0; i < allLabels.size(); i++) {
labelsToIdx.put(allLabels.get(i), i);
}
filePaths = new CompactHeapStringList();
fileLabelIndexes = new int[totalCount];
int position = 0;
for (Map.Entry> entry : filesByLabel.entrySet()) {
int labelIdx = labelsToIdx.get(entry.getKey());
for (File f : entry.getValue()) {
filePaths.add(f.getPath());
fileLabelIndexes[position] = labelIdx;
position++;
}
}
}
@Override
public boolean hasNext() {
return cursor < totalCount;
}
@Override
public Pair nextSentence() {
int idx;
if (rng == null) {
idx = cursor++;
} else {
idx = order[cursor++];
}
File f = new File(filePaths.get(idx));
String label = allLabels.get(fileLabelIndexes[idx]);
String sentence;
try {
sentence = FileUtils.readFileToString(f);
} catch (IOException e) {
throw new RuntimeException(e);
}
return new Pair<>(sentence, label);
}
@Override
public void reset() {
cursor = 0;
if (rng != null) {
RandomUtils.shuffleInPlace(order, rng);
}
}
@Override
public int totalNumSentences() {
return totalCount;
}
@Override
public List allLabels() {
return allLabels;
}
@Override
public int numLabelClasses() {
return allLabels.size();
}
}