gov.sandia.cognition.learning.experiment.RandomByTwoFoldCreator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: RandomByTwoFoldCreator.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright January 20, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*/
package gov.sandia.cognition.learning.experiment;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.DefaultPartitionedDataset;
import gov.sandia.cognition.learning.data.PartitionedDataset;
import gov.sandia.cognition.math.Permutation;
import gov.sandia.cognition.util.AbstractRandomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;
/**
* A validation fold creator that takes a given collection of data and randomly
* splits it in half a given number of times, returning two folds for each
* split, using one half as training and the other half as testing. The number
* of folds is thus twice the parameterized number of splits. The data is
* reordered as a result of each split, so this should not be used for data
* whose sequence order matters. The default setup is a 5x2 cross-fold
* creation, which is a common validation technique.
*
* @param
* The type of data to create folds over.
* @author Justin Basilico
* @since 3.0
*/
public class RandomByTwoFoldCreator
extends AbstractRandomized
implements ValidationFoldCreator
{
/** The default number of splits is {@value}. */
public static final int DEFAULT_NUM_SPLITS = 5;
/** The number of splits. The number of folds is twice this number. */
protected int numSplits;
/**
* Creates a new {@code RandomByTwoFoldCreator} with a default number of
* splits.
*/
public RandomByTwoFoldCreator()
{
this(DEFAULT_NUM_SPLITS);
}
/**
* Creates a new {@code RandomByTwoFoldCreator} with a given number of
* splits.
*
* @param numSplits
* The number of splits to create. The number of folds created is
* twice this number. It must be positive.
*/
public RandomByTwoFoldCreator(
final int numSplits)
{
this(numSplits, new Random());
}
/**
* Creates a new {@code RandomByTwoFoldCreator} with a given number of
* splits.
*
* @param numSplits
* The number of splits to create. The number of folds created is
* twice this number. It must be positive.
* @param random
* The random number generator to use.
*/
public RandomByTwoFoldCreator(
final int numSplits,
final Random random)
{
super(random);
this.setNumSplits(numSplits);
}
public List> createFolds(
final Collection extends DataType> data)
{
final int size = CollectionUtil.size(data);
if (size < 2)
{
// Need at least two elements.
throw new IllegalArgumentException(
"data must have at least 2 elements.");
}
// Figure out the actual number of splits and folds
final int actualNumSplits = Math.min(size, this.getNumSplits());
final int actualNumFolds = 2 * actualNumSplits;
// We are going to have twice as many partitions as number of splits.
final ArrayList> result =
new ArrayList>(actualNumFolds);
final int halfSize = Math.max(size / 2, 1);
// Create the splits.
for (int i = 0; i < actualNumSplits; i++)
{
// Create a random ordering.
final ArrayList reordering =
Permutation.createReordering(data, this.getRandom());
// Get the two halves.
final List firstHalf = reordering.subList(0, halfSize);
final List secondHalf = reordering.subList(halfSize, size);
// Add the two datasets.
result.add(DefaultPartitionedDataset.create(firstHalf, secondHalf));
result.add(DefaultPartitionedDataset.create(secondHalf, firstHalf));
}
// Return the resulting partitions.
return result;
}
/**
* Gets the number of splits to perform. When a dataset is given, two times
* this number of partitions is returned. Must be positive.
*
* @return
* The number of splits to perform. Must be positive.
*/
public int getNumSplits()
{
return this.numSplits;
}
/**
* Sets the number of splits to perform. When a dataset is given, two times
* this number of partitions is returned. Must be positive.
*
* @param numSplits
* The number of splits to perform. Must be positive.
*/
public void setNumSplits(
final int numSplits)
{
if (numSplits <= 0)
{
throw new IllegalArgumentException(
"numSplits must be positive");
}
this.numSplits = numSplits;
}
}