All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.dataset.api.iterator.KFoldIterator Maven / Gradle / Ivy

package org.nd4j.linalg.dataset.api.iterator;

import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;

import java.util.ArrayList;
import java.util.List;

/**
 * Splits a dataset into k folds.
 * DataSet is duplicated in memory once
 * call .next() to get the k-1 folds to train on and call .testfold() to get the corresponding kth fold for testing
 * @author Susan Eraly
 */
public class KFoldIterator implements DataSetIterator {
    private DataSet singleFold;
    private int k;
    private int batch;
    private int lastBatch;
    private int kCursor = 0;
    private DataSet test;
    private DataSet train;
    protected DataSetPreProcessor preProcessor;

    public KFoldIterator(DataSet singleFold) {
        this(10, singleFold);
    }

    /**Create an iterator given the dataset and a value of k (optional, defaults to 10)
     * If number of samples in the dataset is not a multiple of k, the last fold will have less samples with the rest having the same number of samples.
     *
     * @param k number of folds (optional, defaults to 10)
     * @param singleFold DataSet to split into k folds
     */

    public KFoldIterator(int k, DataSet singleFold) {
        this.k = k;
        this.singleFold = singleFold.copy();
        if (k <= 1)
            throw new IllegalArgumentException();
        if (singleFold.numExamples() % k != 0) {
            if (k != 2) {
                this.batch = singleFold.numExamples() / (k - 1);
                this.lastBatch = singleFold.numExamples() % (k - 1);
            } else {
                this.lastBatch = singleFold.numExamples() / 2;
                this.batch = this.lastBatch + 1;
            }
        } else {
            this.batch = singleFold.numExamples() / k;
            this.lastBatch = singleFold.numExamples() / k;
        }
    }

    @Override
    public DataSet next(int num) throws UnsupportedOperationException {
        return null;
    }

    /**
     * Returns total number of examples in the dataset (all k folds)
     *
     * @return total number of examples in the dataset including all k folds
     */
    @Override
    public int totalExamples() {
        return singleFold.getLabels().size(0);
    }

    @Override
    public int inputColumns() {
        return singleFold.getFeatures().size(1);
    }

    @Override
    public int totalOutcomes() {
        return singleFold.getLabels().size(1);
    }

    @Override
    public boolean resetSupported() {
        return true;
    }

    @Override
    public boolean asyncSupported() {
        return false;
    }

    /**
     * Shuffles the dataset and resets to the first fold
     *
     * @return void
     */
    @Override
    public void reset() {
        //shuffle and return new k folds
        singleFold.shuffle();
        kCursor = 0;
    }


    /**
     * The number of examples in every fold, except the last if totalexamples % k !=0
     *
     * @return examples in a fold
     */
    @Override
    public int batch() {
        return batch;
    }

    /**
     * The number of examples in the last fold
     * if totalexamples % k == 0 same as the number of examples in every other fold
     *
     * @return examples in the last fold
     */
    public int lastBatch() {
        return lastBatch;
    }

    /**
     * cursor value of zero indicates no iterations and the very first fold will be held out as test
     * cursor value of 1 indicates the the next() call will return all but second fold which will be held out as test
     * curson value of k-1 indicates the next() call will return all but the last fold which will be held out as test
     *
     * @return the value of the cursor which indicates which fold is held out as a test set
     */

    @Override
    public int cursor() {
        return kCursor;
    }

    @Override
    public int numExamples() {
        return totalExamples();
    }

    @Override
    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    @Override
    public DataSetPreProcessor getPreProcessor() {
        return preProcessor;
    }

    @Override
    public List getLabels() {
        return singleFold.getLabelNamesList();
    }

    @Override
    public boolean hasNext() {
        return kCursor < k;
    }

    @Override
    public DataSet next() {
        nextFold();
        return train;
    }

    @Override
    public void remove() {
        // no-op
    }

    private void nextFold() {
        int left;
        int right;
        if (kCursor == k - 1) {
            left = totalExamples() - lastBatch;
            right = totalExamples();
        } else {
            left = kCursor * batch;
            right = left + batch;
        }

        List kMinusOneFoldList = new ArrayList();
        if (right < totalExamples()) {
            if (left > 0) {
                kMinusOneFoldList.add((DataSet) singleFold.getRange(0, left));
            }
            kMinusOneFoldList.add((DataSet) singleFold.getRange(right, totalExamples()));
            train = DataSet.merge(kMinusOneFoldList);
        } else {
            train = (DataSet) singleFold.getRange(0, left);
        }
        test = (DataSet) singleFold.getRange(left, right);

        kCursor++;

    }

    /**
     * @return the held out fold as a dataset
     */
    public DataSet testFold() {
        return test;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy