All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.hops.codegen.template.TemplateUtils Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.codegen.template;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;

import org.apache.commons.lang.ArrayUtils;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;
import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType;
import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysml.runtime.util.UtilFunctions;

public class TemplateUtils 
{
	public static final TemplateBase[] TEMPLATES = new TemplateBase[]{new TemplateRow(), new TemplateCell(), new TemplateOuterProduct()};
	
	public static boolean isVector(Hop hop) {
		return (hop.getDataType() == DataType.MATRIX 
			&& (hop.getDim1() != 1 && hop.getDim2() == 1 
			  || hop.getDim1() == 1 && hop.getDim2() != 1 ) );
	}
	
	public static boolean isColVector(CNode hop) {
		return (hop.getDataType() == DataType.MATRIX 
			&& hop.getNumRows() != 1 && hop.getNumCols() == 1);
	}
	
	public static boolean isRowVector(CNode hop) {
		return (hop.getDataType() == DataType.MATRIX 
			&& hop.getNumRows() == 1 && hop.getNumCols() != 1);
	}
	
	public static CNode wrapLookupIfNecessary(CNode node, Hop hop) {
		CNode ret = node;
		if( isColVector(node) )
			ret = new CNodeUnary(node, UnaryType.LOOKUP_R);
		else if( isRowVector(node) )
			ret = new CNodeUnary(node, UnaryType.LOOKUP_C);
		else if( node instanceof CNodeData && hop.getDataType().isMatrix() )
			ret = new CNodeUnary(node, UnaryType.LOOKUP_RC);
		return ret;
	}
	
	public static boolean isMatrix(Hop hop) {
		return (hop.getDataType() == DataType.MATRIX && hop.getDim1() != 1 && hop.getDim2()!=1);
	}
	
	public static boolean isVectorOrScalar(Hop hop) {
		return hop.dimsKnown() && (hop.getDataType() == DataType.SCALAR || isVector(hop) );
	}
	
	public static boolean isBinaryMatrixRowVector(Hop hop) {
		if( !(hop instanceof BinaryOp) )
			return false;
		Hop left = hop.getInput().get(0);
		Hop right = hop.getInput().get(1);
		return left.dimsKnown() && right.dimsKnown() 
			&& left.getDataType().isMatrix() && right.getDataType().isMatrix()
			&& left.getDim1() > right.getDim1();
	}
	
	public static boolean isBinaryMatrixColVector(Hop hop) {
		if( !(hop instanceof BinaryOp) )
			return false;
		Hop left = hop.getInput().get(0);
		Hop right = hop.getInput().get(1);
		return left.dimsKnown() && right.dimsKnown() 
			&& left.getDataType().isMatrix() && right.getDataType().isMatrix()
			&& left.getDim2() > right.getDim2();
	}
	
	public static boolean hasMatrixInput( Hop hop ) {
		for( Hop c : hop.getInput() )
			if( isMatrix(c) )
				return true;
		return false;
	}

	public static boolean isOperationSupported(Hop h) {
		if(h instanceof  UnaryOp)
			return UnaryType.contains(((UnaryOp)h).getOp().name());
		else if(h instanceof BinaryOp)
			return BinType.contains(((BinaryOp)h).getOp().name());
		else if(h instanceof TernaryOp)
			return TernaryType.contains(((TernaryOp)h).getOp().name());
		else if(h instanceof ParameterizedBuiltinOp) 
			return TernaryType.contains(((ParameterizedBuiltinOp)h).getOp().name());
		return false;
	}

	private static void rfindChildren(Hop hop, HashSet children ) {		
		if( hop instanceof UnaryOp || (hop instanceof BinaryOp && hop.getInput().get(0).getDataType() == DataType.MATRIX  &&  TemplateUtils.isVectorOrScalar( hop.getInput().get(1))) || (hop instanceof BinaryOp && TemplateUtils.isVectorOrScalar( hop.getInput().get(0))  &&  hop.getInput().get(1).getDataType() == DataType.MATRIX)    //unary operation or binary operaiton with one matrix and a scalar
					&& 	hop.getDataType() == DataType.MATRIX )
		{	
			if(!children.contains(hop))
				children.add(hop);
			Hop matrix = TemplateUtils.isMatrix(hop.getInput().get(0)) ? hop.getInput().get(0) : hop.getInput().get(1);
			rfindChildren(matrix,children);
		}
		else 
			children.add(hop);
	}
	
	private static Hop findCommonChild(Hop hop1, Hop hop2) {
		//this method assumes that each two nodes have at most one common child 
		LinkedHashSet children1 = new LinkedHashSet();
		LinkedHashSet children2 = new LinkedHashSet();
		
		rfindChildren(hop1, children1 );
		rfindChildren(hop2, children2 );
		
		//iterate on one set and find the first common child in the other set
		Iterator iter = children1.iterator();
		while (iter.hasNext()) {
			Hop candidate = iter.next();
			if(children2.contains(candidate))
				return candidate;
		}
		return null;
	}
	
	public static Hop commonChild(ArrayList _adddedMatrices, Hop input) {
		Hop currentChild = null;
		//loop on every added matrix and find its common child with the input, if all of them have the same common child then return it, otherwise null 
		for(Hop addedMatrix : _adddedMatrices)
		{
			Hop child = findCommonChild(addedMatrix,input);
			if(child == null)  // did not find a common child
				return null;
			if(currentChild == null) // first common child to be seen
				currentChild = child;
			else if(child.getHopID() != currentChild.getHopID())
				return null;
		}
		return currentChild;
	}

	public static HashSet rGetInputHopIDs( CNode node, HashSet ids ) {
		if( node instanceof CNodeData && !node.isLiteral() )
			ids.add(((CNodeData)node).getHopID());
		
		for( CNode c : node.getInput() )
			rGetInputHopIDs(c, ids);
			
		return ids;
	}
	
	public static Hop[] mergeDistinct(HashSet ids, Hop[] input1, Hop[] input2) {
		Hop[] ret = new Hop[ids.size()];
		int pos = 0;
		for( Hop[] input : new Hop[][]{input1, input2} )
			for( Hop c : input )
				if( ids.contains(c.getHopID()) )
					ret[pos++] = c; 
		return ret;
	}

	public static TemplateBase createTemplate(TemplateType type) {
		return createTemplate(type, false);
	}
	
	public static TemplateBase createTemplate(TemplateType type, boolean closed) {
		TemplateBase tpl = null;
		switch( type ) {
			case CellTpl: tpl = new TemplateCell(closed); break;
			case RowTpl: tpl = new TemplateRow(closed); break;
			case MultiAggTpl: tpl = new TemplateMultiAgg(closed); break;
			case OuterProdTpl: tpl = new TemplateOuterProduct(closed); break;
		}
		return tpl;
	}
	
	public static TemplateBase[] createCompatibleTemplates(TemplateType type, boolean closed) {
		TemplateBase[] tpl = null;
		switch( type ) {
			case CellTpl: tpl = new TemplateBase[]{new TemplateCell(closed), new TemplateRow(closed)}; break;
			case RowTpl: tpl = new TemplateBase[]{new TemplateRow(closed)}; break;
			case MultiAggTpl: tpl = new TemplateBase[]{new TemplateMultiAgg(closed)}; break;
			case OuterProdTpl: tpl = new TemplateBase[]{new TemplateOuterProduct(closed)}; break;
		}
		return tpl;
	}
	
	public static CellType getCellType(Hop hop) {
		return (hop instanceof AggBinaryOp) ? CellType.FULL_AGG :
			(hop instanceof AggUnaryOp) ? ((((AggUnaryOp) hop).getDirection() == Direction.RowCol) ? 
			CellType.FULL_AGG : CellType.ROW_AGG) : CellType.NO_AGG;
	}
	
	public static RowType getRowType(Hop output, Hop input) {
		if( HopRewriteUtils.isEqualSize(output, input) )
			return RowType.NO_AGG;
		else if( output.getDim1()==input.getDim1() && output.getDim2()==1 )
			return RowType.ROW_AGG;
		else if( output.getDim1()==input.getDim2() && output.getDim2()==1 )
			return RowType.COL_AGG_T;
		else
			return RowType.COL_AGG;
	}
	
	public static AggOp getAggOp(Hop hop) {
		return (hop instanceof AggUnaryOp) ? ((AggUnaryOp)hop).getOp() :
			(hop instanceof AggBinaryOp) ? AggOp.SUM : null;
	}
	
	public static OutProdType getOuterProductType(Hop X, Hop U, Hop V, Hop out) {
		if( out.getDataType() == DataType.SCALAR ) 
			return OutProdType.AGG_OUTER_PRODUCT;
		else if( (out instanceof AggBinaryOp && (out.getInput().get(0) == U 
				|| HopRewriteUtils.isTransposeOperation(out.getInput().get(0))
				&& out.getInput().get(0).getInput().get(0) == U))
				|| HopRewriteUtils.isTransposeOperation(out) ) 
			return OutProdType.LEFT_OUTER_PRODUCT;
		else if( out instanceof AggBinaryOp && (out.getInput().get(1) == V
				|| HopRewriteUtils.isTransposeOperation(out.getInput().get(1)) 
				&& out.getInput().get(1).getInput().get(0) == V ) )
			return OutProdType.RIGHT_OUTER_PRODUCT;
		else if( out instanceof BinaryOp && HopRewriteUtils.isEqualSize(out.getInput().get(0), out.getInput().get(1)) )
			return OutProdType.CELLWISE_OUTER_PRODUCT;
		
		//should never come here
		throw new RuntimeException("Undefined outer product type for hop "+out.getHopID());
	}
	
	public static boolean isLookup(CNode node) {
		return isUnary(node, UnaryType.LOOKUP_R, UnaryType.LOOKUP_C, UnaryType.LOOKUP_RC)
			|| isTernary(node, TernaryType.LOOKUP_RC1);
	}
	
	public static boolean isUnary(CNode node, UnaryType...types) {
		return node instanceof CNodeUnary
			&& ArrayUtils.contains(types, ((CNodeUnary)node).getType());
	}
	
	public static boolean isTernary(CNode node, TernaryType...types) {
		return node instanceof CNodeTernary
			&& ArrayUtils.contains(types, ((CNodeTernary)node).getType());
	}

	public static CNodeData createCNodeData(Hop hop, boolean compileLiterals) {
		CNodeData cdata = new CNodeData(hop);
		cdata.setLiteral(hop instanceof LiteralOp && (compileLiterals 
				|| UtilFunctions.isIntegerNumber(((LiteralOp)hop).getStringValue())));
		return cdata;
	}

	public static CNode skipTranspose(CNode cdataOrig, Hop hop, HashMap tmp, boolean compileLiterals) {
		if( HopRewriteUtils.isTransposeOperation(hop) ) {
			CNode cdata = tmp.get(hop.getInput().get(0).getHopID());
			if( cdata == null ) { //never accessed
				cdata = TemplateUtils.createCNodeData(hop.getInput().get(0), compileLiterals);
				tmp.put(hop.getInput().get(0).getHopID(), cdata);
			}
			tmp.put(hop.getHopID(), cdata);
			return cdata;
		}
		else {
			return cdataOrig;
		}
	}
	
	public static boolean hasTransposeParentUnderOuterProduct(Hop hop) {
		for( Hop p : hop.getParent() )
			if( HopRewriteUtils.isTransposeOperation(p) )
				for( Hop p2 : p.getParent() )
					if( HopRewriteUtils.isOuterProductLikeMM(p2) )
						return true;
		return false;
	}

	public static boolean hasSingleOperation(CNodeTpl tpl) {
		CNode output = tpl.getOutput();
		return (output instanceof CNodeUnary || output instanceof CNodeBinary
				|| output instanceof CNodeTernary) && hasOnlyDataNodeOrLookupInputs(output);
	}
	
	public static boolean hasNoOperation(CNodeTpl tpl) {
		return tpl.getOutput() instanceof CNodeData 
			|| isLookup(tpl.getOutput());
	}
	
	public static boolean hasOnlyDataNodeOrLookupInputs(CNode node) {
		boolean ret = true;
		for( CNode c : node.getInput() )
			ret &= (c instanceof CNodeData || (c instanceof CNodeUnary 
				&& (((CNodeUnary)c).getType()==UnaryType.LOOKUP0 
					|| ((CNodeUnary)c).getType()==UnaryType.LOOKUP_R 
					|| ((CNodeUnary)c).getType()==UnaryType.LOOKUP_RC)));
		return ret;
	}
	
	public static int countVectorIntermediates(CNode node, HashSet memo) {
		//memoization to prevent double counting
		if( memo.contains(node.getID()) )
			return 0;
		memo.add(node.getID());
		//compute vector requirements over all inputs
		int ret = 0;
		for( CNode c : node.getInput() )
			ret += countVectorIntermediates(c, memo);
		//compute vector requirements of current node
		int cntBin = ((node instanceof CNodeBinary 
			&& ((CNodeBinary)node).getType().isVectorScalarPrimitive()) ? 1 : 0);
		int cntUn = ((node instanceof CNodeUnary
				&& ((CNodeUnary)node).getType().isVectorScalarPrimitive()) ? 1 : 0);
		return ret + cntBin + cntUn;
	}

	public static boolean isType(TemplateType type, TemplateType... validTypes) {
		return ArrayUtils.contains(validTypes, type);
	}
	
	public static boolean hasCommonRowTemplateMatrixInput(Hop input1, Hop input2, CPlanMemoTable memo) {
		//if second input has no row template, it's always true
		if( !memo.contains(input2.getHopID(), TemplateType.RowTpl) )
			return true;
		//check for common row template input
		long tmp1 = getRowTemplateMatrixInput(input1, memo);
		long tmp2 = getRowTemplateMatrixInput(input2, memo);
		return (tmp1 == tmp2);
	}
	
	public static long getRowTemplateMatrixInput(Hop current, CPlanMemoTable memo) {
		MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.RowTpl);
		long ret = -1;
		for( int i=0; ret<0 && i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy