All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.hops.rewrite.RewriteMergeBlockSequence Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;

import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.parser.ExternalFunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;

/**
 * Rule: Simplify program structure by merging sequences of last-level
 * statement blocks in order to create optimization opportunities.
 * 
 */
public class RewriteMergeBlockSequence extends StatementBlockRewriteRule
{
	private ProgramRewriter rewriter = new ProgramRewriter(
		new RewriteCommonSubexpressionElimination(true));
	
	@Override
	public List rewriteStatementBlock(StatementBlock sb, 
			ProgramRewriteStatus state) throws HopsException {
		return Arrays.asList(sb);
	}
	
	@Override
	public List rewriteStatementBlocks(List sbs, 
			ProgramRewriteStatus sate) throws HopsException 
	{
		if( sbs == null || sbs.isEmpty() )
			return sbs;
		
		//execute binary merging iterations until fixpoint 
		ArrayList tmpList = new ArrayList<>(sbs);
		boolean merged = true;
		while( merged ) {
			merged = false;
			for( int i=0; i sb1Hops = sb1.getHops();
					ArrayList sb2Hops = sb2.getHops();
					
					//determine transient read inputs s2 
					Hop.resetVisitStatus(sb2Hops);
					HashMap treads = new HashMap<>();
					HashMap twrites = new HashMap<>();
					for( Hop root : sb2Hops )
						rCollectTransientReadWrites(root, treads, twrites);
					Hop.resetVisitStatus(sb2Hops);
					
					//merge hop dags of s1 and s2
					Hop.resetVisitStatus(sb1Hops);
					for( Hop root : sb1Hops ) {
						//connect transient writes s1 and reads s2
						if( HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE) 
							&& treads.containsKey(root.getName()) ) {
							//rewire transient write and transient read
							Hop tread = treads.get(root.getName());
							Hop in = root.getInput().get(0);
							for( Hop parent : new ArrayList<>(tread.getParent()) )
								HopRewriteUtils.replaceChildReference(parent, tread, in);
							HopRewriteUtils.removeAllChildReferences(root);
							//add transient write if necessary
							if( !twrites.containsKey(root.getName()) 
								&& sb2.liveOut().containsVariable(root.getName()) ) {
								sb2Hops.add(HopRewriteUtils.createDataOp(
									root.getName(), in, DataOpTypes.TRANSIENTWRITE));
							}
						}
						//add remaining roots from s1 to s2
						else if( !(HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE)
							&& (twrites.containsKey(root.getName()) || !sb2.liveOut().containsVariable(root.getName()))) ) {
							sb2Hops.add(root);
						}
					}
					//clear partial hops from the merged statement block to avoid problems with 
					//other statement block rewrites that iterate over the original program
					sb1Hops.clear();
					
					//run common-subexpression elimination
					Hop.resetVisitStatus(sb2Hops);
					rewriter.rewriteHopDAG(sb2Hops, new ProgramRewriteStatus());
					
					//modify live variable sets of s2
					sb2.setLiveIn(sb1.liveIn()); //liveOut remains unchanged
					sb2.setGen(VariableSet.minus(VariableSet.union(sb1.getGen(), sb2.getGen()), sb1.getKill()));
					sb2.setKill(VariableSet.union(sb1.getKill(), sb2.getKill()));
					sb2.setReadVariables(VariableSet.union(sb1.variablesRead(), sb2.variablesRead()));
					sb2.setUpdatedVariables(VariableSet.union(sb1.variablesUpdated(), sb2.variablesUpdated()));
					
					LOG.debug("Applied mergeStatementBlockSequences "
							+ "(blocks of lines "+sb1.getBeginLine()+"-"+sb1.getEndLine()
							+" and "+sb2.getBeginLine()+"-"+sb2.getEndLine()+").");
					
					//modify line numbers of s2
					sb2.setBeginLine(sb1.getBeginLine());
					sb2.setBeginColumn(sb1.getBeginColumn());
					
					//remove sb1 from list of statement blocks
					tmpList.remove(i);
					merged = true;
					break; //for
				}
			}
		}
		
		return tmpList;
	}
	
	private void rCollectTransientReadWrites(Hop current, HashMap treads, HashMap twrites) {
		if( current.isVisited() )
			return;
		//process nodes recursively
		for( Hop c : current.getInput() )
			rCollectTransientReadWrites(c, treads, twrites);
		//collect all transient reads
		if( HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTREAD) )
			treads.put(current.getName(), current);
		else if( HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTWRITE) )
			twrites.put(current.getName(), current);
		else if( current instanceof FunctionOp ) {
			for( String output : ((FunctionOp)current).getOutputVariableNames() )
				twrites.put(output, null); //only name lookup
		}
		current.setVisited();
	}
	
	private static boolean hasFunctionOpRoot(StatementBlock sb) 
			throws HopsException {
		if( sb == null || sb.getHops() == null )
			return false;
		boolean ret = false;
		for( Hop root : sb.getHops() )
			ret |= (root instanceof FunctionOp);
		return ret;
	}
	
	private static boolean hasExternalFunctionOpRootWithSideEffect(StatementBlock sb) 
			throws HopsException {
		if( sb == null || sb.getHops() == null )
			return false;
		for( Hop root : sb.getHops() )
			if( root instanceof FunctionOp ) {
				FunctionStatementBlock fsb = sb.getDMLProg()
					.getFunctionStatementBlock(((FunctionOp)root).getFunctionKey());
				//note: in case of builtin multi-return functions such as qr (namespace _internal), 
				//there is no function statement block and hence we need to check for null
				if( fsb != null && fsb.getStatement(0) instanceof ExternalFunctionStatement
					&& ((ExternalFunctionStatement)fsb.getStatement(0)).hasSideEffects() )
					return true; 
			}
		return false;
	}
	
	private static boolean hasFunctionIOConflict(StatementBlock sb1, StatementBlock sb2) 
		throws HopsException 
	{
		//semantics: a function op root in sb1 conflicts with sb2 if this function op writes
		//to a variable that is read or written by sb2, where the write might be either
		//a traditional transient write or another function op.
		
		//collect all function output variables of sb1
		HashSet outSb1 = new HashSet<>();
		for( Hop root : sb1.getHops() )
			if( root instanceof FunctionOp )
				outSb1.addAll(Arrays.asList(((FunctionOp)root).getOutputVariableNames()));
		
		//check all output variables against read/updated sets
		return sb2.variablesRead().containsAnyName(outSb1)
			|| sb2.variablesUpdated().containsAnyName(outSb1);
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy