org.apache.sysml.hops.rewrite.RewriteMergeBlockSequence Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.parser.ExternalFunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;
/**
* Rule: Simplify program structure by merging sequences of last-level
* statement blocks in order to create optimization opportunities.
*
*/
public class RewriteMergeBlockSequence extends StatementBlockRewriteRule
{
private ProgramRewriter rewriter = new ProgramRewriter(
new RewriteCommonSubexpressionElimination(true));
@Override
public List rewriteStatementBlock(StatementBlock sb,
ProgramRewriteStatus state) throws HopsException {
return Arrays.asList(sb);
}
@Override
public List rewriteStatementBlocks(List sbs,
ProgramRewriteStatus sate) throws HopsException
{
if( sbs == null || sbs.isEmpty() )
return sbs;
//execute binary merging iterations until fixpoint
ArrayList tmpList = new ArrayList<>(sbs);
boolean merged = true;
while( merged ) {
merged = false;
for( int i=0; i sb1Hops = sb1.getHops();
ArrayList sb2Hops = sb2.getHops();
//determine transient read inputs s2
Hop.resetVisitStatus(sb2Hops);
HashMap treads = new HashMap<>();
HashMap twrites = new HashMap<>();
for( Hop root : sb2Hops )
rCollectTransientReadWrites(root, treads, twrites);
Hop.resetVisitStatus(sb2Hops);
//merge hop dags of s1 and s2
Hop.resetVisitStatus(sb1Hops);
for( Hop root : sb1Hops ) {
//connect transient writes s1 and reads s2
if( HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE)
&& treads.containsKey(root.getName()) ) {
//rewire transient write and transient read
Hop tread = treads.get(root.getName());
Hop in = root.getInput().get(0);
for( Hop parent : new ArrayList<>(tread.getParent()) )
HopRewriteUtils.replaceChildReference(parent, tread, in);
HopRewriteUtils.removeAllChildReferences(root);
//add transient write if necessary
if( !twrites.containsKey(root.getName())
&& sb2.liveOut().containsVariable(root.getName()) ) {
sb2Hops.add(HopRewriteUtils.createDataOp(
root.getName(), in, DataOpTypes.TRANSIENTWRITE));
}
}
//add remaining roots from s1 to s2
else if( !(HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE)
&& (twrites.containsKey(root.getName()) || !sb2.liveOut().containsVariable(root.getName()))) ) {
sb2Hops.add(root);
}
}
//clear partial hops from the merged statement block to avoid problems with
//other statement block rewrites that iterate over the original program
sb1Hops.clear();
//run common-subexpression elimination
Hop.resetVisitStatus(sb2Hops);
rewriter.rewriteHopDAG(sb2Hops, new ProgramRewriteStatus());
//modify live variable sets of s2
sb2.setLiveIn(sb1.liveIn()); //liveOut remains unchanged
sb2.setGen(VariableSet.minus(VariableSet.union(sb1.getGen(), sb2.getGen()), sb1.getKill()));
sb2.setKill(VariableSet.union(sb1.getKill(), sb2.getKill()));
sb2.setReadVariables(VariableSet.union(sb1.variablesRead(), sb2.variablesRead()));
sb2.setUpdatedVariables(VariableSet.union(sb1.variablesUpdated(), sb2.variablesUpdated()));
LOG.debug("Applied mergeStatementBlockSequences "
+ "(blocks of lines "+sb1.getBeginLine()+"-"+sb1.getEndLine()
+" and "+sb2.getBeginLine()+"-"+sb2.getEndLine()+").");
//modify line numbers of s2
sb2.setBeginLine(sb1.getBeginLine());
sb2.setBeginColumn(sb1.getBeginColumn());
//remove sb1 from list of statement blocks
tmpList.remove(i);
merged = true;
break; //for
}
}
}
return tmpList;
}
private void rCollectTransientReadWrites(Hop current, HashMap treads, HashMap twrites) {
if( current.isVisited() )
return;
//process nodes recursively
for( Hop c : current.getInput() )
rCollectTransientReadWrites(c, treads, twrites);
//collect all transient reads
if( HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTREAD) )
treads.put(current.getName(), current);
else if( HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTWRITE) )
twrites.put(current.getName(), current);
else if( current instanceof FunctionOp ) {
for( String output : ((FunctionOp)current).getOutputVariableNames() )
twrites.put(output, null); //only name lookup
}
current.setVisited();
}
private static boolean hasFunctionOpRoot(StatementBlock sb)
throws HopsException {
if( sb == null || sb.getHops() == null )
return false;
boolean ret = false;
for( Hop root : sb.getHops() )
ret |= (root instanceof FunctionOp);
return ret;
}
private static boolean hasExternalFunctionOpRootWithSideEffect(StatementBlock sb)
throws HopsException {
if( sb == null || sb.getHops() == null )
return false;
for( Hop root : sb.getHops() )
if( root instanceof FunctionOp ) {
FunctionStatementBlock fsb = sb.getDMLProg()
.getFunctionStatementBlock(((FunctionOp)root).getFunctionKey());
//note: in case of builtin multi-return functions such as qr (namespace _internal),
//there is no function statement block and hence we need to check for null
if( fsb != null && fsb.getStatement(0) instanceof ExternalFunctionStatement
&& ((ExternalFunctionStatement)fsb.getStatement(0)).hasSideEffects() )
return true;
}
return false;
}
private static boolean hasFunctionIOConflict(StatementBlock sb1, StatementBlock sb2)
throws HopsException
{
//semantics: a function op root in sb1 conflicts with sb2 if this function op writes
//to a variable that is read or written by sb2, where the write might be either
//a traditional transient write or another function op.
//collect all function output variables of sb1
HashSet outSb1 = new HashSet<>();
for( Hop root : sb1.getHops() )
if( root instanceof FunctionOp )
outSb1.addAll(Arrays.asList(((FunctionOp)root).getOutputVariableNames()));
//check all output variables against read/updated sets
return sb2.variablesRead().containsAnyName(outSb1)
|| sb2.variablesUpdated().containsAnyName(outSb1);
}
}