All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.dataset.BalanceMinibatches Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.linalg.dataset;

import org.nd4j.shade.guava.collect.Lists;
import org.nd4j.shade.guava.collect.Maps;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;

import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * Auto balance mini batches by label.
 * @author Adam Gibson
 */
@AllArgsConstructor
@Builder
@Data
public class BalanceMinibatches {
    private DataSetIterator dataSetIterator;
    private int numLabels;
    private Map> paths = Maps.newHashMap();
    private int miniBatchSize = -1;
    private File rootDir = new File("minibatches");
    private File rootSaveDir = new File("minibatchessave");
    private List labelRootDirs = new ArrayList<>();
    private DataNormalization dataNormalization;

    /**
     * Generate a balanced
     * dataset minibatch fileset.
     */
    public void balance() {
        if (!rootDir.exists())
            rootDir.mkdirs();
        if (!rootSaveDir.exists())
            rootSaveDir.mkdirs();

        if (paths == null)
            paths = Maps.newHashMap();
        if (labelRootDirs == null)
            labelRootDirs = Lists.newArrayList();

        for (int i = 0; i < numLabels; i++) {
            paths.put(i, new ArrayList());
            labelRootDirs.add(new File(rootDir, String.valueOf(i)));
        }


        //lay out each example in their respective label directories tracking the paths along the way
        while (dataSetIterator.hasNext()) {
            DataSet next = dataSetIterator.next();
            //infer minibatch size from iterator
            if (miniBatchSize < 0)
                miniBatchSize = next.numExamples();
            for (int i = 0; i < next.numExamples(); i++) {
                DataSet currExample = next.get(i);
                if (!labelRootDirs.get(currExample.outcome()).exists())
                    labelRootDirs.get(currExample.outcome()).mkdirs();

                //individual example will be saved to: labelrootdir/examples.size()
                File example = new File(labelRootDirs.get(currExample.outcome()),
                                String.valueOf(paths.get(currExample.outcome()).size()));
                currExample.save(example);
                paths.get(currExample.outcome()).add(example);
            }
        }

        int numsSaved = 0;
        //loop till all file paths have been removed
        while (!paths.isEmpty()) {
            List miniBatch = new ArrayList<>();
            while (miniBatch.size() < miniBatchSize && !paths.isEmpty()) {
                for (int i = 0; i < numLabels; i++) {
                    if (paths.get(i) != null && !paths.get(i).isEmpty()) {
                        DataSet d = new DataSet();
                        d.load(paths.get(i).remove(0));
                        miniBatch.add(d);
                    } else
                        paths.remove(i);
                }
            }

            if (!rootSaveDir.exists())
                rootSaveDir.mkdirs();
            //save with an incremental count of the number of minibatches saved
            if (!miniBatch.isEmpty()) {
                DataSet merge = DataSet.merge(miniBatch);
                if (dataNormalization != null)
                    dataNormalization.transform(merge);
                merge.save(new File(rootSaveDir, String.format("dataset-%d.bin", numsSaved++)));
            }


        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy