org.nd4j.linalg.dataset.BalanceMinibatches Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.dataset;
import org.nd4j.shade.guava.collect.Lists;
import org.nd4j.shade.guava.collect.Maps;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* Auto balance mini batches by label.
* @author Adam Gibson
*/
@AllArgsConstructor
@Builder
@Data
public class BalanceMinibatches {
private DataSetIterator dataSetIterator;
private int numLabels;
private Map> paths = Maps.newHashMap();
private int miniBatchSize = -1;
private File rootDir = new File("minibatches");
private File rootSaveDir = new File("minibatchessave");
private List labelRootDirs = new ArrayList<>();
private DataNormalization dataNormalization;
/**
* Generate a balanced
* dataset minibatch fileset.
*/
public void balance() {
if (!rootDir.exists())
rootDir.mkdirs();
if (!rootSaveDir.exists())
rootSaveDir.mkdirs();
if (paths == null)
paths = Maps.newHashMap();
if (labelRootDirs == null)
labelRootDirs = Lists.newArrayList();
for (int i = 0; i < numLabels; i++) {
paths.put(i, new ArrayList());
labelRootDirs.add(new File(rootDir, String.valueOf(i)));
}
//lay out each example in their respective label directories tracking the paths along the way
while (dataSetIterator.hasNext()) {
DataSet next = dataSetIterator.next();
//infer minibatch size from iterator
if (miniBatchSize < 0)
miniBatchSize = next.numExamples();
for (int i = 0; i < next.numExamples(); i++) {
DataSet currExample = next.get(i);
if (!labelRootDirs.get(currExample.outcome()).exists())
labelRootDirs.get(currExample.outcome()).mkdirs();
//individual example will be saved to: labelrootdir/examples.size()
File example = new File(labelRootDirs.get(currExample.outcome()),
String.valueOf(paths.get(currExample.outcome()).size()));
currExample.save(example);
paths.get(currExample.outcome()).add(example);
}
}
int numsSaved = 0;
//loop till all file paths have been removed
while (!paths.isEmpty()) {
List miniBatch = new ArrayList<>();
while (miniBatch.size() < miniBatchSize && !paths.isEmpty()) {
for (int i = 0; i < numLabels; i++) {
if (paths.get(i) != null && !paths.get(i).isEmpty()) {
DataSet d = new DataSet();
d.load(paths.get(i).remove(0));
miniBatch.add(d);
} else
paths.remove(i);
}
}
if (!rootSaveDir.exists())
rootSaveDir.mkdirs();
//save with an incremental count of the number of minibatches saved
if (!miniBatch.isEmpty()) {
DataSet merge = DataSet.merge(miniBatch);
if (dataNormalization != null)
dataNormalization.transform(merge);
merge.save(new File(rootSaveDir, String.format("dataset-%d.bin", numsSaved++)));
}
}
}
}