All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.function.cost.AbstractSupervisedCostFunction Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                AbstractSupervisedCostFunction.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Dec 20, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 * 
 */

package gov.sandia.cognition.learning.function.cost;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultWeightedTargetEstimatePair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.learning.data.WeightedTargetEstimatePair;
import gov.sandia.cognition.learning.performance.AbstractSupervisedPerformanceEvaluator;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;

/**
 * Partial implementation of SupervisedCostFunction
 * @param  Input type of the dataset and Evaluator
 * @param  Output type (labels) of the dataset and Evaluator
 * @author Kevin R. Dixon
 * @since 2.0
 */
public abstract class AbstractSupervisedCostFunction
    extends AbstractSupervisedPerformanceEvaluator
    implements SupervisedCostFunction
{

    /**
     * Labeled dataset to use to evaluate the cost against
     */
    private Collection> costParameters;

    /** 
     * Creates a new instance of AbstractSupervisedCostFunction 
     */
    public AbstractSupervisedCostFunction()
    {
        this.setCostParameters( null );
    }

    /**
     * Creates a new instance of AbstractSupervisedCostFunction 
     * @param costParameters
     * Labeled dataset to use to evaluate the cost against
     */
    public AbstractSupervisedCostFunction(
        Collection> costParameters )
    {
        this.setCostParameters( costParameters );
    }

    @Override
    @SuppressWarnings("unchecked")
    public AbstractSupervisedCostFunction clone()
    {
        AbstractSupervisedCostFunction clone =
            (AbstractSupervisedCostFunction) super.clone();
        clone.setCostParameters(
            ObjectUtil.cloneSmartElementsAsArrayList(this.getCostParameters()) );
        return clone;
    }

    @Override
    public abstract Double evaluatePerformance(
        Collection> data );

    public Double evaluate(
        Evaluator evaluator )
    {
        ArrayList> targetEstimatePairs =
            new ArrayList>( this.getCostParameters().size() );

        for (InputOutputPair io
            : this.getCostParameters())
        {
        	TargetType target = io.getOutput();
            TargetType estimate = evaluator.evaluate(io.getInput());
            targetEstimatePairs.add(DefaultWeightedTargetEstimatePair.create(
                target, estimate, DatasetUtil.getWeight(io)));
        }

        return this.evaluatePerformance( targetEstimatePairs );
    }

    public Collection> getCostParameters()
    {
        return this.costParameters;
    }

    public void setCostParameters(
        Collection> costParameters )
    {
        this.costParameters = costParameters;
    }

    @Override
    public Double summarize(
        Collection> data )
    {
        return this.evaluatePerformance(data);
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy