org.apache.sysml.hops.rewrite.RewriteForLoopVectorization Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LeftIndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
/**
* Rule: Simplify program structure by pulling if or else statement body out
* (removing the if statement block ifself) in order to allow intra-procedure
* analysis to propagate exact statistics.
*
*/
public class RewriteForLoopVectorization extends StatementBlockRewriteRule
{
private static final OpOp2[] MAP_SCALAR_AGGREGATE_SOURCE_OPS = new OpOp2[]{OpOp2.PLUS, OpOp2.MULT, OpOp2.MIN, OpOp2.MAX};
private static final AggOp[] MAP_SCALAR_AGGREGATE_TARGET_OPS = new AggOp[]{AggOp.SUM, AggOp.PROD, AggOp.MIN, AggOp.MAX};
@Override
public ArrayList rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state)
throws HopsException
{
ArrayList ret = new ArrayList();
if( sb instanceof ForStatementBlock )
{
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
Hop from = fsb.getFromHops();
Hop to = fsb.getToHops();
Hop incr = fsb.getIncrementHops();
String iterVar = fsb.getIterPredicate().getIterVar().getName();
if( fs.getBody()!=null && fs.getBody().size()==1 ) //single child block
{
StatementBlock csb = (StatementBlock) fs.getBody().get(0);
if( !( csb instanceof WhileStatementBlock //last level block
|| csb instanceof IfStatementBlock
|| csb instanceof ForStatementBlock ) )
{
//auto vectorzation pattern
sb = vectorizeScalarAggregate(sb, csb, from, to, incr, iterVar); //e.g., for(i){s = s + as.scalar(X[i,2])}
sb = vectorizeElementwiseBinary(sb, csb, from, to, incr, iterVar);
sb = vectorizeElementwiseUnary(sb, csb, from, to, incr, iterVar);
}
}
}
//if no rewrite applied sb is the original for loop otherwise a last level statement block
//that includes the equivalent vectorized operations.
ret.add( sb );
return ret;
}
/**
* Note: unnecessary row or column indexing then later removed via
* dynamic rewrites
*
* @param sb
* @param csb
* @param from
* @param to
* @param increment
* @param itervar
* @return
* @throws HopsException
*/
private StatementBlock vectorizeScalarAggregate( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar )
throws HopsException
{
StatementBlock ret = sb;
//check supported increment values
if( !(increment instanceof LiteralOp && ((LiteralOp)increment).getDoubleValue()==1.0) ){
return ret;
}
//check for applicability
boolean leftScalar = false;
boolean rightScalar = false;
boolean rowIx = false; //row or col
if( csb.get_hops()!=null && csb.get_hops().size()==1 ){
Hop root = csb.get_hops().get(0);
if( root.getDataType()==DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp ) {
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
//check for left scalar plus
if( HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS)
&& left instanceof DataOp && left.getDataType() == DataType.SCALAR
&& root.getName().equals(left.getName())
&& right instanceof UnaryOp && ((UnaryOp) right).getOp() == OpOp1.CAST_AS_SCALAR
&& right.getInput().get(0) instanceof IndexingOp )
{
IndexingOp ix = (IndexingOp)right.getInput().get(0);
if( ix.getRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp
&& ix.getInput().get(1).getName().equals(itervar) ){
leftScalar = true;
rowIx = true;
}
else if( ix.getColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp
&& ix.getInput().get(3).getName().equals(itervar) ){
leftScalar = true;
rowIx = false;
}
}
//check for right scalar plus
else if( HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS)
&& right instanceof DataOp && right.getDataType() == DataType.SCALAR
&& root.getName().equals(right.getName())
&& left instanceof UnaryOp && ((UnaryOp) left).getOp() == OpOp1.CAST_AS_SCALAR
&& left.getInput().get(0) instanceof IndexingOp )
{
IndexingOp ix = (IndexingOp)left.getInput().get(0);
if( ix.getRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp
&& ix.getInput().get(1).getName().equals(itervar) ){
rightScalar = true;
rowIx = true;
}
else if( ix.getColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp
&& ix.getInput().get(3).getName().equals(itervar) ){
rightScalar = true;
rowIx = false;
}
}
}
}
//apply rewrite if possible
if( leftScalar || rightScalar )
{
Hop root = csb.get_hops().get(0);
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop cast = bop.getInput().get( leftScalar?1:0 );
Hop ix = cast.getInput().get(0);
int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
//replace cast with sum
AggUnaryOp newSum = new AggUnaryOp(cast.getName(), DataType.SCALAR, ValueType.DOUBLE, aggOp, Direction.RowCol, ix);
HopRewriteUtils.removeChildReference(cast, ix);
HopRewriteUtils.removeChildReference(bop, cast);
HopRewriteUtils.addChildReference(bop, newSum, leftScalar?1:0 );
//modify indexing expression according to loop predicate from-to
//NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
int index1 = rowIx ? 1 : 3;
int index2 = rowIx ? 2 : 4;
HopRewriteUtils.removeChildReferenceByPos(ix, ix.getInput().get(index1), index1);
HopRewriteUtils.addChildReference(ix, from, index1);
HopRewriteUtils.removeChildReferenceByPos(ix, ix.getInput().get(index2), index2);
HopRewriteUtils.addChildReference(ix, to, index2);
ret = csb;
//ret.liveIn().removeVariable(itervar);
LOG.debug("Applied vectorizeScalarSumForLoop.");
}
return ret;
}
/**
* Note: unnecessary row or column indexing then later removed via
* dynamic rewrites
*
* @param sb
* @param csb
* @param from
* @param to
* @param increment
* @param itervar
* @return
* @throws HopsException
*/
private StatementBlock vectorizeElementwiseBinary( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar )
throws HopsException
{
StatementBlock ret = sb;
//check supported increment values
if( !(increment instanceof LiteralOp && ((LiteralOp)increment).getDoubleValue()==1.0) ){
return ret;
}
//check for applicability
boolean apply = false;
boolean rowIx = false; //row or col
if( csb.get_hops()!=null && csb.get_hops().size()==1 )
{
Hop root = csb.get_hops().get(0);
if( root.getDataType()==DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp )
{
LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
Hop lixlhs = lix.getInput().get(0);
Hop lixrhs = lix.getInput().get(1);
if( lixlhs instanceof DataOp && lixrhs instanceof BinaryOp
&& lixrhs.getInput().get(0) instanceof IndexingOp
&& lixrhs.getInput().get(1) instanceof IndexingOp
&& lixrhs.getInput().get(0).getInput().get(0) instanceof DataOp
&& lixrhs.getInput().get(1).getInput().get(0) instanceof DataOp)
{
IndexingOp rix0 = (IndexingOp) lixrhs.getInput().get(0);
IndexingOp rix1 = (IndexingOp) lixrhs.getInput().get(1);
//check for rowwise
if( lix.getRowLowerEqualsUpper() && rix0.getRowLowerEqualsUpper() && rix1.getRowLowerEqualsUpper()
&& lix.getInput().get(2).getName().equals(itervar)
&& rix0.getInput().get(1).getName().equals(itervar)
&& rix1.getInput().get(1).getName().equals(itervar))
{
apply = true;
rowIx = true;
}
//check for colwise
if( lix.getColLowerEqualsUpper() && rix0.getColLowerEqualsUpper() && rix1.getColLowerEqualsUpper()
&& lix.getInput().get(4).getName().equals(itervar)
&& rix0.getInput().get(3).getName().equals(itervar)
&& rix1.getInput().get(3).getName().equals(itervar))
{
apply = true;
rowIx = false;
}
}
}
}
//apply rewrite if possible
if( apply )
{
Hop root = csb.get_hops().get(0);
LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
BinaryOp bop = (BinaryOp) lix.getInput().get(1);
IndexingOp rix0 = (IndexingOp) bop.getInput().get(0);
IndexingOp rix1 = (IndexingOp) bop.getInput().get(1);
int index1 = rowIx ? 2 : 4;
int index2 = rowIx ? 3 : 5;
//modify left indexing bounds
HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index1), index1 );
HopRewriteUtils.addChildReference(lix, from, index1);
HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index2), index2 );
HopRewriteUtils.addChildReference(lix, to, index2);
//modify both right indexing
HopRewriteUtils.removeChildReferenceByPos(rix0, rix0.getInput().get(index1-1), index1-1 );
HopRewriteUtils.addChildReference(rix0, from, index1-1);
HopRewriteUtils.removeChildReferenceByPos(rix0, rix0.getInput().get(index2-1), index2-1 );
HopRewriteUtils.addChildReference(rix0, to, index2-1);
HopRewriteUtils.removeChildReferenceByPos(rix1, rix1.getInput().get(index1-1), index1-1 );
HopRewriteUtils.addChildReference(rix1, from, index1-1);
HopRewriteUtils.removeChildReferenceByPos(rix1, rix1.getInput().get(index2-1), index2-1 );
HopRewriteUtils.addChildReference(rix1, to, index2-1);
rix0.refreshSizeInformation();
rix1.refreshSizeInformation();
bop.refreshSizeInformation();
lix.refreshSizeInformation();
ret = csb;
//ret.liveIn().removeVariable(itervar);
LOG.debug("Applied vectorizeElementwiseBinaryForLoop.");
}
return ret;
}
/**
* Note: unnecessary row or column indexing then later removed via
* dynamic rewrites
*
* @param sb
* @param csb
* @param from
* @param to
* @param increment
* @param itervar
* @return
* @throws HopsException
*/
private StatementBlock vectorizeElementwiseUnary( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar )
throws HopsException
{
StatementBlock ret = sb;
//check supported increment values
if( !(increment instanceof LiteralOp && ((LiteralOp)increment).getDoubleValue()==1.0) ){
return ret;
}
//check for applicability
boolean apply = false;
boolean rowIx = false; //row or col
if( csb.get_hops()!=null && csb.get_hops().size()==1 )
{
Hop root = csb.get_hops().get(0);
if( root.getDataType()==DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp )
{
LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
Hop lixlhs = lix.getInput().get(0);
Hop lixrhs = lix.getInput().get(1);
if( lixlhs instanceof DataOp && lixrhs instanceof UnaryOp
&& lixrhs.getInput().get(0) instanceof IndexingOp
&& lixrhs.getInput().get(0).getInput().get(0) instanceof DataOp )
{
IndexingOp rix = (IndexingOp) lixrhs.getInput().get(0);
//check for rowwise
if( lix.getRowLowerEqualsUpper() && rix.getRowLowerEqualsUpper()
&& lix.getInput().get(2).getName().equals(itervar)
&& rix.getInput().get(1).getName().equals(itervar) )
{
apply = true;
rowIx = true;
}
//check for colwise
if( lix.getColLowerEqualsUpper() && rix.getColLowerEqualsUpper()
&& lix.getInput().get(4).getName().equals(itervar)
&& rix.getInput().get(3).getName().equals(itervar) )
{
apply = true;
rowIx = false;
}
}
}
}
//apply rewrite if possible
if( apply )
{
Hop root = csb.get_hops().get(0);
LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
UnaryOp uop = (UnaryOp) lix.getInput().get(1);
IndexingOp rix = (IndexingOp) uop.getInput().get(0);
int index1 = rowIx ? 2 : 4;
int index2 = rowIx ? 3 : 5;
//modify left indexing bounds
HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index1), index1 );
HopRewriteUtils.addChildReference(lix, from, index1);
HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index2), index2 );
HopRewriteUtils.addChildReference(lix, to, index2);
//modify right indexing
HopRewriteUtils.removeChildReferenceByPos(rix, rix.getInput().get(index1-1), index1-1 );
HopRewriteUtils.addChildReference(rix, from, index1-1);
HopRewriteUtils.removeChildReferenceByPos(rix, rix.getInput().get(index2-1), index2-1 );
HopRewriteUtils.addChildReference(rix, to, index2-1);
rix.refreshSizeInformation();
uop.refreshSizeInformation();
lix.refreshSizeInformation();
ret = csb;
//ret.liveIn().removeVariable(itervar);
LOG.debug("Applied vectorizeElementwiseUnaryForLoop.");
}
return ret;
}
}