org.apache.sysml.hops.rewrite.ProgramRewriter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.LanguageException;
import org.apache.sysml.parser.ParForStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;
/**
* This program rewriter applies a variety of rule-based rewrites
* on all hop dags of the given program in one pass over the entire
* program.
*
*/
public class ProgramRewriter
{
private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName());
//internal local debug level
private static final boolean LDEBUG = false;
private static final boolean CHECK = false;
private ArrayList _dagRuleSet = null;
private ArrayList _sbRuleSet = null;
static{
// for internal debugging only
if( LDEBUG ) {
Logger.getLogger("org.apache.sysml.hops.rewrite")
.setLevel((Level) Level.DEBUG);
}
}
public ProgramRewriter()
{
// by default which is used during initial compile
// apply all (static and dynamic) rewrites
this( true, true );
}
public ProgramRewriter( boolean staticRewrites, boolean dynamicRewrites )
{
//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
_dagRuleSet = new ArrayList();
//initialize StatementBlock rewrite ruleSet (with fixed rewrite order)
_sbRuleSet = new ArrayList();
//STATIC REWRITES (which do not rely on size information)
if( staticRewrites )
{
//add static HOP DAG rewrite rules
_dagRuleSet.add( new RewriteTransientWriteParentHandling() );
_dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize
_dagRuleSet.add( new RewriteBlockSizeAndReblock() );
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) //dependency: simplifications (no need to merge leafs again)
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
_dagRuleSet.add( new RewriteIndexingVectorization() ); //dependency: cse, simplifications
_dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock
//add statment block rewrite rules
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
_sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding
if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS )
_sbRuleSet.add( new RewriteSplitDagUnknownCSVRead() ); //dependency: reblock
if( OptimizerUtils.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS )
_sbRuleSet.add( new RewriteSplitDagDataDependentOperators() );
if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
_sbRuleSet.add( new RewriteForLoopVectorization() ); //dependency: reblock (reblockop)
_sbRuleSet.add( new RewriteInjectSparkLoopCheckpointing(true) ); //dependency: reblock (blocksizes)
}
// DYNAMIC REWRITES (which do require size information)
if( dynamicRewrites )
{
_dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
{
_dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse
_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse
}
//reapply cse after rewrites because (1) applied rewrites on operators w/ multiple parents, and
//(2) newly introduced operators potentially created redundancy (incl leaf merge to allow for cse)
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
_dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) ); //dependency: simplifications
}
}
/**
* Construct a program rewriter for a given rewrite which is passed from outside.
*
* @param rewrite
*/
public ProgramRewriter( HopRewriteRule rewrite )
{
//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
_dagRuleSet = new ArrayList();
_dagRuleSet.add( rewrite );
_sbRuleSet = new ArrayList();
}
/**
* Construct a program rewriter for a given rewrite which is passed from outside.
*
* @param rewrite
*/
public ProgramRewriter( StatementBlockRewriteRule rewrite )
{
//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
_dagRuleSet = new ArrayList();
_sbRuleSet = new ArrayList();
_sbRuleSet.add( rewrite );
}
/**
* Construct a program rewriter for the given rewrite sets which are passed from outside.
*
* @param rewrite
*/
public ProgramRewriter( ArrayList hRewrites, ArrayList sbRewrites )
{
//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
_dagRuleSet = new ArrayList();
_dagRuleSet.addAll( hRewrites );
_sbRuleSet = new ArrayList();
_sbRuleSet.addAll( sbRewrites );
}
/**
*
* @param dmlp
* @return
* @throws LanguageException
* @throws HopsException
*/
public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp)
throws LanguageException, HopsException
{
ProgramRewriteStatus state = new ProgramRewriteStatus();
// for each namespace, handle function statement blocks
for (String namespaceKey : dmlp.getNamespaces().keySet())
for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet())
{
FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname);
rewriteStatementBlockHopDAGs(fsblock, state);
rewriteStatementBlock(fsblock, state);
}
// handle regular statement blocks in "main" method
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++)
{
StatementBlock current = dmlp.getStatementBlock(i);
rewriteStatementBlockHopDAGs(current, state);
}
dmlp.setStatementBlocks( rewriteStatementBlocks(dmlp.getStatementBlocks(), state) );
return state;
}
/**
*
* @param current
* @throws LanguageException
* @throws HopsException
*/
public void rewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state)
throws LanguageException, HopsException
{
//ensure robustness for calls from outside
if( state == null )
state = new ProgramRewriteStatus();
if (current instanceof FunctionStatementBlock)
{
FunctionStatementBlock fsb = (FunctionStatementBlock)current;
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
for (StatementBlock sb : fstmt.getBody())
rewriteStatementBlockHopDAGs(sb, state);
}
else if (current instanceof WhileStatementBlock)
{
WhileStatementBlock wsb = (WhileStatementBlock) current;
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
wsb.setPredicateHops(rewriteHopDAG(wsb.getPredicateHops(), state));
for (StatementBlock sb : wstmt.getBody())
rewriteStatementBlockHopDAGs(sb, state);
}
else if (current instanceof IfStatementBlock)
{
IfStatementBlock isb = (IfStatementBlock) current;
IfStatement istmt = (IfStatement)isb.getStatement(0);
isb.setPredicateHops(rewriteHopDAG(isb.getPredicateHops(), state));
for (StatementBlock sb : istmt.getIfBody())
rewriteStatementBlockHopDAGs(sb, state);
for (StatementBlock sb : istmt.getElseBody())
rewriteStatementBlockHopDAGs(sb, state);
}
else if (current instanceof ForStatementBlock) //incl parfor
{
ForStatementBlock fsb = (ForStatementBlock) current;
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
fsb.setFromHops(rewriteHopDAG(fsb.getFromHops(), state));
fsb.setToHops(rewriteHopDAG(fsb.getToHops(), state));
fsb.setIncrementHops(rewriteHopDAG(fsb.getIncrementHops(), state));
for (StatementBlock sb : fstmt.getBody())
rewriteStatementBlockHopDAGs(sb, state);
}
else //generic (last-level)
{
current.set_hops( rewriteHopDAGs(current.get_hops(), state) );
}
}
/**
*
* @param roots
* @throws LanguageException
* @throws HopsException
*/
public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state)
throws HopsException
{
for( HopRewriteRule r : _dagRuleSet )
{
Hop.resetVisitStatus( roots ); //reset for each rule
roots = r.rewriteHopDAGs(roots, state);
if( CHECK ) {
LOG.info("Validation after: "+r.getClass().getName());
HopDagValidator.validateHopDag(roots);
}
}
return roots;
}
/**
*
* @param root
* @throws LanguageException
* @throws HopsException
*/
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
throws HopsException
{
for( HopRewriteRule r : _dagRuleSet )
{
root.resetVisitStatus(); //reset for each rule
root = r.rewriteHopDAG(root, state);
if( CHECK ) {
LOG.info("Validation after: "+r.getClass().getName());
HopDagValidator.validateHopDag(root);
}
}
return root;
}
/**
*
* @param sbs
* @return
* @throws HopsException
*/
public ArrayList rewriteStatementBlocks( ArrayList sbs, ProgramRewriteStatus state )
throws HopsException
{
//ensure robustness for calls from outside
if( state == null )
state = new ProgramRewriteStatus();
ArrayList tmp = new ArrayList();
//rewrite statement blocks (with potential expansion)
for( StatementBlock sb : sbs )
tmp.addAll( rewriteStatementBlock(sb, state) );
//copy results into original collection
sbs.clear();
sbs.addAll( tmp );
return sbs;
}
/**
*
* @param sb
* @return
* @throws HopsException
*/
private ArrayList rewriteStatementBlock( StatementBlock sb, ProgramRewriteStatus status )
throws HopsException
{
ArrayList ret = new ArrayList();
ret.add(sb);
//recursive invocation
if (sb instanceof FunctionStatementBlock)
{
FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
fstmt.setBody( rewriteStatementBlocks(fstmt.getBody(), status) );
}
else if (sb instanceof WhileStatementBlock)
{
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
wstmt.setBody( rewriteStatementBlocks( wstmt.getBody(), status ) );
}
else if (sb instanceof IfStatementBlock)
{
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement)isb.getStatement(0);
istmt.setIfBody( rewriteStatementBlocks( istmt.getIfBody(), status ) );
istmt.setElseBody( rewriteStatementBlocks( istmt.getElseBody(), status ) );
}
else if (sb instanceof ForStatementBlock) //incl parfor
{
//maintain parfor context information (e.g., for checkpointing)
boolean prestatus = status.isInParforContext();
if( sb instanceof ParForStatementBlock )
status.setInParforContext(true);
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
fstmt.setBody( rewriteStatementBlocks(fstmt.getBody(), status) );
status.setInParforContext(prestatus);
}
//apply rewrite rules
for( StatementBlockRewriteRule r : _sbRuleSet )
{
ArrayList tmp = new ArrayList();
for( StatementBlock sbc : ret )
tmp.addAll( r.rewriteStatementBlock(sbc, status) );
//take over set of rewritten sbs
ret.clear();
ret.addAll(tmp);
}
return ret;
}
}