All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.hops.globalopt.gdfgraph.GraphBuilder Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.globalopt.gdfgraph;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map.Entry;
import java.util.Set;

import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.globalopt.Summary;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.utils.Explain;

/**
 * GENERAL 'GDF GRAPH' STRUCTURE, by MB:
 *  1) Each hop is represented by an GDFNode
 *  2) Each loop is represented by a structured GDFLoopNode
 *  3) Transient Read/Write connections are represented via CrossBlockNodes,
 *     a) type PLAIN: single input crossblocknode represents unconditional data flow
 *     b) type MERGE: two inputs crossblocknode represent conditional data flow merge
 * 
 *  In detail, the graph builder essentially does a single pass over the entire program
 *  and constructs the global data flow graph bottom up. We create crossblocknodes for
 *  every transient write, loop nodes for for/while programblocks, and crossblocknodes
 *  after every if programblock. 
 *  
 */
public class GraphBuilder 
{
	
	private static final boolean IGNORE_UNBOUND_UPDATED_VARS = true;
	
	/**
	 * 
	 * @param prog
	 * @return
	 * @throws DMLRuntimeException
	 * @throws HopsException 
	 */
	public static GDFGraph constructGlobalDataFlowGraph( Program prog, Summary summary )
		throws DMLRuntimeException, HopsException
	{
		Timing time = new Timing(true);
		
		HashMap roots = new HashMap();		
		for( ProgramBlock pb : prog.getProgramBlocks() )
			constructGDFGraph( pb, roots );
		
		//create GDF graph root nodes 
		ArrayList ret = new ArrayList();
		for( GDFNode root : roots.values() )
			if( !(root instanceof GDFCrossBlockNode) )
				ret.add(root);
		
		//create GDF graph
		GDFGraph graph = new GDFGraph(prog, ret);
		
		summary.setTimeGDFGraph(time.stop());		
		return graph;
	}
	
	/**
	 * 
	 * @param pb
	 * @param roots
	 * @throws DMLRuntimeException
	 * @throws HopsException
	 */
	@SuppressWarnings("unchecked")
	private static void constructGDFGraph( ProgramBlock pb, HashMap roots ) 
		throws DMLRuntimeException, HopsException
	{
		if (pb instanceof FunctionProgramBlock )
		{
			throw new DMLRuntimeException("FunctionProgramBlocks not implemented yet.");
		}
		else if (pb instanceof WhileProgramBlock)
		{
			WhileProgramBlock wpb = (WhileProgramBlock) pb;
			WhileStatementBlock wsb = (WhileStatementBlock) pb.getStatementBlock();		
			//construct predicate node (conceptually sequence of from/to/incr)
			GDFNode pred = constructGDFGraph(wsb.getPredicateHops(), wpb, new HashMap(), roots);
			HashMap inputs = constructLoopInputNodes(wpb, wsb, roots);
			HashMap lroots = (HashMap) inputs.clone();
			//process childs blocks
			for( ProgramBlock pbc : wpb.getChildBlocks() )
				constructGDFGraph(pbc, lroots);
			HashMap outputs = constructLoopOutputNodes(wsb, lroots);
			GDFLoopNode lnode = new GDFLoopNode(wpb, pred, inputs, outputs );
			//construct crossblock nodes
			constructLoopOutputCrossBlockNodes(wsb, lnode, outputs, roots, wpb);
		}	
		else if (pb instanceof IfProgramBlock)
		{
			IfProgramBlock ipb = (IfProgramBlock) pb;
			IfStatementBlock isb = (IfStatementBlock) pb.getStatementBlock();
			//construct predicate
			if( isb.getPredicateHops()!=null ) {
				Hop pred = isb.getPredicateHops();
				roots.put(pred.getName(), constructGDFGraph(pred, ipb, new HashMap(), roots));
			}
			//construct if and else branch separately
			HashMap ifRoots = (HashMap) roots.clone();
			HashMap elseRoots = (HashMap) roots.clone();
			for( ProgramBlock pbc : ipb.getChildBlocksIfBody() )
				constructGDFGraph(pbc, ifRoots);
			if( ipb.getChildBlocksElseBody()!=null )
				for( ProgramBlock pbc : ipb.getChildBlocksElseBody() )
					constructGDFGraph(pbc, elseRoots);
			//merge data flow roots (if no else, elseRoots refer to original roots)
			reconcileMergeIfProgramBlockOutputs(ifRoots, elseRoots, roots, ipb);
		}
		else if (pb instanceof ForProgramBlock) //incl parfor
		{
			ForProgramBlock fpb = (ForProgramBlock) pb;
			ForStatementBlock fsb = (ForStatementBlock)pb.getStatementBlock();
			//construct predicate node (conceptually sequence of from/to/incr)
			GDFNode pred = constructForPredicateNode(fpb, fsb, roots);
			HashMap inputs = constructLoopInputNodes(fpb, fsb, roots);
			HashMap lroots = (HashMap) inputs.clone();
			//process childs blocks
			for( ProgramBlock pbc : fpb.getChildBlocks() )
				constructGDFGraph(pbc, lroots);
			HashMap outputs = constructLoopOutputNodes(fsb, lroots);
			GDFLoopNode lnode = new GDFLoopNode(fpb, pred, inputs, outputs );
			//construct crossblock nodes
			constructLoopOutputCrossBlockNodes(fsb, lnode, outputs, roots, fpb);
		}
		else //last-level program block
		{
			StatementBlock sb = pb.getStatementBlock();
			ArrayList hops = sb.get_hops();
			if( hops != null )
			{
				//create new local memo structure for local dag
				HashMap lmemo = new HashMap();
				for( Hop hop : hops )
				{
					//recursively construct GDF graph for hop dag root
					GDFNode root = constructGDFGraph(hop, pb, lmemo, roots);
					if( root == null )
						throw new HopsException( "GDFGraphBuilder: failed to constuct dag root for: "+Explain.explain(hop) );
					
					//create cross block nodes for all transient writes
					if( hop instanceof DataOp && ((DataOp)hop).getDataOpType()==DataOpTypes.TRANSIENTWRITE )
						root = new GDFCrossBlockNode(hop, pb, root, hop.getName());
					
					//add GDF root node to global roots 
					roots.put(hop.getName(), root);
				}
			}
			
		}
	}
	
	/**
	 * 
	 * @param hop
	 * @param pb
	 * @param lmemo
	 * @param roots 
	 * @return
	 */
	private static GDFNode constructGDFGraph( Hop hop, ProgramBlock pb, HashMap lmemo, HashMap roots )
	{
		if( lmemo.containsKey(hop.getHopID()) )
			return lmemo.get(hop.getHopID());
		
		//process childs recursively first
		ArrayList inputs = new ArrayList();
		for( Hop c : hop.getInput() )
			inputs.add( constructGDFGraph(c, pb, lmemo, roots) );
		
		//connect transient reads to existing roots of data flow graph 
		if( hop instanceof DataOp && ((DataOp)hop).getDataOpType()==DataOpTypes.TRANSIENTREAD ){
			inputs.add(roots.get(hop.getName()));
		}
		
		//add current hop
		GDFNode gnode = new GDFNode(hop, pb, inputs);
				
		//add GDF node of updated variables to global roots (necessary for loops, where updated local
		//variables might never be bound to their logical variables names
		if( !IGNORE_UNBOUND_UPDATED_VARS ) {
			//NOTE: currently disabled because unnecessary, if no transientwrite by definition included in other transientwrite
			if( pb.getStatementBlock()!=null && pb.getStatementBlock().variablesUpdated().containsVariable(hop.getName()) ) {
				roots.put(hop.getName(), gnode);
			}
		}
		
		//memoize current node
		lmemo.put(hop.getHopID(), gnode);
		
		return gnode;
	}
	
	/**
	 * 
	 * @param fpb
	 * @param fsb
	 * @param roots
	 * @return
	 */
	private static GDFNode constructForPredicateNode(ForProgramBlock fpb, ForStatementBlock fsb, HashMap roots)
	{
		HashMap memo = new HashMap();
		GDFNode from = (fsb.getFromHops()!=null)? constructGDFGraph(fsb.getFromHops(), fpb, memo, roots) : null;
		GDFNode to = (fsb.getToHops()!=null)? constructGDFGraph(fsb.getToHops(), fpb, memo, roots) : null;
		GDFNode incr = (fsb.getIncrementHops()!=null)? constructGDFGraph(fsb.getIncrementHops(), fpb, memo, roots) : null;
		ArrayList inputs = new ArrayList();
		inputs.add(from);
		inputs.add(to);
		inputs.add(incr);
		//TODO for predicates 
		GDFNode pred = new GDFNode(null, fpb, inputs );
		
		return pred;
	}
	
	/**
	 * 
	 * @param fpb
	 * @param fsb
	 * @param roots
	 * @return
	 * @throws DMLRuntimeException 
	 */
	private static HashMap constructLoopInputNodes( ProgramBlock fpb, StatementBlock fsb, HashMap roots ) 
		throws DMLRuntimeException
	{
		HashMap ret = new HashMap();
		Set invars = fsb.variablesRead().getVariableNames();
		for( String var : invars ) {
			if( fsb.liveIn().containsVariable(var) ) {
				GDFNode node = roots.get(var);
				if( node == null )
					throw new DMLRuntimeException("GDFGraphBuilder: Non-existing input node for variable: "+var);
				ret.put(var, node);
			}
		}
		
		return ret;
	}
	
	private static HashMap constructLoopOutputNodes( StatementBlock fsb, HashMap roots ) 
		throws HopsException
	{
		HashMap ret = new HashMap();
		
		Set outvars = fsb.variablesUpdated().getVariableNames();
		for( String var : outvars ) 
		{
			GDFNode node = roots.get(var);
			
			//handle non-existing nodes
			if( node == null ) {
				if( !IGNORE_UNBOUND_UPDATED_VARS )
					throw new HopsException( "GDFGraphBuilder: failed to constuct loop output for variable: "+var );	
				else
					continue; //skip unbound updated variables	
			}
			
			//add existing node to loop outputs 
			ret.put(var, node);
		}
		
		return ret;
	}
	
	/**
	 * 
	 * @param ifRoots
	 * @param elseRoots
	 * @param roots
	 * @param pb
	 */
	private static void reconcileMergeIfProgramBlockOutputs( HashMap ifRoots, HashMap elseRoots, HashMap roots, IfProgramBlock pb )
	{
		//merge same variable names, different data
		//( incl add new vars from if branch if node2==null)
		for( Entry e : ifRoots.entrySet() ){
			GDFNode node1 = e.getValue();
			GDFNode node2 = elseRoots.get(e.getKey()); //original or new
			if( node1 != node2 )
				node1 = new GDFCrossBlockNode(null, pb, node1, node2, e.getKey() );
			roots.put(e.getKey(), node1);	
		}
		
		//add new vars from else branch 
		for( Entry e : elseRoots.entrySet() ){
			if( !ifRoots.containsKey(e.getKey()) )
				roots.put(e.getKey(), e.getValue());	
		}
	}
	
	/**
	 * 
	 * @param sb
	 * @param loop
	 * @param loutputs
	 * @param roots
	 * @param pb
	 */
	private static void constructLoopOutputCrossBlockNodes(StatementBlock sb, GDFLoopNode loop, HashMap loutputs, HashMap roots, ProgramBlock pb)
	{
		//iterate over all output (updated) variables
		for( Entry e : loutputs.entrySet() ) 
		{
			//create crossblocknode, if updated variable is also in liveout
			if( sb.liveOut().containsVariable(e.getKey()) ) {
				GDFCrossBlockNode node = null;
				if( roots.containsKey(e.getKey()) )
					node = new GDFCrossBlockNode(null, pb, roots.get(e.getKey()), loop, e.getKey()); //MERGE
				else
					node = new GDFCrossBlockNode(null, pb, loop, e.getKey()); //PLAIN
				roots.put(e.getKey(), node);
			}
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy