org.apache.sysml.hops.codegen.template.TemplateUtils Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.codegen.template;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import org.apache.commons.lang.ArrayUtils;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;
import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType;
import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysml.runtime.util.UtilFunctions;
public class TemplateUtils
{
public static final TemplateBase[] TEMPLATES = new TemplateBase[]{new TemplateRow(), new TemplateCell(), new TemplateOuterProduct()};
public static boolean isVector(Hop hop) {
return (hop.getDataType() == DataType.MATRIX
&& (hop.getDim1() != 1 && hop.getDim2() == 1
|| hop.getDim1() == 1 && hop.getDim2() != 1 ) );
}
public static boolean isColVector(CNode hop) {
return (hop.getDataType() == DataType.MATRIX
&& hop.getNumRows() != 1 && hop.getNumCols() == 1);
}
public static boolean isRowVector(CNode hop) {
return (hop.getDataType() == DataType.MATRIX
&& hop.getNumRows() == 1 && hop.getNumCols() != 1);
}
public static CNode wrapLookupIfNecessary(CNode node, Hop hop) {
CNode ret = node;
if( isColVector(node) )
ret = new CNodeUnary(node, UnaryType.LOOKUP_R);
else if( isRowVector(node) )
ret = new CNodeUnary(node, UnaryType.LOOKUP_C);
else if( node instanceof CNodeData && hop.getDataType().isMatrix() )
ret = new CNodeUnary(node, UnaryType.LOOKUP_RC);
return ret;
}
public static boolean isMatrix(Hop hop) {
return (hop.getDataType() == DataType.MATRIX && hop.getDim1() != 1 && hop.getDim2()!=1);
}
public static boolean isVectorOrScalar(Hop hop) {
return hop.dimsKnown() && (hop.getDataType() == DataType.SCALAR || isVector(hop) );
}
public static boolean isBinaryMatrixRowVector(Hop hop) {
if( !(hop instanceof BinaryOp) )
return false;
Hop left = hop.getInput().get(0);
Hop right = hop.getInput().get(1);
return left.dimsKnown() && right.dimsKnown()
&& left.getDataType().isMatrix() && right.getDataType().isMatrix()
&& left.getDim1() > right.getDim1();
}
public static boolean isBinaryMatrixColVector(Hop hop) {
if( !(hop instanceof BinaryOp) )
return false;
Hop left = hop.getInput().get(0);
Hop right = hop.getInput().get(1);
return left.dimsKnown() && right.dimsKnown()
&& left.getDataType().isMatrix() && right.getDataType().isMatrix()
&& left.getDim2() > right.getDim2();
}
public static boolean hasMatrixInput( Hop hop ) {
for( Hop c : hop.getInput() )
if( isMatrix(c) )
return true;
return false;
}
public static boolean isOperationSupported(Hop h) {
if(h instanceof UnaryOp)
return UnaryType.contains(((UnaryOp)h).getOp().name());
else if(h instanceof BinaryOp)
return BinType.contains(((BinaryOp)h).getOp().name());
else if(h instanceof TernaryOp)
return TernaryType.contains(((TernaryOp)h).getOp().name());
else if(h instanceof ParameterizedBuiltinOp)
return TernaryType.contains(((ParameterizedBuiltinOp)h).getOp().name());
return false;
}
private static void rfindChildren(Hop hop, HashSet children ) {
if( hop instanceof UnaryOp || (hop instanceof BinaryOp && hop.getInput().get(0).getDataType() == DataType.MATRIX && TemplateUtils.isVectorOrScalar( hop.getInput().get(1))) || (hop instanceof BinaryOp && TemplateUtils.isVectorOrScalar( hop.getInput().get(0)) && hop.getInput().get(1).getDataType() == DataType.MATRIX) //unary operation or binary operaiton with one matrix and a scalar
&& hop.getDataType() == DataType.MATRIX )
{
if(!children.contains(hop))
children.add(hop);
Hop matrix = TemplateUtils.isMatrix(hop.getInput().get(0)) ? hop.getInput().get(0) : hop.getInput().get(1);
rfindChildren(matrix,children);
}
else
children.add(hop);
}
private static Hop findCommonChild(Hop hop1, Hop hop2) {
//this method assumes that each two nodes have at most one common child
LinkedHashSet children1 = new LinkedHashSet();
LinkedHashSet children2 = new LinkedHashSet();
rfindChildren(hop1, children1 );
rfindChildren(hop2, children2 );
//iterate on one set and find the first common child in the other set
Iterator iter = children1.iterator();
while (iter.hasNext()) {
Hop candidate = iter.next();
if(children2.contains(candidate))
return candidate;
}
return null;
}
public static Hop commonChild(ArrayList _adddedMatrices, Hop input) {
Hop currentChild = null;
//loop on every added matrix and find its common child with the input, if all of them have the same common child then return it, otherwise null
for(Hop addedMatrix : _adddedMatrices)
{
Hop child = findCommonChild(addedMatrix,input);
if(child == null) // did not find a common child
return null;
if(currentChild == null) // first common child to be seen
currentChild = child;
else if(child.getHopID() != currentChild.getHopID())
return null;
}
return currentChild;
}
public static HashSet rGetInputHopIDs( CNode node, HashSet ids ) {
if( node instanceof CNodeData && !node.isLiteral() )
ids.add(((CNodeData)node).getHopID());
for( CNode c : node.getInput() )
rGetInputHopIDs(c, ids);
return ids;
}
public static Hop[] mergeDistinct(HashSet ids, Hop[] input1, Hop[] input2) {
Hop[] ret = new Hop[ids.size()];
int pos = 0;
for( Hop[] input : new Hop[][]{input1, input2} )
for( Hop c : input )
if( ids.contains(c.getHopID()) )
ret[pos++] = c;
return ret;
}
public static TemplateBase createTemplate(TemplateType type) {
return createTemplate(type, false);
}
public static TemplateBase createTemplate(TemplateType type, boolean closed) {
TemplateBase tpl = null;
switch( type ) {
case CellTpl: tpl = new TemplateCell(closed); break;
case RowTpl: tpl = new TemplateRow(closed); break;
case MultiAggTpl: tpl = new TemplateMultiAgg(closed); break;
case OuterProdTpl: tpl = new TemplateOuterProduct(closed); break;
}
return tpl;
}
public static TemplateBase[] createCompatibleTemplates(TemplateType type, boolean closed) {
TemplateBase[] tpl = null;
switch( type ) {
case CellTpl: tpl = new TemplateBase[]{new TemplateCell(closed), new TemplateRow(closed)}; break;
case RowTpl: tpl = new TemplateBase[]{new TemplateRow(closed)}; break;
case MultiAggTpl: tpl = new TemplateBase[]{new TemplateMultiAgg(closed)}; break;
case OuterProdTpl: tpl = new TemplateBase[]{new TemplateOuterProduct(closed)}; break;
}
return tpl;
}
public static CellType getCellType(Hop hop) {
return (hop instanceof AggBinaryOp) ? CellType.FULL_AGG :
(hop instanceof AggUnaryOp) ? ((((AggUnaryOp) hop).getDirection() == Direction.RowCol) ?
CellType.FULL_AGG : CellType.ROW_AGG) : CellType.NO_AGG;
}
public static RowType getRowType(Hop output, Hop input) {
if( HopRewriteUtils.isEqualSize(output, input) )
return RowType.NO_AGG;
else if( output.getDim1()==input.getDim1() && output.getDim2()==1 )
return RowType.ROW_AGG;
else if( output.getDim1()==input.getDim2() && output.getDim2()==1 )
return RowType.COL_AGG_T;
else
return RowType.COL_AGG;
}
public static AggOp getAggOp(Hop hop) {
return (hop instanceof AggUnaryOp) ? ((AggUnaryOp)hop).getOp() :
(hop instanceof AggBinaryOp) ? AggOp.SUM : null;
}
public static OutProdType getOuterProductType(Hop X, Hop U, Hop V, Hop out) {
if( out.getDataType() == DataType.SCALAR )
return OutProdType.AGG_OUTER_PRODUCT;
else if( (out instanceof AggBinaryOp && (out.getInput().get(0) == U
|| HopRewriteUtils.isTransposeOperation(out.getInput().get(0))
&& out.getInput().get(0).getInput().get(0) == U))
|| HopRewriteUtils.isTransposeOperation(out) )
return OutProdType.LEFT_OUTER_PRODUCT;
else if( out instanceof AggBinaryOp && (out.getInput().get(1) == V
|| HopRewriteUtils.isTransposeOperation(out.getInput().get(1))
&& out.getInput().get(1).getInput().get(0) == V ) )
return OutProdType.RIGHT_OUTER_PRODUCT;
else if( out instanceof BinaryOp && HopRewriteUtils.isEqualSize(out.getInput().get(0), out.getInput().get(1)) )
return OutProdType.CELLWISE_OUTER_PRODUCT;
//should never come here
throw new RuntimeException("Undefined outer product type for hop "+out.getHopID());
}
public static boolean isLookup(CNode node) {
return isUnary(node, UnaryType.LOOKUP_R, UnaryType.LOOKUP_C, UnaryType.LOOKUP_RC)
|| isTernary(node, TernaryType.LOOKUP_RC1);
}
public static boolean isUnary(CNode node, UnaryType...types) {
return node instanceof CNodeUnary
&& ArrayUtils.contains(types, ((CNodeUnary)node).getType());
}
public static boolean isTernary(CNode node, TernaryType...types) {
return node instanceof CNodeTernary
&& ArrayUtils.contains(types, ((CNodeTernary)node).getType());
}
public static CNodeData createCNodeData(Hop hop, boolean compileLiterals) {
CNodeData cdata = new CNodeData(hop);
cdata.setLiteral(hop instanceof LiteralOp && (compileLiterals
|| UtilFunctions.isIntegerNumber(((LiteralOp)hop).getStringValue())));
return cdata;
}
public static CNode skipTranspose(CNode cdataOrig, Hop hop, HashMap tmp, boolean compileLiterals) {
if( HopRewriteUtils.isTransposeOperation(hop) ) {
CNode cdata = tmp.get(hop.getInput().get(0).getHopID());
if( cdata == null ) { //never accessed
cdata = TemplateUtils.createCNodeData(hop.getInput().get(0), compileLiterals);
tmp.put(hop.getInput().get(0).getHopID(), cdata);
}
tmp.put(hop.getHopID(), cdata);
return cdata;
}
else {
return cdataOrig;
}
}
public static boolean hasTransposeParentUnderOuterProduct(Hop hop) {
for( Hop p : hop.getParent() )
if( HopRewriteUtils.isTransposeOperation(p) )
for( Hop p2 : p.getParent() )
if( HopRewriteUtils.isOuterProductLikeMM(p2) )
return true;
return false;
}
public static boolean hasSingleOperation(CNodeTpl tpl) {
CNode output = tpl.getOutput();
return (output instanceof CNodeUnary || output instanceof CNodeBinary
|| output instanceof CNodeTernary) && hasOnlyDataNodeOrLookupInputs(output);
}
public static boolean hasNoOperation(CNodeTpl tpl) {
return tpl.getOutput() instanceof CNodeData
|| isLookup(tpl.getOutput());
}
public static boolean hasOnlyDataNodeOrLookupInputs(CNode node) {
boolean ret = true;
for( CNode c : node.getInput() )
ret &= (c instanceof CNodeData || (c instanceof CNodeUnary
&& (((CNodeUnary)c).getType()==UnaryType.LOOKUP0
|| ((CNodeUnary)c).getType()==UnaryType.LOOKUP_R
|| ((CNodeUnary)c).getType()==UnaryType.LOOKUP_RC)));
return ret;
}
public static int countVectorIntermediates(CNode node, HashSet memo) {
//memoization to prevent double counting
if( memo.contains(node.getID()) )
return 0;
memo.add(node.getID());
//compute vector requirements over all inputs
int ret = 0;
for( CNode c : node.getInput() )
ret += countVectorIntermediates(c, memo);
//compute vector requirements of current node
int cntBin = ((node instanceof CNodeBinary
&& ((CNodeBinary)node).getType().isVectorScalarPrimitive()) ? 1 : 0);
int cntUn = ((node instanceof CNodeUnary
&& ((CNodeUnary)node).getType().isVectorScalarPrimitive()) ? 1 : 0);
return ret + cntBin + cntUn;
}
public static boolean isType(TemplateType type, TemplateType... validTypes) {
return ArrayUtils.contains(validTypes, type);
}
public static boolean hasCommonRowTemplateMatrixInput(Hop input1, Hop input2, CPlanMemoTable memo) {
//if second input has no row template, it's always true
if( !memo.contains(input2.getHopID(), TemplateType.RowTpl) )
return true;
//check for common row template input
long tmp1 = getRowTemplateMatrixInput(input1, memo);
long tmp2 = getRowTemplateMatrixInput(input2, memo);
return (tmp1 == tmp2);
}
public static long getRowTemplateMatrixInput(Hop current, CPlanMemoTable memo) {
MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.RowTpl);
long ret = -1;
for( int i=0; ret<0 && i