All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.hops.rewrite.ProgramRewriter Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.LanguageException;
import org.apache.sysml.parser.ParForStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;

/**
 * This program rewriter applies a variety of rule-based rewrites
 * on all hop dags of the given program in one pass over the entire
 * program. 
 * 
 */
public class ProgramRewriter 
{
	private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName());
	
	//internal local debug level
	private static final boolean LDEBUG = false; 
	private static final boolean CHECK = false;
	
	private ArrayList _dagRuleSet = null;
	private ArrayList _sbRuleSet = null;
	
	static{
		// for internal debugging only
		if( LDEBUG ) {
			Logger.getLogger("org.apache.sysml.hops.rewrite")
				  .setLevel((Level) Level.DEBUG);
		}
		
	}
	
	public ProgramRewriter()
	{
		// by default which is used during initial compile 
		// apply all (static and dynamic) rewrites
		this( true, true );
	}
	
	public ProgramRewriter( boolean staticRewrites, boolean dynamicRewrites )
	{
		//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
		_dagRuleSet = new ArrayList();
		
		//initialize StatementBlock rewrite ruleSet (with fixed rewrite order)
		_sbRuleSet = new ArrayList();
		
		
		//STATIC REWRITES (which do not rely on size information)
		if( staticRewrites )
		{
			//add static HOP DAG rewrite rules
			_dagRuleSet.add(     new RewriteTransientWriteParentHandling()       );
			_dagRuleSet.add(     new RewriteRemoveReadAfterWrite()               ); //dependency: before blocksize
			_dagRuleSet.add(     new RewriteBlockSizeAndReblock()                );
			_dagRuleSet.add(     new RewriteRemoveUnnecessaryCasts()             );		
			if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
				_dagRuleSet.add( new RewriteCommonSubexpressionElimination()     );
			if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
				_dagRuleSet.add( new RewriteConstantFolding()                    ); //dependency: cse
			if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
				_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic()      ); //dependencies: cse
			if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )             //dependency: simplifications (no need to merge leafs again)
				_dagRuleSet.add( new RewriteCommonSubexpressionElimination()     ); 
			if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
				_dagRuleSet.add( new RewriteIndexingVectorization()              ); //dependency: cse, simplifications
			_dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing()          ); //dependency: reblock
			
			//add statment block rewrite rules
 			if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )			
				_sbRuleSet.add(  new RewriteRemoveUnnecessaryBranches()          ); //dependency: constant folding		
 			if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS )
 				_sbRuleSet.add(  new RewriteSplitDagUnknownCSVRead()             ); //dependency: reblock	
 			if( OptimizerUtils.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS )
 				_sbRuleSet.add(  new RewriteSplitDagDataDependentOperators()     );
 			if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
				_sbRuleSet.add(  new RewriteForLoopVectorization()               ); //dependency: reblock (reblockop)
 			_sbRuleSet.add( new RewriteInjectSparkLoopCheckpointing(true)        ); //dependency: reblock (blocksizes)
		}
		
		// DYNAMIC REWRITES (which do require size information)
		if( dynamicRewrites )
		{
			_dagRuleSet.add(     new RewriteMatrixMultChainOptimization()         ); //dependency: cse 
			
			if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
			{
				_dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic()      ); //dependencies: cse
				_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic()       ); //dependencies: cse
			}
			
			//reapply cse after rewrites because (1) applied rewrites on operators w/ multiple parents, and
			//(2) newly introduced operators potentially created redundancy (incl leaf merge to allow for cse)
			if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )             
				_dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) ); //dependency: simplifications 			
		}
	}
	
	/**
	 * Construct a program rewriter for a given rewrite which is passed from outside.
	 * 
	 * @param rewrite
	 */
	public ProgramRewriter( HopRewriteRule rewrite )
	{
		//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
		_dagRuleSet = new ArrayList();
		_dagRuleSet.add( rewrite );		
		
		_sbRuleSet = new ArrayList();
	}
	
	/**
	 * Construct a program rewriter for a given rewrite which is passed from outside.
	 * 
	 * @param rewrite
	 */
	public ProgramRewriter( StatementBlockRewriteRule rewrite )
	{
		//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
		_dagRuleSet = new ArrayList();
		
		_sbRuleSet = new ArrayList();
		_sbRuleSet.add( rewrite );
	}
	
	/**
	 * Construct a program rewriter for the given rewrite sets which are passed from outside.
	 * 
	 * @param rewrite
	 */
	public ProgramRewriter( ArrayList hRewrites, ArrayList sbRewrites )
	{
		//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
		_dagRuleSet = new ArrayList();
		_dagRuleSet.addAll( hRewrites );
		
		_sbRuleSet = new ArrayList();
		_sbRuleSet.addAll( sbRewrites );
	}
	
	/**
	 * 
	 * @param dmlp
	 * @return
	 * @throws LanguageException
	 * @throws HopsException
	 */
	public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp) 
		throws LanguageException, HopsException
	{	
		ProgramRewriteStatus state = new ProgramRewriteStatus();
		
		// for each namespace, handle function statement blocks
		for (String namespaceKey : dmlp.getNamespaces().keySet())
			for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet())
			{
				FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname);
				rewriteStatementBlockHopDAGs(fsblock, state);
				rewriteStatementBlock(fsblock, state);
			}
		
		// handle regular statement blocks in "main" method
		for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) 
		{
			StatementBlock current = dmlp.getStatementBlock(i);
			rewriteStatementBlockHopDAGs(current, state);
		}
		dmlp.setStatementBlocks( rewriteStatementBlocks(dmlp.getStatementBlocks(), state) );
		
		return state;
	}
	
	/**
	 * 
	 * @param current
	 * @throws LanguageException
	 * @throws HopsException
	 */
	public void rewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state) 
		throws LanguageException, HopsException
	{	
		//ensure robustness for calls from outside
		if( state == null )
			state = new ProgramRewriteStatus();
		
		if (current instanceof FunctionStatementBlock)
		{
			FunctionStatementBlock fsb = (FunctionStatementBlock)current;
			FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
			for (StatementBlock sb : fstmt.getBody())
				rewriteStatementBlockHopDAGs(sb, state);
		}
		else if (current instanceof WhileStatementBlock)
		{
			WhileStatementBlock wsb = (WhileStatementBlock) current;
			WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
			wsb.setPredicateHops(rewriteHopDAG(wsb.getPredicateHops(), state));
			for (StatementBlock sb : wstmt.getBody())
				rewriteStatementBlockHopDAGs(sb, state);
		}	
		else if (current instanceof IfStatementBlock)
		{
			IfStatementBlock isb = (IfStatementBlock) current;
			IfStatement istmt = (IfStatement)isb.getStatement(0);
			isb.setPredicateHops(rewriteHopDAG(isb.getPredicateHops(), state));
			for (StatementBlock sb : istmt.getIfBody())
				rewriteStatementBlockHopDAGs(sb, state);
			for (StatementBlock sb : istmt.getElseBody())
				rewriteStatementBlockHopDAGs(sb, state);
		}
		else if (current instanceof ForStatementBlock) //incl parfor
		{
			ForStatementBlock fsb = (ForStatementBlock) current;
			ForStatement fstmt = (ForStatement)fsb.getStatement(0);
			fsb.setFromHops(rewriteHopDAG(fsb.getFromHops(), state));
			fsb.setToHops(rewriteHopDAG(fsb.getToHops(), state));
			fsb.setIncrementHops(rewriteHopDAG(fsb.getIncrementHops(), state));
			for (StatementBlock sb : fstmt.getBody())
				rewriteStatementBlockHopDAGs(sb, state);
		}
		else //generic (last-level)
		{
			current.set_hops( rewriteHopDAGs(current.get_hops(), state) );
		}
	}
	
	/**
	 * 
	 * @param roots
	 * @throws LanguageException
	 * @throws HopsException
	 */
	public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) 
		throws HopsException
	{	
		for( HopRewriteRule r : _dagRuleSet )
		{
			Hop.resetVisitStatus( roots ); //reset for each rule
			roots = r.rewriteHopDAGs(roots, state);
		
			if( CHECK ) {		
				LOG.info("Validation after: "+r.getClass().getName());
				HopDagValidator.validateHopDag(roots);
			}
		}
		
		return roots;
	}
	
	/**
	 * 
	 * @param root
	 * @throws LanguageException
	 * @throws HopsException
	 */
	public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) 
		throws HopsException
	{	
		for( HopRewriteRule r : _dagRuleSet )
		{
			root.resetVisitStatus(); //reset for each rule
			root = r.rewriteHopDAG(root, state);
		
			if( CHECK ) {
				LOG.info("Validation after: "+r.getClass().getName());
				HopDagValidator.validateHopDag(root);
			}
		}
		
		return root;
	}
	
	/**
	 * 
	 * @param sbs
	 * @return
	 * @throws HopsException 
	 */
	public ArrayList rewriteStatementBlocks( ArrayList sbs, ProgramRewriteStatus state ) 
		throws HopsException
	{
		//ensure robustness for calls from outside
		if( state == null )
			state = new ProgramRewriteStatus();
				
		
		ArrayList tmp = new ArrayList();
		
		//rewrite statement blocks (with potential expansion)
		for( StatementBlock sb : sbs )
			tmp.addAll( rewriteStatementBlock(sb, state) );
		
		//copy results into original collection
		sbs.clear();
		sbs.addAll( tmp );
		
		return sbs;
	}
	
	/**
	 * 
	 * @param sb
	 * @return
	 * @throws HopsException
	 */
	private ArrayList rewriteStatementBlock( StatementBlock sb, ProgramRewriteStatus status ) 
		throws HopsException
	{
		ArrayList ret = new ArrayList();
		ret.add(sb);
		
		//recursive invocation
		if (sb instanceof FunctionStatementBlock)
		{
			FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
			FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
			fstmt.setBody( rewriteStatementBlocks(fstmt.getBody(), status) );			
		}
		else if (sb instanceof WhileStatementBlock)
		{
			WhileStatementBlock wsb = (WhileStatementBlock) sb;
			WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
			wstmt.setBody( rewriteStatementBlocks( wstmt.getBody(), status ) );
		}	
		else if (sb instanceof IfStatementBlock)
		{
			IfStatementBlock isb = (IfStatementBlock) sb;
			IfStatement istmt = (IfStatement)isb.getStatement(0);
			istmt.setIfBody( rewriteStatementBlocks( istmt.getIfBody(), status ) );
			istmt.setElseBody( rewriteStatementBlocks( istmt.getElseBody(), status ) );
		}
		else if (sb instanceof ForStatementBlock) //incl parfor
		{
			//maintain parfor context information (e.g., for checkpointing)
			boolean prestatus = status.isInParforContext();
			if( sb instanceof ParForStatementBlock )
				status.setInParforContext(true);
			
			ForStatementBlock fsb = (ForStatementBlock) sb;
			ForStatement fstmt = (ForStatement)fsb.getStatement(0);
			fstmt.setBody( rewriteStatementBlocks(fstmt.getBody(), status) );
			
			status.setInParforContext(prestatus);
		}
		
		//apply rewrite rules
		for( StatementBlockRewriteRule r : _sbRuleSet )
		{
			ArrayList tmp = new ArrayList();			
			for( StatementBlock sbc : ret )
				tmp.addAll( r.rewriteStatementBlock(sbc, status) );
			
			//take over set of rewritten sbs		
			ret.clear();
			ret.addAll(tmp);
		}
		
		return ret;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy