gov.sandia.cognition.learning.experiment.OnlineLearnerValidationExperiment Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of gov-sandia-cognition-learning-core Show documentation
Show all versions of gov-sandia-cognition-learning-core Show documentation
Algorithms and components for machine learning and statistics.
The newest version!
/*
* File: IncrementalLearnerValidationExperiment.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright June 10, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.experiment;
import gov.sandia.cognition.learning.algorithm.IncrementalLearner;
import gov.sandia.cognition.learning.performance.PerformanceEvaluator;
import gov.sandia.cognition.util.Summarizer;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
/**
* Implements an experiment where an incremental supervised machine learning
* algorithm is evaluated by applying it to a set of data by successively
* testing on each item and then training on it.
*
* @param
* The type of the data to perform the experiment with.
* It will be used as input into the learning algorithm.
* @param The type of the output produced by the learning
* algorithm whose performance will be evaluated on each data item.
* @param The type of the statistic generated by the
* performance evaluator on the learned object for each data item. It
* is created by passing the learned object plus the data item
* into the performance evaluator.
* @param The type produced by the summarizer at the end of
* the experiment from a collection of the given statistics (one for
* each item). This represents the performance result for the learning
* algorithm for the whole experiment.
* @author Justin Basilico
* @since 3.0
*/
public class OnlineLearnerValidationExperiment
extends AbstractLearningExperiment
implements PerformanceEvaluator, Collection extends DataType>, SummaryType>,
Serializable
// TODO: This class is largely copied from LearnerValidationExperiment.
// They should probably be merged into abstract classes.
// --jdbasil (2010-06-10)
{
/** The evaluator to use to compute the performance of the learned object on
* each fold. */
protected PerformanceEvaluator
super LearnedType, ? super Collection extends DataType>,
? extends StatisticType>
performanceEvaluator;
/** The summarizer for summarizing the result of the performance evaluator
* from all the folds. */
protected Summarizer super StatisticType, ? extends SummaryType>
summarizer;
/** The number of trials in the experiment, which is the number of folds
* in the experiment. */
protected int numTrials;
/** The performance evaluations made during the experiment. */
protected ArrayList statistics;
/** The summary of the performance evaluations made at the end of the
* experiment. */
protected SummaryType summary;
/**
* Creates a new instance of IncrementalLearnerValidationExperiment.
*/
public OnlineLearnerValidationExperiment()
{
this(null, null);
}
/**
* Creates a new instance of IncrementalLearnerValidationExperiment.
*
* @param performanceEvaluator The evaluator to use to compute the
* performance of the learned object on each fold.
* @param summarizer The summarizer for summarizing the result of the
* performance evaluator from all the folds.
*/
public OnlineLearnerValidationExperiment(
final PerformanceEvaluator
super LearnedType, ? super Collection extends DataType>, ? extends StatisticType>
performanceEvaluator,
final Summarizer super StatisticType, ? extends SummaryType> summarizer)
{
super();
this.setPerformanceEvaluator(performanceEvaluator);
this.setSummarizer(summarizer);
// The initial number of trials is unknown.
this.setNumTrials(-1);
this.setStatistics(null);
this.setSummary(null);
}
/**
* Performs the experiment.
*
* @param data The data to use.
* @param learner The learner to perform the experiment on.
* @return The summary of the experiment.
*/
public SummaryType evaluatePerformance(
final IncrementalLearner super DataType, LearnedType> learner,
final Collection extends DataType> data)
{
// Initialize the collection where we will store the statistics
// generated from the data.
this.setNumTrials(data.size());
this.setStatistics(new ArrayList(data.size()));
this.setSummary(null);
// We've started the experiment.
this.fireExperimentStarted();
// Initialize learning.
final LearnedType learned = learner.createInitialLearnedObject();
// Go through and evaluate each item and then update the model using
// it.
for (DataType item : data)
{
// Start a new trial.
this.fireTrialStarted();
// Compute the statistic for this item.
final StatisticType statistic =
this.getPerformanceEvaluator().evaluatePerformance(
learned, Collections.singletonList(item));
this.statistics.add(statistic);
// Update the learned value.
learner.update(learned, item);
// The trial has ended.
this.fireTrialEnded();
}
// Summarize the statistics.
this.setSummary(this.getSummarizer().summarize(this.getStatistics()));
// The experiment has ended.
this.fireExperimentEnded();
// The result is the summary statistic.
return this.getSummary();
}
/**
* Gets the performance evaluator to apply to each fold.
*
* @return The performance evaluator to apply to each fold.
*/
public PerformanceEvaluator
super LearnedType, ? super Collection extends DataType>, ? extends StatisticType>
getPerformanceEvaluator()
{
return this.performanceEvaluator;
}
/**
* Sets the performance evaluator to apply to each fold.
*
* @param performanceEvaluator
* The performance evaluator to apply to each fold.
*/
public void setPerformanceEvaluator(
final PerformanceEvaluator
super LearnedType, ? super Collection extends DataType>, ? extends StatisticType>
performanceEvaluator)
{
this.performanceEvaluator = performanceEvaluator;
}
/**
* Gets the summarizer of the performance evaluations.
*
* @return The summarizer of the performance evaluations.
*/
public Summarizer super StatisticType, ? extends SummaryType> getSummarizer()
{
return this.summarizer;
}
/**
* Sets the summarizer of the performance evaluations.
*
* @param summarizer The summarizer of the performance evaluations.
*/
public void setSummarizer(
final Summarizer super StatisticType, ? extends SummaryType> summarizer)
{
this.summarizer = summarizer;
}
/**
* Gets the number of trials. Will be equal to the number of data points
* that the experiment is being run over.
*
* @return
* The number of trials.
*/
public int getNumTrials()
{
return this.numTrials;
}
/**
* Sets the number of trials.
*
* @param numTrials
* The number of trials.
*/
protected void setNumTrials(
final int numTrials)
{
this.numTrials = numTrials;
}
/**
* Gets the performance evaluations for the trials of the experiment.
*
* @return The performance evaluations for the trials of the experiment.
*/
public ArrayList getStatistics()
{
return this.statistics;
}
/**
* Sets the performance evaluations for the trials of the experiment.
*
* @param statistics
* The performance evaluations for the trials of the experiment.
*/
protected void setStatistics(
final ArrayList statistics)
{
this.statistics = statistics;
}
/**
* Gets the summary of the experiment.
*
* @return The summary of the experiment.
*/
public SummaryType getSummary()
{
return this.summary;
}
/**
* Sets the summary of the experiment.
*
* @param summary The summary of the experiment.
*/
protected void setSummary(
final SummaryType summary)
{
this.summary = summary;
}
}