org.nd4j.linalg.dataset.BalanceMinibatches Maven / Gradle / Ivy
package org.nd4j.linalg.dataset;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* Auto balance mini batches by label.
* @author Adam Gibson
*/
@AllArgsConstructor
@Builder
@Data
public class BalanceMinibatches {
private DataSetIterator dataSetIterator;
private int numLabels;
private Map> paths = Maps.newHashMap();
private int miniBatchSize = -1;
private File rootDir = new File("minibatches");
private File rootSaveDir = new File("minibatchessave");
private List labelRootDirs = new ArrayList<>();
private DataNormalization dataNormalization;
/**
* Generate a balanced
* dataset minibatch fileset.
*/
public void balance() {
if(!rootDir.exists())
rootDir.mkdirs();
if(!rootSaveDir.exists())
rootDir.mkdirs();
if(paths == null)
paths = Maps.newHashMap();
if(labelRootDirs == null)
labelRootDirs = Lists.newArrayList();
for(int i = 0; i < numLabels; i++) {
paths.put(i,new ArrayList());
labelRootDirs.add(new File(rootDir,String.valueOf(i)));
}
//lay out each example in their respective label directories tracking the paths along the way
while(dataSetIterator.hasNext()) {
DataSet next = dataSetIterator.next();
//infer minibatch size from iterator
if(miniBatchSize < 0)
miniBatchSize = next.numExamples();
for(int i = 0; i < next.numExamples(); i++) {
DataSet currExample = next.get(i);
if(!labelRootDirs.get(currExample.outcome()).exists())
labelRootDirs.get(currExample.outcome()).mkdirs();
//individual example will be saved to: labelrootdir/examples.size()
File example = new File(labelRootDirs.get(currExample.outcome()),String.valueOf(paths.get(currExample.outcome()).size()));
currExample.save(example);
paths.get(currExample.outcome()).add(example);
}
}
int numsSaved = 0;
//loop till all file paths have been removed
while(!paths.isEmpty()) {
List miniBatch = new ArrayList<>();
while(miniBatch.size() < miniBatchSize && !paths.isEmpty()) {
for(int i = 0; i < numLabels; i++) {
if(paths.get(i) != null && !paths.get(i).isEmpty()) {
DataSet d = new DataSet();
d.load(paths.get(i).remove(0));
miniBatch.add(d);
}
else
paths.remove(i);
}
}
if(!rootSaveDir.exists())
rootSaveDir.mkdirs();
//save with an incremental count of the number of minibatches saved
if(!miniBatch.isEmpty()) {
DataSet merge = DataSet.merge(miniBatch);
if(dataNormalization != null)
dataNormalization.transform(merge);
merge.save(new File(rootSaveDir,String.format("dataset-%d.bin",numsSaved++)));
}
}
}
}