All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.experiment.ParallelLearnerValidationExperiment Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                ParallelLearnerValidationExperiment.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright October 04, 2008, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 * See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.experiment;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.PartitionedDataset;
import gov.sandia.cognition.learning.performance.PerformanceEvaluator;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Summarizer;
import java.util.Collection;
import java.util.LinkedList;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * Parallel version of the LearnerValidationExperiment class that executes
 * the validations experiments across available cores and hyperthreads.
 * 
 * @param   
 *          The type of the data to perform the experiment with.
 *          This will be passed to the fold creator to create a number of folds
 *          on which to validate the performance of the learning algorithm.
 * @param   
 *          The type of data created by the fold creator that will go into
 *          the learning algorithm. Typically, this is the same as the
 *          InputDataType, but it does not need to be. It just needs to match
 *          the output of the fold creator and the input of the learning
 *          algorithm.
 * @param    The type of the output produced by the learning 
 *          algorithm whose performance will be evaluated on each fold of data.
 * @param    The type of the statistic generated by the 
 *          performance evaluator on the learned object for each fold. It is
 *          created by passing the learned object plus the test data for the
 *          fold into the performance evaluator.
 * @param    The type produced by the summarizer at the end of
 *          the experiment from a collection of the given statistics (one for
 *          each fold). This represents the performance result for the learning
 *          algorithm for the whole experiment.
 * @author  Justin Basilico
 * @since   3.0
 */
public class ParallelLearnerValidationExperiment
    extends LearnerValidationExperiment
    implements ParallelAlgorithm
{

    /**
     * Thread pool used to split the computation across multiple cores
     */
    private transient ThreadPoolExecutor threadPool;
    
    /**
     * Default constructor
     */
    public ParallelLearnerValidationExperiment()
    {
        this(null, null, null);
    }
    
    /**
     * Creates a new instance of ParallelLearnerValidationExperiment.
     *
     * @param  foldCreator The object to use for creating the folds.
     * @param  performanceEvaluator The evaluator to use to compute the 
     *         performance of the learned object on each fold.
     * @param  summarizer The summarizer for summarizing the result of the 
     *         performance evaluator from all the folds.
     */
    public ParallelLearnerValidationExperiment(
        final ValidationFoldCreator foldCreator,
        final PerformanceEvaluator
            , ? extends StatisticType>
            performanceEvaluator,
        final Summarizer summarizer)
    {
        super(foldCreator, performanceEvaluator, summarizer);
    }

    @Override
    protected void runExperiment(
        final Collection> folds)
    {
        // The number of trials is the number of folds.
        this.setNumTrials(folds.size());
        
        this.fireExperimentStarted();
        
        
        LinkedList> trials = 
            new LinkedList>();
        // Go through the folds and run the trial for each fold.
        for (PartitionedDataset fold : folds)
        {
            final TrialTask trial = new TrialTask(fold);
            trials.add(trial);
        }
        
        Collection results = null;
        try
        {
            results = ParallelUtil.executeInParallel( trials, this.getThreadPool() );
        }
        catch (Exception ex)
        {
            Logger.getLogger(ParallelLearnerValidationExperiment.class.getName()).log(Level.SEVERE, null, ex);
        }

        this.getStatistics().addAll( results );
        
        this.fireExperimentEnded();
    }
    
    public ThreadPoolExecutor getThreadPool()
    {
        if (this.threadPool == null)
        {
            this.threadPool = ParallelUtil.createThreadPool();
        }
        
        return this.threadPool;
    }
    
    public void setThreadPool(
        final ThreadPoolExecutor threadPool)
    {
        this.threadPool = threadPool;
    }
    
    public int getNumThreads()
    {
        return ParallelUtil.getNumThreads( this );
    }
    
    /**
     * Callable task for a single evaluation trial
     */
    private class TrialTask
        extends Object
        implements Callable
    {
        /**
         * Dataset partition
         */
        private PartitionedDataset fold;
        
        /**
         * Creates a new instance of TrialTask
         * @param fold
         * Dataset partition
         */
        public TrialTask(
            final PartitionedDataset fold)
        {
            super();
            this.fold = fold;
        }
        
        @Override
        public StatisticType call()
            throws Exception
        {
            fireTrialStarted();
            // Perform the learning algorithm on this fold.
            final BatchLearner, ? extends LearnedType>
                learnerClone = ObjectUtil.cloneSmart(getLearner());

            final LearnedType learned = learnerClone.learn(fold.getTrainingSet());

            // Compute the statistic of the learned object on the testing set.
            final Collection testingSet = fold.getTestingSet();
            final StatisticType statistic =
                getPerformanceEvaluator().evaluatePerformance(
                    learned, testingSet);
            fireTrialEnded();
            return statistic;
        }
        
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy