org.apache.sysml.hops.rewrite.RewriteSplitDagDataDependentOperators Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import java.util.HashSet;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.Hop.OpOp3;
import org.apache.sysml.hops.Hop.ParamBuiltinOp;
import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.Hop.ReOrgOp;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.ReorgOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.runtime.matrix.data.Pair;
/**
* Rule: Split Hop DAG after specific data-dependent operators. This is
* important to create recompile hooks if output dimensions are usually
* significantly overestimated.
*
* This is a recursive statementblock rewrite rule.
*
* NOTE: Before we used AssignmentStatement.controlStatement() in order to force
* statementblock cuts. However, this (1) cuts not only after but before-and-after
* (which prevents certain rewrites because the input operators are unknown),
* and (2) is statement-centric which potentially prevents the cut right after
* the problematic operation.
*
* TODO: Cleanup runtime to never access individual statements of potentially
* split statements blocks again (for consistency). However, currently it is
* only used in places (e.g., parfor optimizer) that are not directly affected.
*/
public class RewriteSplitDagDataDependentOperators extends StatementBlockRewriteRule
{
private static String _varnamePredix = "_sbcvar";
private static IDSequence _seq = new IDSequence();
@Override
public ArrayList rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state)
throws HopsException
{
ArrayList ret = new ArrayList();
//collect all unknown csv reads hops
ArrayList cand = new ArrayList();
collectDataDependentOperators( sb.get_hops(), cand );
Hop.resetVisitStatus(sb.get_hops());
//split hop dag on demand
if( !cand.isEmpty() )
{
//collect child operators of candidates (to prevent rewrite anomalies)
HashSet candChilds = new HashSet();
collectCandidateChildOperators( cand, candChilds );
try
{
//duplicate sb incl live variable sets
StatementBlock sb1 = new StatementBlock();
sb1.setDMLProg(sb.getDMLProg());
sb1.setAllPositions(sb.getFilename(), sb.getBeginLine(), sb.getBeginColumn(), sb.getEndLine(), sb.getEndColumn());
sb1.setLiveIn(new VariableSet());
sb1.setLiveOut(new VariableSet());
//move data-dependent ops incl transient writes to new statement block
//(and replace original persistent read with transient read)
ArrayList sb1hops = new ArrayList();
for( Hop c : cand )
{
//if there are already transient writes use them and don't introduce artificial variables;
//unless there are transient reads w/ the same variable name in the current dag which can
//lead to invalid reordering if variable consumers are not feeding into the candidate op.
boolean hasTWrites = hasTransientWriteParents(c);
boolean moveTWrite = hasTWrites ? HopRewriteUtils.rHasSimpleReadChain(c,
getFirstTransientWriteParent(c).getName()) : false;
String varname = null;
long rlen = c.getDim1();
long clen = c.getDim2();
long nnz = c.getNnz();
UpdateType update = c.getUpdateType();
long brlen = c.getRowsInBlock();
long bclen = c.getColsInBlock();
if( hasTWrites && moveTWrite) //reuse existing transient_write
{
Hop twrite = getFirstTransientWriteParent(c);
varname = twrite.getName();
//create new transient read
DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(),
DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen);
tread.setVisited();
HopRewriteUtils.copyLineNumbers(c, tread);
//replace data-dependent operator with transient read
ArrayList parents = new ArrayList(c.getParent());
for( int i=0; i parents = new ArrayList(c.getParent());
for( int i=0; i tmp = rewriteStatementBlock( sb1, state);
//add new statement blocks to output
ret.addAll(tmp); //statement block with data dependent hops
ret.add(sb); //statement block with remaining hops
}
catch(Exception ex)
{
throw new HopsException("Failed to split hops dag for data dependent operators with unknown size.", ex);
}
LOG.debug("Applied splitDagDataDependentOperators (lines "+sb.getBeginLine()+"-"+sb.getEndLine()+").");
}
//keep original hop dag
else
{
ret.add(sb);
}
return ret;
}
private void collectDataDependentOperators( ArrayList roots, ArrayList cand )
{
if( roots == null )
return;
Hop.resetVisitStatus(roots);
for( Hop root : roots )
rCollectDataDependentOperators(root, cand);
}
private void rCollectDataDependentOperators( Hop hop, ArrayList cand )
{
if( hop.isVisited() )
return;
//prevent unnecessary dag split (dims known or no consumer operations)
boolean noSplitRequired = ( hop.dimsKnown() || HopRewriteUtils.hasOnlyWriteParents(hop, true, true) );
boolean investigateChilds = true;
//collect data dependent operations (to be extended as necessary)
//#1 removeEmpty
if( hop instanceof ParameterizedBuiltinOp
&& ((ParameterizedBuiltinOp) hop).getOp()==ParamBuiltinOp.RMEMPTY
&& !noSplitRequired
&& !(hop.getParent().size()==1 && hop.getParent().get(0) instanceof TernaryOp
&& ((TernaryOp)hop.getParent().get(0)).isMatrixIgnoreZeroRewriteApplicable()))
{
ParameterizedBuiltinOp pbhop = (ParameterizedBuiltinOp)hop;
cand.add(pbhop);
investigateChilds = false;
//keep interesting consumer information, flag hops accordingly
boolean noEmptyBlocks = true;
boolean onlyPMM = true;
boolean diagInput = pbhop.isTargetDiagInput();
for( Hop p : hop.getParent() ) {
//list of operators without need for empty blocks to be extended as needed
noEmptyBlocks &= ( p instanceof AggBinaryOp && hop == p.getInput().get(0)
|| HopRewriteUtils.isUnary(p, OpOp1.NROW) );
onlyPMM &= (p instanceof AggBinaryOp && hop == p.getInput().get(0));
}
pbhop.setOutputEmptyBlocks(!noEmptyBlocks);
if( onlyPMM && diagInput ){
//configure rmEmpty to directly output selection vector
//(only applied if dynamic recompilation enabled)
if( ConfigurationManager.isDynamicRecompilation() )
pbhop.setOutputPermutationMatrix(true);
for( Hop p : hop.getParent() )
((AggBinaryOp)p).setHasLeftPMInput(true);
}
}
//#2 ctable with unknown dims
if( hop instanceof TernaryOp
&& ((TernaryOp) hop).getOp()==OpOp3.CTABLE
&& hop.getInput().size() < 4 //dims not provided
&& !noSplitRequired )
{
cand.add(hop);
investigateChilds = false;
//keep interesting consumer information, flag hops accordingly
boolean onlyPMM = true;
for( Hop p : hop.getParent() ) {
onlyPMM &= (p instanceof AggBinaryOp && hop == p.getInput().get(0));
}
if( onlyPMM && HopRewriteUtils.isBasic1NSequence(hop.getInput().get(0)) )
hop.setOutputEmptyBlocks(false);
}
//#3 orderby childs computed in same DAG
if( hop instanceof ReorgOp
&& ((ReorgOp)hop).getOp()==ReOrgOp.SORT )
{
//params 'decreasing' / 'indexreturn'
for( int i=2; i<=3; i++ ) {
Hop c = hop.getInput().get(i);
if( !(c instanceof LiteralOp || c instanceof DataOp) ){
cand.add(c);
c.setVisited();
investigateChilds = false;
}
}
}
//process children (if not already found a special operators;
//otherwise, processed by recursive rule application)
if( investigateChilds )
if( hop.getInput()!=null )
for( Hop c : hop.getInput() )
rCollectDataDependentOperators(c, cand);
hop.setVisited();
}
private boolean hasTransientWriteParents( Hop hop )
{
for( Hop p : hop.getParent() )
if( p instanceof DataOp && ((DataOp)p).getDataOpType()==DataOpTypes.TRANSIENTWRITE )
return true;
return false;
}
private Hop getFirstTransientWriteParent( Hop hop )
{
for( Hop p : hop.getParent() )
if( p instanceof DataOp && ((DataOp)p).getDataOpType()==DataOpTypes.TRANSIENTWRITE )
return p;
return null;
}
private void handleReplicatedOperators( ArrayList rootsSB1, ArrayList rootsSB2, VariableSet sb1out, VariableSet sb2in )
{
//step 1: create probe set SB1
HashSet probeSet = new HashSet();
Hop.resetVisitStatus(rootsSB1);
for( Hop h : rootsSB1 )
rAddHopsToProbeSet( h, probeSet );
//step 2: probe SB2 operators top-down (collect cut candidates)
HashSet> candSet = new HashSet>();
Hop.resetVisitStatus(rootsSB2);
for( Hop h : rootsSB2 )
rProbeAndAddHopsToCandidateSet(h, probeSet, candSet);
//step 3: create additional cuts
for( Pair p : candSet )
{
String varname = _varnamePredix + _seq.getNextID();
Hop hop = p.getKey();
Hop c = p.getValue();
DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), DataOpTypes.TRANSIENTREAD,
null, c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock());
tread.setVisited();
HopRewriteUtils.copyLineNumbers(c, tread);
DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null);
twrite.setVisited();
twrite.setOutputParams(c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock());
HopRewriteUtils.copyLineNumbers(c, twrite);
//create additional cut by rewriting both hop dags
int pos = HopRewriteUtils.getChildReferencePos(hop, c);
HopRewriteUtils.removeChildReferenceByPos(hop, c, pos);
HopRewriteUtils.addChildReference(hop, tread, pos);
//update live in and out of new statement block (for piggybacking)
DataIdentifier diVar = new DataIdentifier(varname);
diVar.setDimensions(c.getDim1(), c.getDim2());
diVar.setBlockDimensions(c.getRowsInBlock(), c.getColsInBlock());
diVar.setDataType(c.getDataType());
diVar.setValueType(c.getValueType());
sb1out.addVariable(varname, new DataIdentifier(diVar));
sb2in.addVariable(varname, new DataIdentifier(diVar));
rootsSB1.add(twrite);
}
}
private void rAddHopsToProbeSet( Hop hop, HashSet probeSet )
{
if( hop.isVisited() )
return;
//prevent cuts for no-ops
if( !( (hop instanceof DataOp && !((DataOp)hop).isPersistentReadWrite() )
|| hop instanceof LiteralOp) )
{
probeSet.add(hop);
}
if( hop.getInput() != null )
for( Hop c : hop.getInput() )
rAddHopsToProbeSet(c, probeSet);
hop.setVisited();
}
/**
* NOTE: candset is a set of parent-child pairs because a parent might have
* multiple references to replicated hops.
*
* @param hop high-level operator
* @param probeSet probe set?
* @param candSet candidate set?
*/
private void rProbeAndAddHopsToCandidateSet( Hop hop, HashSet probeSet, HashSet> candSet )
{
if( hop.isVisited() )
return;
if( hop.getInput() != null )
for( Hop c : hop.getInput() ) {
//probe for replicated operator, if any child is replicated, keep parent
//for cut between parent-child; otherwise recursively descend.
if( !probeSet.contains(c) )
rProbeAndAddHopsToCandidateSet(c, probeSet, candSet);
else
{
candSet.add(new Pair(hop,c));
}
}
hop.setVisited();
}
private void collectCandidateChildOperators( ArrayList cand, HashSet candChilds )
{
Hop.resetVisitStatus(cand);
if( cand != null )
for( Hop root : cand )
rCollectCandidateChildOperators(root, cand, candChilds, false);
// Immediately reset the visit status because candidates might be inner nodes in the DAG.
// Subsequent resets on the root nodes of the DAG would otherwise not necessarily reach
// these nodes which could lead to missing checks on subsequent passes (e.g., when checking
// for replicated operators).
Hop.resetVisitStatus(cand);
}
private void rCollectCandidateChildOperators( Hop hop, ArrayList cand, HashSet candChilds, boolean collect )
{
if( hop.isVisited() )
return;
//collect operator if necessary
if( collect ) {
candChilds.add(hop);
}
//activate collection if we passed a candidate
boolean passedFlag = collect;
if( cand.contains(hop) ) {
passedFlag = true;
}
//process childs recursively
if( hop.getInput()!=null ) {
for( Hop c : hop.getInput() )
rCollectCandidateChildOperators(c, cand, candChilds, passedFlag);
}
hop.setVisited();
}
}