org.apache.sysml.hops.rewrite.RewriteMatrixMultChainOptimization Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.utils.Explain;
/**
* Rule: Determine the optimal order of execution for a chain of
* matrix multiplications Solution: Classic Dynamic Programming
* Approach Currently, the approach based only on matrix dimensions
* Goal: To reduce the number of computations in the run-time
* (map-reduce) layer
*/
public class RewriteMatrixMultChainOptimization extends HopRewriteRule
{
private static final Log LOG = LogFactory.getLog(RewriteMatrixMultChainOptimization.class.getName());
private static final boolean LDEBUG = false;
static
{
// for internal debugging only
if( LDEBUG ) {
Logger.getLogger("org.apache.sysml.hops.rewrite.RewriteMatrixMultChainOptimization")
.setLevel((Level) Level.TRACE);
}
}
@Override
public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state)
throws HopsException
{
if( roots == null )
return null;
for( Hop h : roots )
{
// Find the optimal order for the chain whose result is the current HOP
rule_OptimizeMMChains(h);
}
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
throws HopsException
{
if( root == null )
return null;
// Find the optimal order for the chain whose result is the current HOP
rule_OptimizeMMChains(root);
return root;
}
/**
* rule_OptimizeMMChains(): This method recurses through all Hops in the DAG
* to find chains that need to be optimized.
*
* @param hop high-level operator
* @throws HopsException if HopsException occurs
*/
private void rule_OptimizeMMChains(Hop hop)
throws HopsException
{
if(hop.isVisited())
return;
if ( HopRewriteUtils.isMatrixMultiply(hop)
&& !((AggBinaryOp)hop).hasLeftPMInput() && !hop.isVisited() )
{
// Try to find and optimize the chain in which current Hop is the
// last operator
optimizeMMChain(hop);
}
for (Hop hi : hop.getInput())
rule_OptimizeMMChains(hi);
hop.setVisited();
}
/**
* optimizeMMChain(): It optimizes the matrix multiplication chain in which
* the last Hop is "this". Step-1) Identify the chain (mmChain). (Step-2) clear all
* links among the Hops that are involved in mmChain. (Step-3) Find the
* optimal ordering (dynamic programming) (Step-4) Relink the hops in
* mmChain.
*
* @param hop high-level operator
* @throws HopsException if HopsException occurs
*/
private void optimizeMMChain( Hop hop ) throws HopsException
{
if( LOG.isTraceEnabled() ) {
LOG.trace("MM Chain Optimization for HOP: (" + " " + hop.getClass().getSimpleName() + ", " + hop.getHopID() + ", "
+ hop.getName() + ")");
}
ArrayList mmChain = new ArrayList();
ArrayList mmOperators = new ArrayList();
ArrayList tempList;
// Step 1: Identify the chain (mmChain) & clear all links among the Hops
// that are involved in mmChain.
mmOperators.add(hop);
// Initialize mmChain with my inputs
for (Hop hi : hop.getInput()) {
mmChain.add(hi);
}
// expand each Hop in mmChain to find the entire matrix multiplication
// chain
int i = 0;
while (i < mmChain.size()) {
boolean expandable = false;
Hop h = mmChain.get(i);
/*
* Check if mmChain[i] is expandable:
* 1) It must be MATMULT
* 2) It must not have been visited already
* (one MATMULT should get expanded only in one chain)
* 3) Its output should not be used in multiple places
* (either within chain or outside the chain)
*/
if ( HopRewriteUtils.isMatrixMultiply(h)
&& !((AggBinaryOp)hop).hasLeftPMInput() && h.isVisited() )
{
// check if the output of "h" is used at multiple places. If yes, it can
// not be expanded.
if (h.getParent().size() > 1 || inputCount( (Hop) ((h.getParent().toArray())[0]), h) > 1 ) {
expandable = false;
break;
}
else
expandable = true;
}
h.setVisited();
if ( !expandable ) {
i = i + 1;
} else {
tempList = mmChain.get(i).getInput();
if (tempList.size() != 2) {
throw new HopsException(hop.printErrorLocation() + "Hops::rule_OptimizeMMChain(): AggBinary must have exactly two inputs.");
}
// add current operator to mmOperators, and its input nodes to mmChain
mmOperators.add(mmChain.get(i));
mmChain.set(i, tempList.get(0));
mmChain.add(i + 1, tempList.get(1));
}
}
// print the MMChain
if( LOG.isTraceEnabled() ) {
LOG.trace("Identified MM Chain: ");
for (Hop h : mmChain) {
logTraceHop(h, 1);
}
}
if (mmChain.size() == 2) {
// If the chain size is 2, then there is nothing to optimize.
return;
}
else
{
// Step 2: construct dims array
double[] dimsArray = new double[mmChain.size() + 1];
boolean dimsKnown = getDimsArray( hop, mmChain, dimsArray );
if( dimsKnown ) {
// Step 3: clear the links among Hops within the identified chain
clearLinksWithinChain ( hop, mmOperators );
// Step 4: Find the optimal ordering via dynamic programming.
// Invoke Dynamic Programming
int size = mmChain.size();
int[][] split = mmChainDP(dimsArray, mmChain.size());
// Step 5: Relink the hops using the optimal ordering (split[][]) found from DP.
LOG.trace("Optimal MM Chain: ");
mmChainRelinkHops(mmOperators.get(0), 0, size - 1, mmChain, mmOperators, 1, split, 1);
}
}
}
/**
* mmChainDP(): Core method to perform dynamic programming on a given array
* of matrix dimensions.
*
* Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford Stein
* Introduction to Algorithms, Third Edition, MIT Press, page 395.
*/
private int[][] mmChainDP(double[] dimArray, int size)
{
double[][] dpMatrix = new double[size][size]; //min cost table
int[][] split = new int[size][size]; //min cost index table
//init minimum costs for chains of length 1
for (int i = 0; i < size; i++) {
Arrays.fill(dpMatrix[i], 0);
Arrays.fill(split[i], -1);
}
//compute cost-optimal chains for increasing chain sizes
for (int l = 2; l <= size; l++) { // chain length
for (int i = 0; i < size - l + 1; i++) {
int j = i + l - 1;
// find cost of (i,j)
dpMatrix[i][j] = Double.MAX_VALUE;
for (int k = i; k <= j - 1; k++)
{
//recursive cost computation
double cost = dpMatrix[i][k] + dpMatrix[k + 1][j]
+ (dimArray[i] * dimArray[k + 1] * dimArray[j + 1]);
//prune suboptimal
if (cost < dpMatrix[i][j]) {
dpMatrix[i][j] = cost;
split[i][j] = k;
}
}
if( LOG.isTraceEnabled() ){
LOG.trace("mmchainopt [i="+(i+1)+",j="+(j+1)+"]: costs = "+dpMatrix[i][j]+", split = "+(split[i][j]+1));
}
}
}
return split;
}
/**
* mmChainRelinkHops(): This method gets invoked after finding the optimal
* order (split[][]) from dynamic programming. It relinks the Hops that are
* part of the mmChain. mmChain : basic operands in the entire matrix
* multiplication chain. mmOperators : Hops that store the intermediate
* results in the chain. For example: A = B %*% (C %*% D) there will be
* three Hops in mmChain (B,C,D), and two Hops in mmOperators (one for each
* %*%) .
*/
private void mmChainRelinkHops(Hop h, int i, int j, ArrayList mmChain, ArrayList mmOperators,
int opIndex, int[][] split, int level)
{
//single matrix - end of recursion
if (i == j) {
logTraceHop(h, level);
return;
}
if( LOG.isTraceEnabled() ){
String offset = Explain.getIdentation(level);
LOG.trace(offset + "(");
}
// Set Input1 for current Hop h
if (i == split[i][j]) {
h.getInput().add(mmChain.get(i));
mmChain.get(i).getParent().add(h);
} else {
h.getInput().add(mmOperators.get(opIndex));
mmOperators.get(opIndex).getParent().add(h);
opIndex = opIndex + 1;
}
// Set Input2 for current Hop h
if (split[i][j] + 1 == j) {
h.getInput().add(mmChain.get(j));
mmChain.get(j).getParent().add(h);
} else {
h.getInput().add(mmOperators.get(opIndex));
mmOperators.get(opIndex).getParent().add(h);
opIndex = opIndex + 1;
}
// Find children for both the inputs
mmChainRelinkHops(h.getInput().get(0), i, split[i][j], mmChain, mmOperators, opIndex, split, level+1);
mmChainRelinkHops(h.getInput().get(1), split[i][j] + 1, j, mmChain, mmOperators, opIndex, split, level+1);
// Propagate properties of input hops to current hop h
h.refreshSizeInformation();
if( LOG.isTraceEnabled() ){
String offset = Explain.getIdentation(level);
LOG.trace(offset + ")");
}
}
private void clearLinksWithinChain ( Hop hop, ArrayList operators )
throws HopsException
{
Hop op, input1, input2;
for ( int i=0; i < operators.size(); i++ ) {
op = operators.get(i);
if ( op.getInput().size() != 2 || (i != 0 && op.getParent().size() > 1 ) ) {
throw new HopsException(hop.printErrorLocation() + "Unexpected error while applying optimization on matrix-mult chain. \n");
}
input1 = op.getInput().get(0);
input2 = op.getInput().get(1);
op.getInput().clear();
input1.getParent().remove(op);
input2.getParent().remove(op);
}
}
/**
* Obtains all dimension information of the chain and constructs the dimArray.
* If all dimensions are known it returns true; othrewise the mmchain rewrite
* should be ended without modifications.
*
* @param hop high-level operator
* @param chain list of high-level operators
* @param dimArray dimension array
* @return true if all dimensions known
* @throws HopsException if HopsException occurs
*/
private boolean getDimsArray( Hop hop, ArrayList chain, double[] dimsArray )
throws HopsException
{
boolean dimsKnown = true;
// Build the array containing dimensions from all matrices in the chain
// check the dimensions in the matrix chain to insure all dimensions are known
for (int i=0; i< chain.size(); i++){
if (chain.get(i).getDim1() <= 0 || chain.get(i).getDim2() <= 0)
dimsKnown = false;
}
if( dimsKnown ) { //populate dims array if all dims known
for (int i = 0; i < chain.size(); i++)
{
if (i == 0) {
dimsArray[i] = chain.get(i).getDim1();
if (dimsArray[i] <= 0) {
throw new HopsException(hop.printErrorLocation() +
"Hops::optimizeMMChain() : Invalid Matrix Dimension: "+ dimsArray[i]);
}
} else {
if (chain.get(i - 1).getDim2() != chain.get(i).getDim1()) {
throw new HopsException(hop.printErrorLocation() +
"Hops::optimizeMMChain() : Matrix Dimension Mismatch: "+chain.get(i - 1).getDim2()+" != "+chain.get(i).getDim1());
}
}
dimsArray[i + 1] = chain.get(i).getDim2();
if (dimsArray[i + 1] <= 0) {
throw new HopsException(hop.printErrorLocation() +
"Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i + 1]);
}
}
}
return dimsKnown;
}
private int inputCount ( Hop p, Hop h ) {
int count = 0;
for ( int i=0; i < p.getInput().size(); i++ )
if ( p.getInput().get(i).equals(h) )
count++;
return count;
}
private void logTraceHop( Hop hop, int level )
{
if( LOG.isTraceEnabled() ) {
String offset = Explain.getIdentation(level);
LOG.trace(offset+ "Hop " + hop.getName() + "(" + hop.getClass().getSimpleName() + ", " + hop.getHopID() + ")" + " "
+ hop.getDim1() + "x" + hop.getDim2());
}
}
}