All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.experiment.RandomByTwoFoldCreator Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                RandomByTwoFoldCreator.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright January 20, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government. Export
 * of this program may require a license from the United States Government.
 * See CopyrightHistory.txt for complete details.
 */

package gov.sandia.cognition.learning.experiment;

import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.DefaultPartitionedDataset;
import gov.sandia.cognition.learning.data.PartitionedDataset;
import gov.sandia.cognition.math.Permutation;
import gov.sandia.cognition.util.AbstractRandomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;

/**
 * A validation fold creator that takes a given collection of data and randomly
 * splits it in half a given number of times, returning two folds for each
 * split, using one half as training and the other half as testing. The number
 * of folds is thus twice the parameterized number of splits. The data is
 * reordered as a result of each split, so this should not be used for data
 * whose sequence order matters. The default setup is a 5x2 cross-fold
 * creation, which is a common validation technique.
 *
 * @param   
 *      The type of data to create folds over.
 * @author  Justin Basilico
 * @since   3.0
 */
public class RandomByTwoFoldCreator
    extends AbstractRandomized
    implements ValidationFoldCreator
{
    /** The default number of splits is {@value}. */
    public static final int DEFAULT_NUM_SPLITS = 5;
    
    /** The number of splits. The number of folds is twice this number. */
    protected int numSplits;

    /**
     * Creates a new {@code RandomByTwoFoldCreator} with a default number of
     * splits.
     */
    public RandomByTwoFoldCreator()
    {
        this(DEFAULT_NUM_SPLITS);
    }

    /**
     * Creates a new {@code RandomByTwoFoldCreator} with a given number of
     * splits.
     *
     * @param   numSplits
     *      The number of splits to create. The number of folds created is
     *      twice this number. It must be positive.
     */
    public RandomByTwoFoldCreator(
        final int numSplits)
    {
        this(numSplits, new Random());
    }

    /**
     * Creates a new {@code RandomByTwoFoldCreator} with a given number of
     * splits.
     *
     * @param   numSplits
     *      The number of splits to create. The number of folds created is
     *      twice this number. It must be positive.
     * @param   random
     *      The random number generator to use.
     */
    public RandomByTwoFoldCreator(
        final int numSplits,
        final Random random)
    {
        super(random);

        this.setNumSplits(numSplits);
    }

    public List> createFolds(
        final Collection data)
    {
        final int size = CollectionUtil.size(data);
        if (size < 2)
        {
            // Need at least two elements.
            throw new IllegalArgumentException(
                "data must have at least 2 elements.");
        }

        // Figure out the actual number of splits and folds
        final int actualNumSplits = Math.min(size, this.getNumSplits());
        final int actualNumFolds = 2 * actualNumSplits;

        // We are going to have twice as many partitions as number of splits.
        final ArrayList> result =
            new ArrayList>(actualNumFolds);

        final int halfSize = Math.max(size / 2, 1);

        // Create the splits.
        for (int i = 0; i < actualNumSplits; i++)
        {
            // Create a random ordering.
            final ArrayList reordering =
                Permutation.createReordering(data, this.getRandom());

            // Get the two halves.
            final List firstHalf = reordering.subList(0, halfSize);
            final List secondHalf = reordering.subList(halfSize, size);

            // Add the two datasets.
            result.add(DefaultPartitionedDataset.create(firstHalf, secondHalf));
            result.add(DefaultPartitionedDataset.create(secondHalf, firstHalf));
        }

        // Return the resulting partitions.
        return result;
    }

    /**
     * Gets the number of splits to perform. When a dataset is given, two times
     * this number of partitions is returned. Must be positive.
     *
     * @return
     *      The number of splits to perform. Must be positive.
     */
    public int getNumSplits()
    {
        return this.numSplits;
    }

    /**
     * Sets the number of splits to perform. When a dataset is given, two times
     * this number of partitions is returned. Must be positive.
     *
     * @param   numSplits
     *      The number of splits to perform. Must be positive.
     */
    public void setNumSplits(
        final int numSplits)
    {
        if (numSplits <= 0)
        {
            throw new IllegalArgumentException(
                "numSplits must be positive");
        }
        this.numSplits = numSplits;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy