gov.sandia.cognition.learning.experiment.ParallelLearnerValidationExperiment Maven / Gradle / Ivy
/*
* File: ParallelLearnerValidationExperiment.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright October 04, 2008, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.experiment;
import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.PartitionedDataset;
import gov.sandia.cognition.learning.performance.PerformanceEvaluator;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Summarizer;
import java.util.Collection;
import java.util.LinkedList;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Parallel version of the LearnerValidationExperiment class that executes
* the validations experiments across available cores and hyperthreads.
*
* @param
* The type of the data to perform the experiment with.
* This will be passed to the fold creator to create a number of folds
* on which to validate the performance of the learning algorithm.
* @param
* The type of data created by the fold creator that will go into
* the learning algorithm. Typically, this is the same as the
* InputDataType, but it does not need to be. It just needs to match
* the output of the fold creator and the input of the learning
* algorithm.
* @param The type of the output produced by the learning
* algorithm whose performance will be evaluated on each fold of data.
* @param The type of the statistic generated by the
* performance evaluator on the learned object for each fold. It is
* created by passing the learned object plus the test data for the
* fold into the performance evaluator.
* @param The type produced by the summarizer at the end of
* the experiment from a collection of the given statistics (one for
* each fold). This represents the performance result for the learning
* algorithm for the whole experiment.
* @author Justin Basilico
* @since 3.0
*/
public class ParallelLearnerValidationExperiment
extends LearnerValidationExperiment
implements ParallelAlgorithm
{
/**
* Thread pool used to split the computation across multiple cores
*/
private transient ThreadPoolExecutor threadPool;
/**
* Default constructor
*/
public ParallelLearnerValidationExperiment()
{
this(null, null, null);
}
/**
* Creates a new instance of ParallelLearnerValidationExperiment.
*
* @param foldCreator The object to use for creating the folds.
* @param performanceEvaluator The evaluator to use to compute the
* performance of the learned object on each fold.
* @param summarizer The summarizer for summarizing the result of the
* performance evaluator from all the folds.
*/
public ParallelLearnerValidationExperiment(
final ValidationFoldCreator foldCreator,
final PerformanceEvaluator
super LearnedType, ? super Collection extends FoldDataType>, ? extends StatisticType>
performanceEvaluator,
final Summarizer super StatisticType, ? extends SummaryType> summarizer)
{
super(foldCreator, performanceEvaluator, summarizer);
}
@Override
protected void runExperiment(
final Collection> folds)
{
// The number of trials is the number of folds.
this.setNumTrials(folds.size());
this.fireExperimentStarted();
LinkedList> trials =
new LinkedList>();
// Go through the folds and run the trial for each fold.
for (PartitionedDataset fold : folds)
{
final TrialTask trial = new TrialTask(fold);
trials.add(trial);
}
Collection results = null;
try
{
results = ParallelUtil.executeInParallel( trials, this.getThreadPool() );
}
catch (Exception ex)
{
Logger.getLogger(ParallelLearnerValidationExperiment.class.getName()).log(Level.SEVERE, null, ex);
}
this.getStatistics().addAll( results );
this.fireExperimentEnded();
}
public ThreadPoolExecutor getThreadPool()
{
if (this.threadPool == null)
{
this.threadPool = ParallelUtil.createThreadPool();
}
return this.threadPool;
}
public void setThreadPool(
final ThreadPoolExecutor threadPool)
{
this.threadPool = threadPool;
}
public int getNumThreads()
{
return ParallelUtil.getNumThreads( this );
}
/**
* Callable task for a single evaluation trial
*/
private class TrialTask
extends Object
implements Callable
{
/**
* Dataset partition
*/
private PartitionedDataset fold;
/**
* Creates a new instance of TrialTask
* @param fold
* Dataset partition
*/
public TrialTask(
final PartitionedDataset fold)
{
super();
this.fold = fold;
}
@Override
public StatisticType call()
throws Exception
{
fireTrialStarted();
// Perform the learning algorithm on this fold.
final BatchLearner super Collection extends FoldDataType>, ? extends LearnedType>
learnerClone = ObjectUtil.cloneSmart(getLearner());
final LearnedType learned = learnerClone.learn(fold.getTrainingSet());
// Compute the statistic of the learned object on the testing set.
final Collection testingSet = fold.getTestingSet();
final StatisticType statistic =
getPerformanceEvaluator().evaluatePerformance(
learned, testingSet);
fireTrialEnded();
return statistic;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy