All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.experiment.OnlineLearnerValidationExperiment Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                IncrementalLearnerValidationExperiment.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright June 10, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government. Export
 * of this program may require a license from the United States Government.
 * See CopyrightHistory.txt for complete details.
 *
 */
package gov.sandia.cognition.learning.experiment;

import gov.sandia.cognition.learning.algorithm.IncrementalLearner;
import gov.sandia.cognition.learning.performance.PerformanceEvaluator;
import gov.sandia.cognition.util.Summarizer;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;

/**
 * Implements an experiment where an incremental supervised machine learning
 * algorithm is evaluated by applying it to a set of data by successively
 * testing on each item and then training on it.
 *
 * @param   
 *          The type of the data to perform the experiment with.
 *          It will be used as input into the learning algorithm.
 * @param    The type of the output produced by the learning
 *          algorithm whose performance will be evaluated on each data item.
 * @param    The type of the statistic generated by the
 *          performance evaluator on the learned object for each data item. It
 *          is created by passing the learned object plus the data item
 *          into the performance evaluator.
 * @param    The type produced by the summarizer at the end of
 *          the experiment from a collection of the given statistics (one for
 *          each item). This represents the performance result for the learning
 *          algorithm for the whole experiment.
 * @author  Justin Basilico
 * @since   3.0
 */
public class OnlineLearnerValidationExperiment
    extends AbstractLearningExperiment
    implements PerformanceEvaluator, Collection, SummaryType>,
        Serializable
// TODO: This class is largely copied from LearnerValidationExperiment.
// They should probably be merged into abstract classes.
// --jdbasil (2010-06-10)
{
    /** The evaluator to use to compute the performance of the learned object on
     *  each fold. */
    protected PerformanceEvaluator
        ,
         ? extends StatisticType>
        performanceEvaluator;

    /** The summarizer for summarizing the result of the performance evaluator
     *  from all the folds. */
    protected Summarizer
        summarizer;

    /** The number of trials in the experiment, which is the number of folds
     *  in the experiment. */
    protected int numTrials;

    /** The performance evaluations made during the experiment. */
    protected ArrayList statistics;

    /** The summary of the performance evaluations made at the end of the
     *  experiment. */
    protected SummaryType summary;

    /**
     * Creates a new instance of IncrementalLearnerValidationExperiment.
     */
    public OnlineLearnerValidationExperiment()
    {
        this(null, null);
    }

    /**
     * Creates a new instance of IncrementalLearnerValidationExperiment.
     *
     * @param  performanceEvaluator The evaluator to use to compute the
     *         performance of the learned object on each fold.
     * @param  summarizer The summarizer for summarizing the result of the
     *         performance evaluator from all the folds.
     */
    public OnlineLearnerValidationExperiment(
        final PerformanceEvaluator
            , ? extends StatisticType>
            performanceEvaluator,
        final Summarizer summarizer)
    {
        super();

        this.setPerformanceEvaluator(performanceEvaluator);
        this.setSummarizer(summarizer);

        // The initial number of trials is unknown.
        this.setNumTrials(-1);
        this.setStatistics(null);
        this.setSummary(null);
    }

    /**
     * Performs the experiment.
     *
     * @param  data The data to use.
     * @param  learner The learner to perform the experiment on.
     * @return The summary of the experiment.
     */
    public SummaryType evaluatePerformance(
        final IncrementalLearner learner,
        final Collection data)
    {
        // Initialize the collection where we will store the statistics
        // generated from the data.
        this.setNumTrials(data.size());
        this.setStatistics(new ArrayList(data.size()));
        this.setSummary(null);

        // We've started the experiment.
        this.fireExperimentStarted();

        // Initialize learning.
        final LearnedType learned = learner.createInitialLearnedObject();

        // Go through and evaluate each item and then update the model using
        // it.
        for (DataType item : data)
        {
            // Start a new trial.
            this.fireTrialStarted();

            // Compute the statistic for this item.
            final StatisticType statistic =
                this.getPerformanceEvaluator().evaluatePerformance(
                    learned, Collections.singletonList(item));
            this.statistics.add(statistic);

            // Update the learned value.
            learner.update(learned, item);

            // The trial has ended.
            this.fireTrialEnded();
        }

        // Summarize the statistics.
        this.setSummary(this.getSummarizer().summarize(this.getStatistics()));

        // The experiment has ended.
        this.fireExperimentEnded();

        // The result is the summary statistic.
        return this.getSummary();
    }

    /**
     * Gets the performance evaluator to apply to each fold.
     *
     * @return The performance evaluator to apply to each fold.
     */
    public PerformanceEvaluator
        , ? extends StatisticType>
        getPerformanceEvaluator()
    {
        return this.performanceEvaluator;
    }

    /**
     * Sets the performance evaluator to apply to each fold.
     *
     * @param  performanceEvaluator
     *      The performance evaluator to apply to each fold.
     */
    public void setPerformanceEvaluator(
        final PerformanceEvaluator
            , ? extends StatisticType>
            performanceEvaluator)
    {
        this.performanceEvaluator = performanceEvaluator;
    }

    /**
     * Gets the summarizer of the performance evaluations.
     *
     * @return The summarizer of the performance evaluations.
     */
    public Summarizer getSummarizer()
    {
        return this.summarizer;
    }

    /**
     * Sets the summarizer of the performance evaluations.
     *
     * @param  summarizer The summarizer of the performance evaluations.
     */
    public void setSummarizer(
        final Summarizer summarizer)
    {
        this.summarizer = summarizer;
    }

    /**
     * Gets the number of trials. Will be equal to the number of data points
     * that the experiment is being run over.
     *
     * @return
     *      The number of trials.
     */
    public int getNumTrials()
    {
        return this.numTrials;
    }

    /**
     * Sets the number of trials.
     *
     * @param   numTrials
     *      The number of trials.
     */
    protected void setNumTrials(
        final int numTrials)
    {
        this.numTrials = numTrials;
    }
    
    /**
     * Gets the performance evaluations for the trials of the experiment.
     *
     * @return The performance evaluations for the trials of the experiment.
     */
    public ArrayList getStatistics()
    {
        return this.statistics;
    }

    /**
     * Sets the performance evaluations for the trials of the experiment.
     *
     * @param  statistics
     *      The performance evaluations for the trials of the experiment.
     */
    protected void setStatistics(
        final ArrayList statistics)
    {
        this.statistics = statistics;
    }

    /**
     * Gets the summary of the experiment.
     *
     * @return The summary of the experiment.
     */
    public SummaryType getSummary()
    {
        return this.summary;
    }

    /**
     * Sets the summary of the experiment.
     *
     * @param  summary The summary of the experiment.
     */
    protected void setSummary(
        final SummaryType summary)
    {
        this.summary = summary;
    }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy