gov.sandia.cognition.learning.experiment.LearnerValidationExperiment Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: LearnerValidationExperiment.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright September 21, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.experiment;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.performance.PerformanceEvaluator;
import gov.sandia.cognition.learning.data.PartitionedDataset;
import gov.sandia.cognition.util.Summarizer;
import java.util.ArrayList;
import java.util.Collection;
/**
* The {@code LearnerValidationExperiment} class implements an experiment where
* a supervised machine learning algorithm is evaluated by applying it to a set
* of folds created from a given set of data.
*
* @param
* The type of the data to perform the experiment with.
* This will be passed to the fold creator to create a number of folds
* on which to validate the performance of the learning algorithm.
* @param
* The type of data created by the fold creator that will go into
* the learning algorithm. Typically, this is the same as the
* InputDataType, but it does not need to be. It just needs to match
* the output of the fold creator and the input of the learning
* algorithm.
* @param The type of the output produced by the learning
* algorithm whose performance will be evaluated on each fold of data.
* @param The type of the statistic generated by the
* performance evaluator on the learned object for each fold. It is
* created by passing the learned object plus the test data for the
* fold into the performance evaluator.
* @param The type produced by the summarizer at the end of
* the experiment from a collection of the given statistics (one for
* each fold). This represents the performance result for the learning
* algorithm for the whole experiment.
* @author Justin Basilico
* @since 2.0
*/
@PublicationReference(
author="Wikipedia",
title="Decriptive statistics",
type=PublicationType.WebPage,
year=2008,
url="http://en.wikipedia.org/wiki/Descriptive_statistics"
)
public class LearnerValidationExperiment
extends AbstractValidationFoldExperiment
implements PerformanceEvaluator, ? extends LearnedType>, Collection extends InputDataType>, SummaryType>
{
/** The evaluator to use to compute the performance of the learned object on
* each fold. */
protected PerformanceEvaluator
super LearnedType, ? super Collection extends FoldDataType>,
? extends StatisticType>
performanceEvaluator;
/** The summarizer for summarizing the result of the performance evaluator
* from all the folds. */
protected Summarizer super StatisticType, ? extends SummaryType>
summarizer;
/** The learner that the experiment is run on. */
private BatchLearner
super Collection extends FoldDataType>, ? extends LearnedType>
learner;
/** The performance evaluations made during the experiment. */
protected ArrayList statistics;
/** The summary of the performance evaluations made at the end of the
* experiment. */
protected SummaryType summary;
/**
* Creates a new instance of SupervisedLearnerExperiment.
*/
public LearnerValidationExperiment()
{
this(null, null, null);
}
/**
* Creates a new instance of SupervisedLearnerExperiment.
*
* @param foldCreator The object to use for creating the folds.
* @param performanceEvaluator The evaluator to use to compute the
* performance of the learned object on each fold.
* @param summarizer The summarizer for summarizing the result of the
* performance evaluator from all the folds.
*/
public LearnerValidationExperiment(
final ValidationFoldCreator foldCreator,
final PerformanceEvaluator
super LearnedType, ? super Collection extends FoldDataType>, ? extends StatisticType>
performanceEvaluator,
final Summarizer super StatisticType, ? extends SummaryType> summarizer)
{
super(foldCreator);
this.setPerformanceEvaluator(performanceEvaluator);
this.setSummarizer(summarizer);
// The initial number of trials is unknown.
this.setStatistics(null);
this.setSummary(null);
}
/**
* @deprecated Use evaluatePerformance instead.
*
* Performs the experiment.
*
* @param data The data to use.
* @param learner The learner to perform the experiment on.
* @return The summary of the experiment.
*/
@Deprecated
public SummaryType evaluate(
final BatchLearner super Collection extends FoldDataType>, ? extends LearnedType>
learner,
final Collection extends InputDataType> data)
{
return this.evaluatePerformance(learner, data);
}
public SummaryType evaluatePerformance(
final BatchLearner super Collection extends FoldDataType>, ? extends LearnedType>
learner,
final Collection extends InputDataType> data)
{
// The first step in the experiment is to create the folds.
final Collection> folds =
this.getFoldCreator().createFolds(data);
this.setLearner(learner);
// Initialize the collection where we will store the statistics
// generated from the data.
this.setStatistics(new ArrayList(folds.size()));
this.setSummary(null);
this.runExperiment(folds);
// Summarize the statistics.
this.setSummary(this.getSummarizer().summarize(this.getStatistics()));
return this.getSummary();
}
protected void runTrial(
final PartitionedDataset fold)
{
// Perform the learning algorithm on this fold.
final LearnedType learned = getLearner().learn(fold.getTrainingSet());
// Compute the statistic of the learned object on the testing set.
final Collection testingSet = fold.getTestingSet();
final StatisticType statistic =
this.getPerformanceEvaluator().evaluatePerformance(
learned, testingSet);
statistics.add(statistic);
}
/**
* Gets the performance evaluator to apply to each fold.
*
* @return The performance evaluator to apply to each fold.
*/
public PerformanceEvaluator
super LearnedType, ? super Collection extends FoldDataType>, ? extends StatisticType>
getPerformanceEvaluator()
{
return this.performanceEvaluator;
}
/**
* Sets the performance evaluator to apply to each fold.
*
* @param performanceEvaluator
* The performance evaluator to apply to each fold.
*/
public void setPerformanceEvaluator(
final PerformanceEvaluator
super LearnedType, ? super Collection extends FoldDataType>, ? extends StatisticType>
performanceEvaluator)
{
this.performanceEvaluator = performanceEvaluator;
}
/**
* Gets the summarizer of the performance evaluations.
*
* @return The summarizer of the performance evaluations.
*/
public Summarizer super StatisticType, ? extends SummaryType> getSummarizer()
{
return this.summarizer;
}
/**
* Sets the summarizer of the performance evaluations.
*
* @param summarizer The summarizer of the performance evaluations.
*/
public void setSummarizer(
final Summarizer super StatisticType, ? extends SummaryType> summarizer)
{
this.summarizer = summarizer;
}
/**
* Gets the learner the experiment is being run on.
*
* @return The learner.
*/
public BatchLearner
super Collection extends FoldDataType>, ? extends LearnedType>
getLearner()
{
return this.learner;
}
/**
* Sets the learner the experiment is being run on.
*
* @param learner The learner.
*/
protected void setLearner(
final BatchLearner
super Collection extends FoldDataType>, ? extends LearnedType>
learner)
{
this.learner = learner;
}
/**
* Gets the performance evaluations for the trials of the experiment.
*
* @return The performance evaluations for the trials of the experiment.
*/
public ArrayList getStatistics()
{
return this.statistics;
}
/**
* Sets the performance evaluations for the trials of the experiment.
*
* @param statistics
* The performance evaluations for the trials of the experiment.
*/
protected void setStatistics(
final ArrayList statistics)
{
this.statistics = statistics;
}
/**
* Gets the summary of the experiment.
*
* @return The summary of the experiment.
*/
public SummaryType getSummary()
{
return this.summary;
}
/**
* Sets the summary of the experiment.
*
* @param summary The summary of the experiment.
*/
protected void setSummary(
final SummaryType summary)
{
this.summary = summary;
}
}