gov.sandia.cognition.learning.experiment.LearnerRepeatExperiment Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: LearnerRepeatExperiment.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright January 07, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.experiment;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.PartitionedDataset;
import gov.sandia.cognition.learning.performance.PerformanceEvaluator;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.Summarizer;
import java.util.ArrayList;
import java.util.Collection;
/**
* Runs an experiment where the same learner is evaluated multiple times on
* the same data. Useful for cases where there is a randomized algorithm used
* in conjunction with a known training/test split. Should only be used with
* stochastic algorithms
*
* @param
* The type of the data to perform the experiment with. It will be
* passed to the algorithm.
* @param The type of the output produced by the learning
* algorithm whose performance will be evaluated on each fold of data.
* @param The type of the statistic generated by the
* performance evaluator on the learned object for each fold. It is
* created by passing the learned object plus the test data for the
* fold into the performance evaluator.
* @param The type produced by the summarizer at the end of
* the experiment from a collection of the given statistics (one for
* each fold). This represents the performance result for the learning
* algorithm for the whole experiment.
* @author Justin Basilico
* @since 3.1.1
*/
public class LearnerRepeatExperiment
extends AbstractLearningExperiment
implements PerformanceEvaluator, ? extends LearnedType>, PartitionedDataset extends InputDataType>, SummaryType>
{
/** The default number of trials is {@value}. */
public static final int DEFAULT_NUM_TRIALS = 10;
/** The number of trials to repeat the learning. */
protected int numTrials;
/** The evaluator to use to compute the performance of the learned object on
* each fold. */
protected PerformanceEvaluator
super LearnedType, ? super Collection extends InputDataType>,
? extends StatisticType>
performanceEvaluator;
/** The summarizer for summarizing the result of the performance evaluator
* from all the folds. */
protected Summarizer super StatisticType, ? extends SummaryType>
summarizer;
/** The learner that the experiment is run on. */
private BatchLearner
super Collection extends InputDataType>, ? extends LearnedType>
learner;
/** The performance evaluations made during the experiment. */
protected ArrayList statistics;
/** The summary of the performance evaluations made at the end of the
* experiment. */
protected SummaryType summary;
/**
* Creates a new instance of LearnerRepeatExperiment.
*/
public LearnerRepeatExperiment()
{
this(DEFAULT_NUM_TRIALS, null, null);
}
/**
* Creates a new instance of LearnerRepeatExperiment.
*
* @param numTrials The number of repeated trials to run.
* @param performanceEvaluator The evaluator to use to compute the
* performance of the learned object on each fold.
* @param summarizer The summarizer for summarizing the result of the
* performance evaluator from all the folds.
*/
public LearnerRepeatExperiment(
final int numTrials,
final PerformanceEvaluator
super LearnedType, ? super Collection extends InputDataType>, ? extends StatisticType>
performanceEvaluator,
final Summarizer super StatisticType, ? extends SummaryType> summarizer)
{
super();
this.setNumTrials(numTrials);
this.setPerformanceEvaluator(performanceEvaluator);
this.setSummarizer(summarizer);
// The initial number of trials is unknown.
this.setStatistics(null);
this.setSummary(null);
}
/**
* Performs the experiment.
*
* @param data The data to use.
* @param learner The learner to perform the experiment on.
* @return The summary of the experiment.
*/
public SummaryType evaluatePerformance(
final BatchLearner
super Collection extends InputDataType>, ? extends LearnedType>
learner,
final PartitionedDataset extends InputDataType> data)
{
this.setLearner(learner);
// Initialize the collection where we will store the statistics
// generated from the data.
this.setStatistics(new ArrayList(this.getNumTrials()));
this.setSummary(null);
this.runExperiment(data);
// Summarize the statistics.
this.setSummary(this.getSummarizer().summarize(this.getStatistics()));
return this.getSummary();
}
/**
* Runs the experiment.
*
* @param data The data to use.
*/
protected void runExperiment(
final PartitionedDataset extends InputDataType> data)
{
this.fireExperimentStarted();
// Go through the folds and run the trial for each fold.
for (int i = 1; i <= this.getNumTrials(); i++)
{
this.fireTrialStarted();
// Run the trial on this fold.
this.runTrial(data);
this.fireTrialEnded();
}
this.fireExperimentEnded();
}
/**
* Runs one trial in the experiment.
*
* @param data The data to use.
*/
protected void runTrial(
final PartitionedDataset extends InputDataType> data)
{
// Perform the learning algorithm on this fold.
final LearnedType learned = getLearner().learn(data.getTrainingSet());
// Compute the statistic of the learned object on the testing set.
final Collection extends InputDataType> testingSet =
data.getTestingSet();
final StatisticType statistic =
this.getPerformanceEvaluator().evaluatePerformance(
learned, testingSet);
statistics.add(statistic);
}
/**
* Gets the performance evaluator to apply to each fold.
*
* @return The performance evaluator to apply to each fold.
*/
public PerformanceEvaluator
super LearnedType, ? super Collection extends InputDataType>, ? extends StatisticType>
getPerformanceEvaluator()
{
return this.performanceEvaluator;
}
/**
* Sets the performance evaluator to apply to each fold.
*
* @param performanceEvaluator
* The performance evaluator to apply to each fold.
*/
public void setPerformanceEvaluator(
final PerformanceEvaluator
super LearnedType, ? super Collection extends InputDataType>, ? extends StatisticType>
performanceEvaluator)
{
this.performanceEvaluator = performanceEvaluator;
}
/**
* Gets the summarizer of the performance evaluations.
*
* @return The summarizer of the performance evaluations.
*/
public Summarizer super StatisticType, ? extends SummaryType> getSummarizer()
{
return this.summarizer;
}
/**
* Sets the summarizer of the performance evaluations.
*
* @param summarizer The summarizer of the performance evaluations.
*/
public void setSummarizer(
final Summarizer super StatisticType, ? extends SummaryType> summarizer)
{
this.summarizer = summarizer;
}
/**
* Gets the learner the experiment is being run on.
*
* @return The learner.
*/
public BatchLearner
super Collection extends InputDataType>, ? extends LearnedType>
getLearner()
{
return this.learner;
}
/**
* Sets the learner the experiment is being run on.
*
* @param learner The learner.
*/
protected void setLearner(
final BatchLearner
super Collection extends InputDataType>, ? extends LearnedType>
learner)
{
this.learner = learner;
}
/**
* Gets the performance evaluations for the trials of the experiment.
*
* @return The performance evaluations for the trials of the experiment.
*/
public ArrayList getStatistics()
{
return this.statistics;
}
/**
* Sets the performance evaluations for the trials of the experiment.
*
* @param statistics
* The performance evaluations for the trials of the experiment.
*/
protected void setStatistics(
final ArrayList statistics)
{
this.statistics = statistics;
}
/**
* Gets the summary of the experiment.
*
* @return The summary of the experiment.
*/
public SummaryType getSummary()
{
return this.summary;
}
/**
* Sets the summary of the experiment.
*
* @param summary The summary of the experiment.
*/
protected void setSummary(
final SummaryType summary)
{
this.summary = summary;
}
@Override
public int getNumTrials()
{
return this.numTrials;
}
/**
* Sets the number of trials for the experiment to repeatedly call the
* learning algorithm.
*
* @param numTrials
* The number of trials. Must be positive.
*/
public void setNumTrials(
final int numTrials)
{
ArgumentChecker.assertIsPositive("numTrials", numTrials);
this.numTrials = numTrials;
}
}