All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.dataset.MiniBatchFileDataSetIterator Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.dataset;

import org.apache.commons.io.FileUtils;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;

/**
 * Mini batch file datasetiterator
 * auto partitions a dataset in to mini batches
 */
public class MiniBatchFileDataSetIterator implements DataSetIterator {
    private int batchSize;
    private List paths;
    private int currIdx;
    private File rootDir;
    private int totalExamples;
    private int totalLabels;
    private int totalBatches = -1;
    private DataSetPreProcessor dataSetPreProcessor;



    /**
     *
     * @param baseData the base dataset
     * @param batchSize the batch size to split by
     * @throws IOException
     */
    public MiniBatchFileDataSetIterator(DataSet baseData,int batchSize) throws IOException {
        this(baseData,batchSize,true);

    }
    /**
     *
     * @param baseData the base dataset
     * @param batchSize the batch size to split by
     * @throws IOException
     */
    public MiniBatchFileDataSetIterator(DataSet baseData,int batchSize,boolean delete,File rootDir) throws IOException {
        this.batchSize = batchSize;
        rootDir = new File(rootDir,UUID.randomUUID().toString());
        rootDir.mkdirs();
        if(delete)
            Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
                @Override
                public void run() {
                    try {
                        FileUtils.deleteDirectory(MiniBatchFileDataSetIterator.this.rootDir);
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }));
        currIdx = 0;
        paths = new ArrayList<>();
        totalExamples = baseData.numExamples();
        totalLabels = baseData.numOutcomes();
        int offset = 0;
        totalBatches = baseData.numExamples() / batchSize;
        for(int i = 0; i < baseData.numExamples() / batchSize; i++) {
            paths.add(writeData(new DataSet(baseData.getFeatureMatrix().get(NDArrayIndex.interval(offset, offset + batchSize))
                    , baseData.getLabels().get(NDArrayIndex.interval(offset, offset + batchSize)))));
            offset += batchSize;
            if(offset >= totalExamples)
                break;
        }
    }

    /**
     *
     * @param baseData the base dataset
     * @param batchSize the batch size to split by
     * @throws IOException
     */
    public MiniBatchFileDataSetIterator(DataSet baseData,int batchSize,boolean delete) throws IOException {
        this(baseData, batchSize, delete, new File(System.getProperty("java.io.tmpdir")));
    }

    @Override
    public DataSet next(int num) {
        throw new UnsupportedOperationException("Unable to load custom number of examples");
    }

    @Override
    public int totalExamples() {
        return totalExamples;
    }

    @Override
    public int inputColumns() {
        return 0;
    }

    @Override
    public int totalOutcomes() {
        return totalLabels;
    }

    @Override
    public void reset() {
        currIdx = 0;
    }

    @Override
    public int batch() {
        return batchSize;
    }

    @Override
    public int cursor() {
        return currIdx;
    }

    @Override
    public int numExamples() {
        return totalExamples;
    }

    @Override
    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.dataSetPreProcessor = preProcessor;
    }

    @Override
    public boolean hasNext() {
        return currIdx < totalBatches;
    }

    @Override
    public void remove() {
        //no opt;
    }

    @Override
    public DataSet next() {
        try {
            DataSet ret =  read(currIdx);
            if(dataSetPreProcessor != null)
                dataSetPreProcessor.preProcess(ret);
            currIdx++;

            return ret;
        } catch (IOException e) {
            throw new IllegalStateException("Unable to read dataset");
        }
    }

    private DataSet read(int idx) throws IOException {
        BufferedInputStream bis = new BufferedInputStream(new FileInputStream(paths.get(idx)[0]));
        DataInputStream dis = new DataInputStream(bis);
        BufferedInputStream labelInputStream = new BufferedInputStream(new FileInputStream(paths.get(idx)[1]));
        DataInputStream labelDis = new DataInputStream(labelInputStream);
        DataSet d = new DataSet(Nd4j.read(dis),Nd4j.read(labelDis));
        dis.close();
        labelDis.close();
        return d;
    }


    private String[] writeData(DataSet write) throws IOException {
        String[] ret = new String[2];
        String dataSetId = UUID.randomUUID().toString();
        BufferedOutputStream dataOut = new BufferedOutputStream(new FileOutputStream(new File(rootDir,dataSetId + ".bin")));
        DataOutputStream dos = new DataOutputStream(dataOut);
        Nd4j.write(write.getFeatureMatrix(),dos);
        dos.flush();
        dos.close();


        BufferedOutputStream dataOutLabels = new BufferedOutputStream(new FileOutputStream(new File(rootDir,dataSetId + ".labels.bin")));
        DataOutputStream dosLabels = new DataOutputStream(dataOutLabels);
        Nd4j.write(write.getLabels(), dosLabels);
        dosLabels.flush();
        dos.close();
        ret[0] = new File(rootDir,dataSetId + ".bin").getAbsolutePath();
        ret[1] = new File(rootDir,dataSetId + ".labels.bin").getAbsolutePath();
        return ret;

    }

    public File getRootDir() {
        return rootDir;
    }

    public void setRootDir(File rootDir) {
        this.rootDir = rootDir;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy