gov.sandia.cognition.learning.function.cost.ParallelizedCostFunctionContainer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: ParallelizedCostFunctionContainer.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Sep 22, 2008, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.function.cost;
import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.SequentialDataMultiPartitioner;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* A cost function that automatically splits a ParallelizableCostFunction
* across multiple cores/processors to speed up computation.
* @author Kevin R. Dixon
* @since 2.1
*/
public class ParallelizedCostFunctionContainer
extends AbstractSupervisedCostFunction
implements DifferentiableCostFunction,
ParallelAlgorithm
{
/**
* Cost function to parallelize
*/
private ParallelizableCostFunction costFunction;
/**
* Collection of evaluation thread calls
*/
private transient ArrayList> evaluationComponents;
/**
* Collection of evaluation gradient calls
*/
private transient ArrayList> gradientComponents;
/**
* Thread pool used to parallelize the computation
*/
private transient ThreadPoolExecutor threadPool;
/**
* Default constructor for ParallelizedCostFunctionContainer.
*/
public ParallelizedCostFunctionContainer()
{
this( (ParallelizableCostFunction) null );
}
/**
* Creates a new instance of ParallelizedCostFunctionContainer
* @param costFunction
* Cost function to parallelize
*/
public ParallelizedCostFunctionContainer(
ParallelizableCostFunction costFunction )
{
this( costFunction, ParallelUtil.createThreadPool() );
}
/**
* Creates a new instance of ParallelizedCostFunctionContainer
* @param threadPool
* Thread pool used to parallelize the computation
* @param costFunction
* Cost function to parallelize
*/
public ParallelizedCostFunctionContainer(
ParallelizableCostFunction costFunction,
ThreadPoolExecutor threadPool )
{
this.setCostFunction( costFunction );
this.setThreadPool( threadPool );
}
@Override
public ParallelizedCostFunctionContainer clone()
{
ParallelizedCostFunctionContainer clone =
(ParallelizedCostFunctionContainer) super.clone();
clone.setCostFunction( ObjectUtil.cloneSafe( this.getCostFunction() ) );
clone.setThreadPool(
ParallelUtil.createThreadPool( this.getNumThreads() ) );
return clone;
}
/**
* Getter for costFunction
* @return
* Cost function to parallelize
*/
public ParallelizableCostFunction getCostFunction()
{
return this.costFunction;
}
/**
* Setter for costFunction
* @param costFunction
* Cost function to parallelize
*/
public void setCostFunction(
ParallelizableCostFunction costFunction )
{
this.costFunction = costFunction;
this.evaluationComponents = null;
this.gradientComponents = null;
}
/**
* Splits the data across the numComponents cost functions
*/
protected void createPartitions()
{
int numThreads = this.getNumThreads();
ArrayList>> partitions =
SequentialDataMultiPartitioner.create(
this.getCostParameters(), numThreads );
this.evaluationComponents = new ArrayList>( numThreads );
this.gradientComponents = new ArrayList>( numThreads );
for( int i = 0; i < numThreads; i++ )
{
ParallelizableCostFunction subcost =
(ParallelizableCostFunction) this.getCostFunction().clone();
subcost.setCostParameters( partitions.get(i) );
this.evaluationComponents.add( new SubCostEvaluate( subcost, null ) );
this.gradientComponents.add( new SubCostGradient( subcost, null ) );
}
}
@Override
public void setCostParameters(
Collection extends InputOutputPair extends Vector, Vector>> costParameters )
{
super.setCostParameters( costParameters );
this.evaluationComponents = null;
this.gradientComponents = null;
}
@Override
public Double evaluate(
Evaluator super Vector, ? extends Vector> evaluator )
{
if( this.evaluationComponents == null )
{
this.createPartitions();
}
// Set the subtasks
for( Callable