org.apache.sysml.hops.rewrite.RewriteIndexingVectorization Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LeftIndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
/**
* Rule: Indexing vectorization. This rewrite rule set simplifies
* multiple right / left indexing accesses within a DAG into row/column
* index accesses, which is beneficial for two reasons: (1) it is an
* enabler for later row/column partitioning, and (2) it reduces the number
* of operations over potentially large data (i.e., prevents unnecessary MR
* operations and reduces pressure on the buffer pool due to copy on write
* on left indexing).
*
*/
public class RewriteIndexingVectorization extends HopRewriteRule
{
private static final Log LOG = LogFactory.getLog(RewriteIndexingVectorization.class.getName());
@Override
public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state)
throws HopsException
{
if( roots == null )
return roots;
for( Hop h : roots )
rule_IndexingVectorization( h );
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
throws HopsException
{
if( root == null )
return root;
rule_IndexingVectorization( root );
return root;
}
/**
*
* @param hop
* @param descendFirst
* @throws HopsException
*/
private void rule_IndexingVectorization( Hop hop )
throws HopsException
{
if(hop.getVisited() == Hop.VisitStatus.DONE)
return;
//recursively process children
for( int i=0; i X[i,];
vectorizeLeftIndexing( hi ); //e.g., multiple left indexing X[i,1], X[i,3] -> X[i,];
//process childs recursively after rewrites
rule_IndexingVectorization( hi );
}
hop.setVisited(Hop.VisitStatus.DONE);
}
/**
* Note: unnecessary row or column indexing then later removed via
* dynamic rewrites
*
* @param hop
* @throws HopsException
*/
@SuppressWarnings("unused")
private void vectorizeRightIndexing( Hop hop )
throws HopsException
{
if( hop instanceof IndexingOp ) //right indexing
{
IndexingOp ihop0 = (IndexingOp) hop;
boolean isSingleRow = ihop0.getRowLowerEqualsUpper();
boolean isSingleCol = ihop0.getColLowerEqualsUpper();
boolean appliedRow = false;
//search for multiple indexing in same row
if( isSingleRow && isSingleCol ){
Hop input = ihop0.getInput().get(0);
//find candidate set
//dependence on common subexpression elimination to find equal input / row expression
ArrayList ihops = new ArrayList();
ihops.add(ihop0);
for( Hop c : input.getParent() ){
if( c != ihop0 && c instanceof IndexingOp && c.getInput().get(0) == input
&& ((IndexingOp) c).getRowLowerEqualsUpper()
&& c.getInput().get(1)==ihop0.getInput().get(1) )
{
ihops.add( c );
}
}
//apply rewrite if found candidates
if( ihops.size() > 1 ){
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, input,
ihop0.getInput().get(1), ihop0.getInput().get(1), new LiteralOp(1),
HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newRix.refreshSizeInformation();
//rewire current operator and all candidates
for( Hop c : ihops ) {
HopRewriteUtils.removeChildReference(c, input); //input data
HopRewriteUtils.addChildReference(c, newRix, 0);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(1),1); //row lower expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 1);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(2),2); //row upper expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 2);
c.refreshSizeInformation();
}
appliedRow = true;
LOG.debug("Applied vectorizeRightIndexingRow");
}
}
//search for multiple indexing in same col
if( isSingleRow && isSingleCol && !appliedRow ){
Hop input = ihop0.getInput().get(0);
//find candidate set
//dependence on common subexpression elimination to find equal input / row expression
ArrayList ihops = new ArrayList();
ihops.add(ihop0);
for( Hop c : input.getParent() ){
if( c != ihop0 && c instanceof IndexingOp && c.getInput().get(0) == input
&& ((IndexingOp) c).getColLowerEqualsUpper()
&& c.getInput().get(3)==ihop0.getInput().get(3) )
{
ihops.add( c );
}
}
//apply rewrite if found candidates
if( ihops.size() > 1 ){
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, input,
new LiteralOp(1), HopRewriteUtils.createValueHop(input, true),
ihop0.getInput().get(3), ihop0.getInput().get(3), false, true);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newRix.refreshSizeInformation();
//rewire current operator and all candidates
for( Hop c : ihops ) {
HopRewriteUtils.removeChildReference(c, input); //input data
HopRewriteUtils.addChildReference(c, newRix, 0);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(3),3); //col lower expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 3);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(4),4); //col upper expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 4);
c.refreshSizeInformation();
}
LOG.debug("Applied vectorizeRightIndexingCol");
}
}
}
}
/**
*
* @param hop
* @throws HopsException
*/
@SuppressWarnings("unchecked")
private void vectorizeLeftIndexing( Hop hop )
throws HopsException
{
if( hop instanceof LeftIndexingOp ) //left indexing
{
LeftIndexingOp ihop0 = (LeftIndexingOp) hop;
boolean isSingleRow = ihop0.getRowLowerEqualsUpper();
boolean isSingleCol = ihop0.getColLowerEqualsUpper();
boolean appliedRow = false;
if( isSingleRow && isSingleCol )
{
//collect simple chains (w/o multiple consumers) of left indexing ops
ArrayList ihops = new ArrayList();
ihops.add(ihop0);
Hop current = ihop0;
while( current.getInput().get(0) instanceof LeftIndexingOp ) {
LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
if( tmp.getParent().size()>1 //multiple consumers, i.e., not a simple chain
|| !((LeftIndexingOp) tmp).getRowLowerEqualsUpper() //row merge not applicable
|| tmp.getInput().get(2) != ihop0.getInput().get(2) //not the same row
|| tmp.getInput().get(0).getDim2() <= 1 ) //target is single column or unknown
{
break;
}
ihops.add( tmp );
current = tmp;
}
//apply rewrite if found candidates
if( ihops.size() > 1 ){
Hop input = current.getInput().get(0);
Hop rowExpr = ihop0.getInput().get(2); //keep before reset
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, input,
rowExpr, rowExpr, new LiteralOp(1),
HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newRix.refreshSizeInformation();
//rewrite bottom left indexing operator
HopRewriteUtils.removeChildReference(current, input); //input data
HopRewriteUtils.addChildReference(current, newRix, 0);
//reset row index all candidates
for( Hop c : ihops ) {
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(2), 2); //row lower expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 2);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(3), 3); //row upper expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 3);
c.refreshSizeInformation();
}
//new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
//(note: it's important to clone the parent list before creating newLix on top of ihop0)
ArrayList ihop0parents = (ArrayList) ihop0.getParent().clone();
ArrayList ihop0parentsPos = new ArrayList();
for( Hop parent : ihop0parents ) {
int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp); //input data
ihop0parentsPos.add(posp);
}
LeftIndexingOp newLix = new LeftIndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, input, ihop0,
rowExpr, rowExpr, new LiteralOp(1),
HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newLix.refreshSizeInformation();
for( int i=0; i ihops = new ArrayList();
ihops.add(ihop0);
Hop current = ihop0;
while( current.getInput().get(0) instanceof LeftIndexingOp ) {
LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
if( tmp.getParent().size()>1 //multiple consumers, i.e., not a simple chain
|| !((LeftIndexingOp) tmp).getColLowerEqualsUpper() //row merge not applicable
|| tmp.getInput().get(4) != ihop0.getInput().get(4) //not the same col
|| tmp.getInput().get(0).getDim1() <= 1 ) //target is single row or unknown
{
break;
}
ihops.add( tmp );
current = tmp;
}
//apply rewrite if found candidates
if( ihops.size() > 1 ){
Hop input = current.getInput().get(0);
Hop colExpr = ihop0.getInput().get(4); //keep before reset
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, input,
new LiteralOp(1), HopRewriteUtils.createValueHop(input, true),
colExpr, colExpr, false, true);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newRix.refreshSizeInformation();
//rewrite bottom left indexing operator
HopRewriteUtils.removeChildReference(current, input); //input data
HopRewriteUtils.addChildReference(current, newRix, 0);
//reset row index all candidates
for( Hop c : ihops ) {
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(4), 4); //col lower expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 4);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(5), 5); //col upper expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 5);
c.refreshSizeInformation();
}
//new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
//(note: it's important to clone the parent list before creating newLix on top of ihop0)
ArrayList ihop0parents = (ArrayList) ihop0.getParent().clone();
ArrayList ihop0parentsPos = new ArrayList();
for( Hop parent : ihop0parents ) {
int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp); //input data
ihop0parentsPos.add(posp);
}
LeftIndexingOp newLix = new LeftIndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, input, ihop0,
new LiteralOp(1), HopRewriteUtils.createValueHop(input, true),
colExpr, colExpr, false, true);
HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newLix.refreshSizeInformation();
for( int i=0; i