gov.sandia.cognition.math.matrix.custom.ParallelMatrixMergeFunction Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: ParallelMatrixMergeFunction.java
* Authors: Jeremy D. Wendt
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright 2015, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*/
package gov.sandia.cognition.math.matrix.custom;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
/**
* This package-private class simplifies parallelizing Matrix operations. It
* uses generics for defining the two (possibly different) input types, and the
* result of merging the parallel-solved pieces. This should be used for
* parallel operations where the result is solved in-place on a single value.
* For instance in dot products, the result is a single double. Thus, each
* parallel piece solves on a different double and then merge combines those
* doubles into a single double.
*
* @author Jeremy D. Wendt
* @since 3.4.3
* @param The type of the left operand
* @param The type of the right operand
* @param The type of the final result
*/
abstract class ParallelMatrixMergeFunction
implements Callable
{
/**
* The minimum index for the row that should be used in the operation (row
* on input1, likely column on input2).
*/
protected int minRow;
/**
* The maximum index for the row that should be used in the operation
* (not-inclusive (as in the for loop goes i = minRow; i < maxRow); row
* on input1, likely column on input2).
*/
protected int maxRow;
/**
* The left-part of the operation. For instance, in vector-vector dot
* product, this is the vector type of the left side. Each thread should
* only read from rows between minRow and maxRow (for caching purposes).
* This should not be changed at all during the operations.
*/
protected InputType1 input1;
/**
* The right-part of the operation. For instance, in vector-vector dot
* product, this is the vector type of the right side. This should not be
* changed at all during the operations.
*/
protected InputType2 input2;
/**
* Private because this should never be called. Ever. No matter what.
*/
private ParallelMatrixMergeFunction()
{
throw new UnsupportedOperationException(
"Null constructor not supported.");
}
/**
* Passes in the necessary arguments to initialize an instance. Shallow
* copies of all inputs are made.
*
* @param input1 The first input
* @param input2 The second input
* @param minRow The minimum row for this thread to operate on
* @param maxRow The maximum row (not inclusive) for this thread to operate
* on
*/
public ParallelMatrixMergeFunction(
final InputType1 input1,
final InputType2 input2,
final int minRow,
final int maxRow)
{
this.input1 = input1;
this.input2 = input2;
this.minRow = minRow;
this.maxRow = maxRow;
}
/**
* This needs to be extended by operation-specific classes.
*
* @return The solution for solving this piece of the operation -- the
* results will be merged in merge (not in parallel).
* @throws Exception Part of the interface. Please don't throw exceptions
* unless you really need to.
*/
@Override
abstract public MergeType call()
throws Exception;
/**
* This method will only be called on one instance and won't use internal
* state to merge the results of all of the piece's call methods.
*
* @param pieces The results from all of the pieces
* @return The merged, final result
*/
abstract protected MergeType merge(
final List> pieces);
/**
* This static method handles all the logic of splitting up the chunks of a
* matrix problem, calling the chunks in parallel, and merging the results.
*
* @param The type for the left operand
* @param The type for the right operand
* @param The type for the merged result
* @param input1 The left operand
* @param input2 The right operand
* @param numPieces The number of pieces to split the problem into -- can be
* more than the number of threads if you think the pieces may be non-equal
* in size.
* @param numThreads The number of threads to create for solving the problem
* @param numRows The number of rows in the problem (usually input1's
* numRows)
* @param factory The factory for creating ParallelMatrixFunction instnaces
*/
public static MergeType solve(
final InputType1 input1,
final InputType2 input2,
final int numPieces,
final int numThreads,
final int numRows,
final ParallelMatrixMergeFunction.Factory factory)
{
double numRowsPer = numRows / ((double) numPieces);
numRowsPer = Math.max(numRowsPer, 1.0);
List> pieces =
new ArrayList<>(
numPieces);
int minRow, maxRow;
minRow = 0;
for (int i = 0; i < numPieces; ++i)
{
if (i == (numPieces - 1))
{
maxRow = numRows;
}
else
{
maxRow = (int) Math.round((i + 1) * numRowsPer);
}
maxRow = Math.min(maxRow, numRows);
pieces.add(factory.init(input1, input2, minRow, maxRow));
minRow = maxRow;
// Break out early if there were more pieces than rows
if (minRow >= numRows)
{
break;
}
}
ExecutorService threads = Executors.newFixedThreadPool(numThreads);
try
{
List> results = threads.invokeAll(
pieces);
threads.shutdown();
return pieces.get(0).merge(results);
}
catch (InterruptedException e)
{
threads.shutdown();
throw new RuntimeException("Threads stopped prematurely", e);
}
}
/**
* A factory for creating the necessary parallel-aware solvers
*
* @param The left input's type
* @param The right input's type
* @param The output type
*/
public static interface Factory
{
/**
* Creates an instance of the parallel-aware solver with the input
* values stored for the call method.
*
* @param input1 The left input
* @param input2 The right input
* @param minRow The minimum row to affect (inclusive)
* @param maxRow The maximum row to affect (not inclusive)
* @return A new instance of the correct parallel-aware solver with the
* input values stored for the call method.
*/
ParallelMatrixMergeFunction init(
InputType1 input1,
InputType2 input2,
int minRow,
int maxRow);
}
}