All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.hops.recompile.LiteralReplacement Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.recompile;

import java.util.ArrayList;

import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.Hop.VisitStatus;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.utils.Statistics;

public class LiteralReplacement 
{
	
	//internal configuration parameters
	private static final long REPLACE_LITERALS_MAX_MATRIX_SIZE = 1000000; //10^6 cells (8MB)
	private static final boolean REPORT_LITERAL_REPLACE_OPS_STATS = true; 	
	
	/**
	 * 
	 * @param hop
	 * @param vars
	 * @throws DMLRuntimeException
	 */
	protected static void rReplaceLiterals( Hop hop, LocalVariableMap vars ) 
		throws DMLRuntimeException
	{
		if( hop.getVisited() == VisitStatus.DONE )
			return;

		if( hop.getInput() != null )
		{
			//indexed access to allow parent-child modifications
			for( int i=0; i 1 ) { //multiple parents
						ArrayList parents = new ArrayList(c.getParent());
						for( Hop p : parents ) {
							int pos = HopRewriteUtils.getChildReferencePos(p, c);
							HopRewriteUtils.removeChildReferenceByPos(p, c, pos);
							HopRewriteUtils.addChildReference(p, lit, pos);
						}
					}
					else { //current hop is only parent
						HopRewriteUtils.removeChildReferenceByPos(hop, c, i);
						HopRewriteUtils.addChildReference(hop, lit, i);
					}
				}
				//recursively process children
				else 
				{
					rReplaceLiterals(c, vars);	
				}			
			}
		}
		
		hop.setVisited(VisitStatus.DONE);
	}
	

	///////////////////////////////
	// Literal replacement rules
	///////////////////////////////
	
	/**
	 * 
	 * @param c
	 * @param vars
	 * @return
	 */
	private static LiteralOp replaceLiteralScalarRead(Hop c, LocalVariableMap vars)
	{
		LiteralOp ret = null;
		
		//scalar read - literal replacement
		if( c instanceof DataOp && ((DataOp)c).getDataOpType() != DataOpTypes.PERSISTENTREAD 
			&& c.getDataType()==DataType.SCALAR )
		{
			Data dat = vars.get(c.getName());
			if( dat != null ) //required for selective constant propagation
			{
				ScalarObject sdat = (ScalarObject)dat;
				switch( sdat.getValueType() ) {
					case INT:
						ret = new LiteralOp(sdat.getLongValue());		
						break;
					case DOUBLE:
						ret = new LiteralOp(sdat.getDoubleValue());	
						break;
					case BOOLEAN:
						ret = new LiteralOp(sdat.getBooleanValue());
						break;
					default:	
						//otherwise: do nothing
				}
			}
		}
		
		return ret;
	}
	
	/**
	 * 
	 * @param c
	 * @param vars
	 * @return
	 */
	private static LiteralOp replaceLiteralValueTypeCastScalarRead( Hop c, LocalVariableMap vars )
	{
		LiteralOp ret = null;
		
		//as.double/as.integer/as.boolean over scalar read - literal replacement
		if( c instanceof UnaryOp && (((UnaryOp)c).getOp() == OpOp1.CAST_AS_DOUBLE 
			|| ((UnaryOp)c).getOp() == OpOp1.CAST_AS_INT || ((UnaryOp)c).getOp() == OpOp1.CAST_AS_BOOLEAN )	
				&& c.getInput().get(0) instanceof DataOp && c.getDataType()==DataType.SCALAR )
		{
			Data dat = vars.get(c.getInput().get(0).getName());
			if( dat != null ) //required for selective constant propagation
			{
				ScalarObject sdat = (ScalarObject)dat;
				UnaryOp cast = (UnaryOp) c;
				switch( cast.getOp() ) {
					case CAST_AS_INT:
						ret = new LiteralOp(sdat.getLongValue());		
						break;
					case CAST_AS_DOUBLE:
						ret = new LiteralOp(sdat.getDoubleValue());		
						break;						
					case CAST_AS_BOOLEAN:
						ret = new LiteralOp(sdat.getBooleanValue());		
						break;
					default:	
						//otherwise: do nothing
				}
			}	
		}
		
		return ret;
	}
	
	/**
	 * 
	 * @param c
	 * @param vars
	 * @return
	 * @throws DMLRuntimeException 
	 */
	private static LiteralOp replaceLiteralValueTypeCastLiteral( Hop c, LocalVariableMap vars ) 
		throws DMLRuntimeException
	{
		LiteralOp ret = null;
		
		//as.double/as.integer/as.boolean over scalar literal (potentially created by other replacement 
		//rewrite in same dag) - literal replacement
		if( c instanceof UnaryOp && (((UnaryOp)c).getOp() == OpOp1.CAST_AS_DOUBLE 
			|| ((UnaryOp)c).getOp() == OpOp1.CAST_AS_INT || ((UnaryOp)c).getOp() == OpOp1.CAST_AS_BOOLEAN )	
				&& c.getInput().get(0) instanceof LiteralOp )
		{
			LiteralOp sdat = (LiteralOp)c.getInput().get(0);
			UnaryOp cast = (UnaryOp) c;
			try
			{
				switch( cast.getOp() ) {
					case CAST_AS_INT:
						long ival = HopRewriteUtils.getIntValue(sdat);
						ret = new LiteralOp(ival);		
						break;
					case CAST_AS_DOUBLE:
						double dval = HopRewriteUtils.getDoubleValue(sdat);
						ret = new LiteralOp(dval);		
						break;						
					case CAST_AS_BOOLEAN:
						boolean bval = HopRewriteUtils.getBooleanValue(sdat);
						ret = new LiteralOp(bval);		
						break;
					default:	
						//otherwise: do nothing
				}
			}
			catch(HopsException ex) {
				throw new DMLRuntimeException(ex);
			}
		}
		
		return ret;
	}
	
	/**
	 * 
	 * @param c
	 * @param vars
	 * @return
	 * @throws DMLRuntimeException 
	 */
	private static LiteralOp replaceLiteralDataTypeCastMatrixRead( Hop c, LocalVariableMap vars ) 
		throws DMLRuntimeException
	{
		LiteralOp ret = null;
		
		//as.scalar/matrix read - literal replacement
		if( c instanceof UnaryOp && ((UnaryOp)c).getOp() == OpOp1.CAST_AS_SCALAR 
			&& c.getInput().get(0) instanceof DataOp )
		{
			Data dat = vars.get(c.getInput().get(0).getName());
			if( dat != null ) //required for selective constant propagation
			{
				//cast as scalar (see VariableCPInstruction)
				MatrixObject mo = (MatrixObject)dat;
				MatrixBlock mBlock = mo.acquireRead();
				if( mBlock.getNumRows()!=1 || mBlock.getNumColumns()!=1 )
					throw new DMLRuntimeException("Dimension mismatch - unable to cast matrix of dimension ("+mBlock.getNumRows()+" x "+mBlock.getNumColumns()+") to scalar.");
				double value = mBlock.getValue(0,0);
				mo.release();
				
				//literal substitution (always double)
				ret = new LiteralOp(value);
			}
		}

		return ret;
	}
	
	/**
	 * 
	 * @param c
	 * @param vars
	 * @return
	 * @throws DMLRuntimeException 
	 */
	private static LiteralOp replaceLiteralValueTypeCastRightIndexing( Hop c, LocalVariableMap vars ) 
		throws DMLRuntimeException
	{
		LiteralOp ret = null;
		
		//as.scalar/right indexing w/ literals/vars and matrix less than 10^6 cells
		if( c instanceof UnaryOp && ((UnaryOp)c).getOp() == OpOp1.CAST_AS_SCALAR 
			&& c.getInput().get(0) instanceof IndexingOp )
		{
			IndexingOp rix = (IndexingOp)c.getInput().get(0);
			Hop data = rix.getInput().get(0);
			Hop rl = rix.getInput().get(1);
			Hop ru = rix.getInput().get(2);
			Hop cl = rix.getInput().get(3);
			Hop cu = rix.getInput().get(4);
			if(    rix.dimsKnown() && rix.getDim1()==1 && rix.getDim2()==1
				&& data instanceof DataOp && vars.keySet().contains(data.getName())
				&& isIntValueDataLiteral(rl, vars) && isIntValueDataLiteral(ru, vars) 
				&& isIntValueDataLiteral(cl, vars) && isIntValueDataLiteral(cu, vars) ) 
			{
				long rlval = getIntValueDataLiteral(rl, vars);
				long clval = getIntValueDataLiteral(cl, vars);

				MatrixObject mo = (MatrixObject)vars.get(data.getName());
				
				//get the dimension information from the matrix object because the hop
				//dimensions might not have been updated during recompile
				if( mo.getNumRows()*mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE )
				{
					MatrixBlock mBlock = mo.acquireRead();
					double value = mBlock.getValue((int)rlval-1,(int)clval-1);
					mo.release();
					
					//literal substitution (always double)
					ret = new LiteralOp(value);
				}
			}
		}
		
		return ret;
	}

	/**
	 * 
	 * @param c
	 * @param vars
	 * @return
	 * @throws DMLRuntimeException 
	 */
	private static LiteralOp replaceLiteralFullUnaryAggregate( Hop c, LocalVariableMap vars ) 
		throws DMLRuntimeException
	{
		LiteralOp ret = null;
		
		//full unary aggregate w/ matrix less than 10^6 cells
		if( c instanceof AggUnaryOp 
			&& isReplaceableUnaryAggregate((AggUnaryOp)c)
			&& c.getInput().get(0) instanceof DataOp
			&& vars.keySet().contains(c.getInput().get(0).getName()) )
		{
			Hop data = c.getInput().get(0);
			MatrixObject mo = (MatrixObject) vars.get(data.getName());
				
			//get the dimension information from the matrix object because the hop
			//dimensions might not have been updated during recompile
			if( mo.getNumRows()*mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE )
			{
				MatrixBlock mBlock = mo.acquireRead();
				double value = replaceUnaryAggregate((AggUnaryOp)c, mBlock);
				mo.release();
					
				//literal substitution (always double)
				ret = new LiteralOp(value);
			}		
		}
		
		return ret;
	}
	
	/**
	 * 
	 * @param c
	 * @param vars
	 * @return
	 * @throws DMLRuntimeException 
	 */
	private static LiteralOp replaceLiteralFullUnaryAggregateRightIndexing( Hop c, LocalVariableMap vars ) 
		throws DMLRuntimeException
	{
		LiteralOp ret = null;
		
		//full unary aggregate w/ indexed matrix less than 10^6 cells
		if( c instanceof AggUnaryOp 
			&& isReplaceableUnaryAggregate((AggUnaryOp)c)
			&& c.getInput().get(0) instanceof IndexingOp
			&& c.getInput().get(0).getInput().get(0) instanceof DataOp  )
		{
			IndexingOp rix = (IndexingOp)c.getInput().get(0);
			Hop data = rix.getInput().get(0);
			Hop rl = rix.getInput().get(1);
			Hop ru = rix.getInput().get(2);
			Hop cl = rix.getInput().get(3);
			Hop cu = rix.getInput().get(4);
			
			if(    data instanceof DataOp && vars.keySet().contains(data.getName())
				&& isIntValueDataLiteral(rl, vars) && isIntValueDataLiteral(ru, vars) 
				&& isIntValueDataLiteral(cl, vars) && isIntValueDataLiteral(cu, vars)  ) 
			{
				long rlval = getIntValueDataLiteral(rl, vars);
				long ruval = getIntValueDataLiteral(ru, vars);
				long clval = getIntValueDataLiteral(cl, vars);
				long cuval = getIntValueDataLiteral(cu, vars);

				MatrixObject mo = (MatrixObject) vars.get(data.getName());	
				
				//get the dimension information from the matrix object because the hop
				//dimensions might not have been updated during recompile
				if( mo.getNumRows()*mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE )
				{
					MatrixBlock mBlock = mo.acquireRead();
					MatrixBlock mBlock2 = mBlock.sliceOperations((int)(rlval-1), (int)(ruval-1), (int)(clval-1), (int)(cuval-1), new MatrixBlock());
					double value = replaceUnaryAggregate((AggUnaryOp)c, mBlock2);
					mo.release();
						
					//literal substitution (always double)
					ret = new LiteralOp(value);
				}
			}		
		}
		
		return ret;
	}

	
	///////////////////////////////
	// Utility functions
	///////////////////////////////

	/**
	 * 
	 * @param h
	 * @param vars
	 * @return
	 */
	private static boolean isIntValueDataLiteral(Hop h, LocalVariableMap vars)
	{
		return (  (h instanceof DataOp && vars.keySet().contains(h.getName())) 
				|| h instanceof LiteralOp
				||(h instanceof UnaryOp && (((UnaryOp)h).getOp()==OpOp1.NROW || ((UnaryOp)h).getOp()==OpOp1.NCOL)
				   && h.getInput().get(0) instanceof DataOp && vars.keySet().contains(h.getInput().get(0).getName())) );
	}
	
	/**
	 * 
	 * @param hop
	 * @param vars
	 * @return
	 * @throws DMLRuntimeException 
	 */
	private static long getIntValueDataLiteral(Hop hop, LocalVariableMap vars) 
		throws DMLRuntimeException
	{
		long value = -1;
		
		try 
		{
			if( hop instanceof LiteralOp )
			{
				value = HopRewriteUtils.getIntValue((LiteralOp)hop);
			}
			else if( hop instanceof UnaryOp && ((UnaryOp)hop).getOp()==OpOp1.NROW )
			{
				//get the dimension information from the matrix object because the hop
				//dimensions might not have been updated during recompile
				MatrixObject mo = (MatrixObject)vars.get(hop.getInput().get(0).getName());
				value = mo.getNumRows();
			}
			else if( hop instanceof UnaryOp && ((UnaryOp)hop).getOp()==OpOp1.NCOL )
			{
				//get the dimension information from the matrix object because the hop
				//dimensions might not have been updated during recompile
				MatrixObject mo = (MatrixObject)vars.get(hop.getInput().get(0).getName());
				value = mo.getNumColumns();
			}
			else
			{
				ScalarObject sdat = (ScalarObject) vars.get(hop.getName());
				value = sdat.getLongValue();
			}
		}
		catch(HopsException ex)
		{
			throw new DMLRuntimeException("Failed to get int value for literal replacement", ex);
		}
		
		return value;
	}
	
	
	/**
	 * 
	 * @param auop
	 * @return
	 */
	private static boolean isReplaceableUnaryAggregate( AggUnaryOp auop )
	{
		boolean cdir = (auop.getDirection() == Direction.RowCol);		
		boolean cop = (  auop.getOp() == AggOp.SUM
				      || auop.getOp() == AggOp.SUM_SQ
				      || auop.getOp() == AggOp.MIN
				      || auop.getOp() == AggOp.MAX ); 
		
		return cdir && cop;
	}
	
	/**
	 * 
	 * @param auop
	 * @param mb
	 * @return
	 * @throws DMLRuntimeException
	 */
	private static double replaceUnaryAggregate( AggUnaryOp auop, MatrixBlock mb ) 
		throws DMLRuntimeException
	{
		//setup stats reporting if necessary
		boolean REPORT_STATS = (DMLScript.STATISTICS && REPORT_LITERAL_REPLACE_OPS_STATS); 
		long t0 = REPORT_STATS ? System.nanoTime() : 0;
		
		//compute required unary aggregate 
		double val = Double.MAX_VALUE;
		switch( auop.getOp() ) {
			case SUM: 
				val = mb.sum(); 
				break;
			case SUM_SQ:
				val = mb.sumSq();
				break;
			case MIN:
				val = mb.min(); 
				break;
			case MAX: 
				val = mb.max(); 
				break;
			default:
				throw new DMLRuntimeException("Unsupported unary aggregate replacement: "+auop.getOp());
		}

		//report statistics if necessary
		if( REPORT_STATS ){
			long t1 = System.nanoTime();
			Statistics.maintainCPHeavyHitters("rlit", t1-t0);
		}
		
		return val;
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy