All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.datasets.iterator.ExistingDataSetIterator Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.datasets.iterator;


import lombok.Getter;
import lombok.NonNull;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

import java.util.Iterator;
import java.util.List;

/**
 * This wrapper provides DataSetIterator interface to existing java Iterable and Iterator
 *
 * @author [email protected]
 */
public class ExistingDataSetIterator implements DataSetIterator {
    @Getter
    private DataSetPreProcessor preProcessor;

    private transient Iterable iterable;
    private transient Iterator iterator;
    private int totalExamples = 0;
    private int numFeatures = 0;
    private int numLabels = 0;
    private List labels;


    public ExistingDataSetIterator(@NonNull Iterator iterator) {
        this.iterator = iterator;
    }

    public ExistingDataSetIterator(@NonNull Iterator iterator, @NonNull List labels) {
        this(iterator);
        this.labels = labels;
    }

    public ExistingDataSetIterator(@NonNull Iterable iterable) {
        this.iterable = iterable;
        this.iterator = iterable.iterator();
    }

    public ExistingDataSetIterator(@NonNull Iterable iterable, @NonNull List labels) {
        this(iterable);
        this.labels = labels;
    }


    public ExistingDataSetIterator(@NonNull Iterable iterable, int totalExamples, int numFeatures,
                    int numLabels) {
        this(iterable);

        this.totalExamples = totalExamples;
        this.numFeatures = numFeatures;
        this.numLabels = numLabels;
    }

    @Override
    public DataSet next(int num) {
        // TODO: this might be changed
        throw new UnsupportedOperationException("next(int) isn't supported");
    }

    @Override
    public int totalExamples() {
        return totalExamples;
    }

    @Override
    public int inputColumns() {
        return numFeatures;
    }

    @Override
    public int totalOutcomes() {
        if (labels != null)
            return labels.size();

        return numLabels;
    }

    @Override
    public boolean resetSupported() {
        return iterable != null;
    }

    @Override
    public boolean asyncSupported() {
        //No need to asynchronously prefetch here: already in memory
        return false;
    }

    @Override
    public void reset() {
        if (iterable != null)
            this.iterator = iterable.iterator();
        else
            throw new IllegalStateException(
                            "To use reset() method you need to provide Iterable, not Iterator");
    }

    @Override
    public int batch() {
        return 0;
    }

    @Override
    public int cursor() {
        return 0;
    }

    @Override
    public int numExamples() {
        return totalExamples;
    }

    @Override
    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    @Override
    public List getLabels() {
        return labels;
    }

    @Override
    public boolean hasNext() {
        if (iterator != null)
            return iterator.hasNext();

        return false;
    }

    @Override
    public DataSet next() {
        if (preProcessor != null) {
            DataSet ds = iterator.next();
            if (!ds.isPreProcessed()) {
                preProcessor.preProcess(ds);
                ds.markAsPreProcessed();
            }
            return ds;
        } else
            return iterator.next();
    }

    @Override
    public void remove() {
        // no-op
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy