All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.function.cost.ParallelizedCostFunctionContainer Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                ParallelizedCostFunctionContainer.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Sep 22, 2008, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government. 
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.function.cost;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.SequentialDataMultiPartitioner;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * A cost function that automatically splits a ParallelizableCostFunction
 * across multiple cores/processors to speed up computation.
 * @author Kevin R. Dixon
 * @since 2.1
 */
public class ParallelizedCostFunctionContainer
    extends AbstractSupervisedCostFunction
    implements DifferentiableCostFunction,
    ParallelAlgorithm
{
    
    /**
     * Cost function to parallelize
     */
    private ParallelizableCostFunction costFunction;

    /**
     * Collection of evaluation thread calls
     */
    private transient ArrayList> evaluationComponents;
    
    /**
     * Collection of evaluation gradient calls
     */
    private transient ArrayList> gradientComponents;
    
    /**
     * Thread pool used to parallelize the computation
     */
    private transient ThreadPoolExecutor threadPool;

    /**
     * Default constructor for ParallelizedCostFunctionContainer.
     */
    public ParallelizedCostFunctionContainer()
    {
        this( (ParallelizableCostFunction) null );
    }
    
    /**
     * Creates a new instance of ParallelizedCostFunctionContainer
     * @param costFunction
     * Cost function to parallelize
     */
    public ParallelizedCostFunctionContainer(
        ParallelizableCostFunction costFunction )
    {
        this( costFunction, ParallelUtil.createThreadPool() );
    }
    
    /**
     * Creates a new instance of ParallelizedCostFunctionContainer
     * @param threadPool 
     * Thread pool used to parallelize the computation
     * @param costFunction
     * Cost function to parallelize
     */
    public ParallelizedCostFunctionContainer(
        ParallelizableCostFunction costFunction,
        ThreadPoolExecutor threadPool )
    {
        this.setCostFunction( costFunction );
        this.setThreadPool( threadPool );
    }       
    
    @Override
    public ParallelizedCostFunctionContainer clone()
    {
        ParallelizedCostFunctionContainer clone =
            (ParallelizedCostFunctionContainer) super.clone();
        clone.setCostFunction( ObjectUtil.cloneSafe( this.getCostFunction() ) );
        clone.setThreadPool(
            ParallelUtil.createThreadPool( this.getNumThreads() ) );
        return clone;
    }    
    
    /**
     * Getter for costFunction
     * @return
     * Cost function to parallelize
     */
    public ParallelizableCostFunction getCostFunction()
    {
        return this.costFunction;
    }
    
    /**
     * Setter for costFunction
     * @param costFunction
     * Cost function to parallelize
     */
    public void setCostFunction(
        ParallelizableCostFunction costFunction )
    {
        this.costFunction = costFunction;
        this.evaluationComponents = null;
        this.gradientComponents = null;
    }
    
    /**
     * Splits the data across the numComponents cost functions
     */
    protected void createPartitions()
    {
        int numThreads = this.getNumThreads();
        ArrayList>> partitions =
            SequentialDataMultiPartitioner.create(
                this.getCostParameters(), numThreads );
        this.evaluationComponents = new ArrayList>( numThreads );
        this.gradientComponents = new ArrayList>( numThreads );
        for( int i = 0; i < numThreads; i++ )
        {
            ParallelizableCostFunction subcost =
                (ParallelizableCostFunction) this.getCostFunction().clone();
            subcost.setCostParameters( partitions.get(i) );
            this.evaluationComponents.add( new SubCostEvaluate( subcost, null ) );
            this.gradientComponents.add( new SubCostGradient( subcost, null ) );
        }
        
    }

    @Override
    public void setCostParameters(
        Collection> costParameters )
    {
        super.setCostParameters( costParameters );
        this.evaluationComponents = null;
        this.gradientComponents = null;
    }
    
    @Override
    public Double evaluate(
        Evaluator evaluator )
    {
        
        if( this.evaluationComponents == null )
        {
            this.createPartitions();
        }
        
        // Set the subtasks
        for( Callable sce : this.evaluationComponents )
        {
            ((SubCostEvaluate) sce).evaluator = evaluator;
        }
        
        Collection partialResults = null;
        try
        {
            partialResults = ParallelUtil.executeInParallel(
                this.evaluationComponents, this.getThreadPool() );
        }
        catch (Exception ex)
        {
            Logger.getLogger( ParallelizedCostFunctionContainer.class.getName() ).log( Level.SEVERE, null, ex );
        }
        
        return this.getCostFunction().evaluateAmalgamate( partialResults );
        
    }
    
    
    @Override
    public Double evaluatePerformance(
        Collection> data )
    {
        return this.getCostFunction().evaluatePerformance( data );
    }

    public Vector computeParameterGradient(
        GradientDescendable function )
    {
        
        if (this.gradientComponents == null)
        {
            this.createPartitions();
        }

        // Create the subtasks
        for (Callable eval : this.gradientComponents)
        {
            ((SubCostGradient) eval).evaluator = function;
        }

        Collection results = null;
        try
        {
            results = ParallelUtil.executeInParallel(
                this.gradientComponents, this.getThreadPool() );
        }
        catch (Exception ex)
        {
            Logger.getLogger( ParallelizedCostFunctionContainer.class.getName() ).log( Level.SEVERE, null, ex );
        }
        
        return this.getCostFunction().computeParameterGradientAmalgamate( results );
        
    }

    public ThreadPoolExecutor getThreadPool()
    {
        if( this.threadPool == null )
        {
            this.setThreadPool( ParallelUtil.createThreadPool() );
        }
        
        return this.threadPool;
    }

    public void setThreadPool(
        ThreadPoolExecutor threadPool )
    {
        this.threadPool = threadPool;
    }

    public int getNumThreads()
    {
        return ParallelUtil.getNumThreads( this );
    }
    
    /**
     * Creates the thread pool using the Foundry's global thread pool.
     */
    protected void createThreadPool()
    {
        this.setThreadPool( ParallelUtil.createThreadPool() );
    }

    /**
     * Callable task for the evaluate() method.
     */
    protected static class SubCostEvaluate
        implements Callable
    {
        
        /**
         * Parallel cost function
         */
        private ParallelizableCostFunction costFunction;
        
        /**
         * Evaluator for which to compute the cost
         */
        private Evaluator evaluator;
        
        /**
         * Creates a new instance of SubCostEvaluate
         * @param costFunction
         * Parallel cost function
         * @param evaluator
         * Evaluator for which to compute the cost
         */
        public SubCostEvaluate(
            ParallelizableCostFunction costFunction,
            Evaluator evaluator )
        {
            this.costFunction = costFunction;
            this.evaluator = evaluator;
        }

        public Object call()
        {
            return this.costFunction.evaluatePartial( this.evaluator );
        }
        
    }
    
    /**
     * Callable task for the computeGradient() method
     */
    protected static class SubCostGradient
        implements Callable
    {
        
        /**
         * Parallel cost function
         */
        private ParallelizableCostFunction costFunction;
        
        /**
         * Function for which to compute the gradient
         */
        private GradientDescendable evaluator;
        
        /**
         * Creates a new instance of SubCostGradient
         * @param costFunction
         * Parallel cost function
         * @param evaluator
         * Function for which to compute the gradient
         */
        public SubCostGradient(
            ParallelizableCostFunction costFunction,
            GradientDescendable evaluator )
        {
            this.costFunction = costFunction;
            this.evaluator = evaluator;
        }

        public Object call()
        {
            return this.costFunction.computeParameterGradientPartial( this.evaluator );
        }
        
    }

}