gov.sandia.cognition.learning.performance.AnytimeBatchLearnerValidationPerformanceReporter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: AnytimeBatchLearnerValidationPerformanceReporter.java
* Authors: Justin Basilico
* Project: Cognitive Foundry
*
* Copyright 2013 Cognitive Foundry. All rights reserved.
*/
package gov.sandia.cognition.learning.performance;
import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener;
import gov.sandia.cognition.learning.algorithm.AnytimeBatchLearner;
import java.io.PrintStream;
/**
* A performance reporter for a validation set. Can be attached to an anytime
* algorithm to report results after each step. It reports the result of
* evaluating performance on both the training and validation sets.
*
* @param
* The type of data this reports validation performance for.
* @param
* The type of object that is returned by the anytime learner that
* performance can be computed for.
* @author Justin Basilico
* @since 3.4.0
*/
public class AnytimeBatchLearnerValidationPerformanceReporter
extends AbstractIterativeAlgorithmListener
{
/** The default format for reporting performance is {@value}. */
public static final String DEFAULT_FORMAT = "Iteration %d. Train: %s. Validation: %s";
/** The performance evaluator. Produces the metrics. */
protected PerformanceEvaluator super ObjectType, ? super DataType, ?> performanceEvaluator;
/** The validation dataset to use. */
protected DataType validationData;
/** The print stream to report performance to. */
protected PrintStream out;
/** The format for the performance report, passed to String.format. */
protected String format;
/**
* Creates a new {@code AnytimeBatchLearnerValidationPerformanceReporter} that
* reports to the given print stream using the default format.
*
* @param performanceEvaluator
* The performance evaluator. Produces the metrics.
* @param validationData
* The validation dataset to use.
*/
public AnytimeBatchLearnerValidationPerformanceReporter(
final PerformanceEvaluator super ObjectType, ? super DataType, ?> performanceEvaluator,
final DataType validationData)
{
this(performanceEvaluator, validationData, System.out);
}
/**
* Creates a new {@code AnytimeBatchLearnerValidationPerformanceReporter} that
* reports to the given print stream using the default format.
*
* @param performanceEvaluator
* The performance evaluator. Produces the metrics.
* @param validationData
* The validation dataset to use.
* @param out
* Print stream to report the status to.
*/
public AnytimeBatchLearnerValidationPerformanceReporter(
final PerformanceEvaluator super ObjectType, ? super DataType, ?> performanceEvaluator,
final DataType validationData,
final PrintStream out)
{
this(performanceEvaluator, validationData, out, DEFAULT_FORMAT);
}
/**
* Creates a new {@code AnytimeBatchLearnerValidationPerformanceReporter} that
* reports to System.out and the given format.
*
* @param performanceEvaluator
* The performance evaluator. Produces the metrics.
* @param validationData
* The validation dataset to use.
* @param format
* Format string for the status messages, which passed to
* {@code String.format}.
*/
public AnytimeBatchLearnerValidationPerformanceReporter(
final PerformanceEvaluator super ObjectType, ? super DataType, ?> performanceEvaluator,
final DataType validationData,
final String format)
{
this(performanceEvaluator, validationData, System.out, format);
}
/**
* Creates a new {@code AnytimeBatchLearnerValidationPerformanceReporter} that
* reports to the given print stream and format.
*
* @param performanceEvaluator
* The performance evaluator. Produces the metrics.
* @param validationData
* The validation dataset to use.
* @param out
* Print stream to report the status to.
* @param format
* Format string for the status messages, which passed to
* {@code String.format}.
*/
public AnytimeBatchLearnerValidationPerformanceReporter(
final PerformanceEvaluator super ObjectType, ? super DataType, ?> performanceEvaluator,
final DataType validationData,
final PrintStream out,
final String format)
{
super();
this.setPerformanceEvaluator(performanceEvaluator);
this.setValidationData(validationData);
this.out = out;
this.format = format;
}
@Override
public void stepEnded(
final IterativeAlgorithm algorithm)
{
@SuppressWarnings("unchecked")
final AnytimeBatchLearner extends DataType, ? extends ObjectType> anytimeAlgorithm =
(AnytimeBatchLearner extends DataType, ? extends ObjectType>) algorithm;
final Object trainPerformance =
this.performanceEvaluator.evaluatePerformance(
anytimeAlgorithm.getResult(), anytimeAlgorithm.getData());
final Object validationPerformance =
this.performanceEvaluator.evaluatePerformance(
anytimeAlgorithm.getResult(), this.validationData);
this.out.println(String.format(this.format, algorithm.getIteration(),
trainPerformance, validationPerformance));
}
/**
* Gets the performance evaluator. It produces the metrics.
*
* @return
* The performance evaluator.
*/
public PerformanceEvaluator super ObjectType, ? super DataType, ?> getPerformanceEvaluator()
{
return this.performanceEvaluator;
}
/**
* Sets the performance evaluator. It produces the metrics.
*
* @param performanceEvaluator
* The performance evaluator.
*/
public void setPerformanceEvaluator(
final PerformanceEvaluator super ObjectType, ? super DataType, ?> performanceEvaluator)
{
this.performanceEvaluator = performanceEvaluator;
}
/**
* Gets the validation dataset.
*
* @return
* The validation dataset.
*/
public DataType getValidationData()
{
return this.validationData;
}
/**
* Sets the validation dataset.
*
* @param validationData
* The validation dataset.
*/
public void setValidationData(
final DataType validationData)
{
this.validationData = validationData;
}
}