All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.hops.cost.CostEstimatorStaticRuntime Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.cost;

import java.util.ArrayList;
import java.util.HashSet;

import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.lops.DataGen;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.MapMult;
import org.apache.sysml.lops.LopProperties.ExecType;
import org.apache.sysml.lops.MMTSJ.MMTSJType;
import org.apache.sysml.lops.compile.JobType;
import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
import org.apache.sysml.runtime.controlprogram.caching.LazyWriteBuffer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.instructions.CPInstructionParser;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.MRInstructionParser;
import org.apache.sysml.runtime.instructions.MRJobInstruction;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysml.runtime.instructions.mr.BinaryMRInstructionBase;
import org.apache.sysml.runtime.instructions.mr.CM_N_COVInstruction;
import org.apache.sysml.runtime.instructions.mr.DataGenMRInstruction;
import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction;
import org.apache.sysml.runtime.instructions.mr.IDistributedCacheConsumer;
import org.apache.sysml.runtime.instructions.mr.MMTSJMRInstruction;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.instructions.mr.MapMultChainInstruction;
import org.apache.sysml.runtime.instructions.mr.PickByCountInstruction;
import org.apache.sysml.runtime.instructions.mr.RemoveEmptyMRInstruction;
import org.apache.sysml.runtime.instructions.mr.TernaryInstruction;
import org.apache.sysml.runtime.instructions.mr.UnaryMRInstructionBase;
import org.apache.sysml.runtime.instructions.mr.MRInstruction.MRINSTRUCTION_TYPE;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysml.yarn.ropt.MRJobResourceInstruction;
import org.apache.sysml.yarn.ropt.YarnClusterAnalyzer;

/**
 * 
 */
public class CostEstimatorStaticRuntime extends CostEstimator
{
	
	//time-conversion
	private static final long DEFAULT_FLOPS = 2L * 1024 * 1024 * 1024; //2GFLOPS
	//private static final long UNKNOWN_TIME = -1;
	
	//floating point operations
	private static final double DEFAULT_NFLOP_NOOP = 10; 
	private static final double DEFAULT_NFLOP_UNKNOWN = 1; 
	private static final double DEFAULT_NFLOP_CP = 1; 	
	private static final double DEFAULT_NFLOP_TEXT_IO = 350; 
	
	//MR job latency
	private static final double DEFAULT_MR_JOB_LATENCY_LOCAL = 2;
	private static final double DEFAULT_MR_JOB_LATENCY_REMOTE = 20;
	private static final double DEFAULT_MR_TASK_LATENCY_LOCAL = 0.001;
	private static final double DEFAULT_MR_TASK_LATENCY_REMOTE = 1.5;
	
	//IO READ throughput
	private static final double DEFAULT_MBS_FSREAD_BINARYBLOCK_DENSE = 200;
	private static final double DEFAULT_MBS_FSREAD_BINARYBLOCK_SPARSE = 100;
	private static final double DEFAULT_MBS_HDFSREAD_BINARYBLOCK_DENSE = 150;
	private static final double DEFAULT_MBS_HDFSREAD_BINARYBLOCK_SPARSE = 75;
	//IO WRITE throughput
	private static final double DEFAULT_MBS_FSWRITE_BINARYBLOCK_DENSE = 150;
	private static final double DEFAULT_MBS_FSWRITE_BINARYBLOCK_SPARSE = 75;
	private static final double DEFAULT_MBS_HDFSWRITE_BINARYBLOCK_DENSE = 120;
	private static final double DEFAULT_MBS_HDFSWRITE_BINARYBLOCK_SPARSE = 60;
	private static final double DEFAULT_MBS_HDFSWRITE_TEXT_DENSE = 40;
	private static final double DEFAULT_MBS_HDFSWRITE_TEXT_SPARSE = 30;
	
	@Override
	@SuppressWarnings("unused")
	protected double getCPInstTimeEstimate( Instruction inst, VarStats[] vs, String[] args ) 
		throws DMLRuntimeException, DMLUnsupportedOperationException
	{
		CPInstruction cpinst = (CPInstruction)inst;
		
		//load time into mem
		double ltime = 0;
		if( !vs[0]._inmem ){
			ltime += getHDFSReadTime( vs[0]._rlen, vs[0]._clen, vs[0].getSparsity() );
			//eviction costs
			if( CacheableData.CACHING_WRITE_CACHE_ON_READ &&
				LazyWriteBuffer.getWriteBufferSize()=0 )
					vs[2] = stats[ minst.getInput3() ];
				
				if( vs[0] == null ) //scalar input, 
					vs[0] = _scalarStats;
				if( vs[1] == null ) //scalar input, 
					vs[1] = _scalarStats;
				if( vs[2] == null ) //scalar input
					vs[2] = _scalarStats;
			}
		}
		
		//maintain var status (CP output always inmem)
		vs[2]._inmem = true;
		
		ret[0] = vs;
		ret[1] = attr;
		
		return ret;
	}
	
	

	/////////////////////
	// Utilities       //
	/////////////////////	
	
	/**
	 * 
	 * @param inputVars
	 * @return
	 */
	private byte[] getInputIndexes(String[] inputVars)
	{
		byte[] inIx = new byte[inputVars.length];
		for( int i=0; i ixMap = new HashSet();
		for( byte ix : inIx )
			ixMap.add(ix);
		
		if( rdInst!=null && rdInst.length()>0 ) {
			rdInst = replaceInstructionPatch(rdInst);
			DataGenMRInstruction[] ins = MRInstructionParser.parseDataGenInstructions(rdInst);
			for( DataGenMRInstruction inst : ins )
				for( byte ix : inst.getAllIndexes() )
					ixMap.add(ix);
		}
		
		if( mapInst!=null && mapInst.length()>0 ) {
			mapInst = replaceInstructionPatch(mapInst);
			MRInstruction[] ins = MRInstructionParser.parseMixedInstructions(mapInst);
			for( MRInstruction inst : ins )
				for( byte ix : inst.getAllIndexes() )
					ixMap.add(ix);
		}
		
		//reduce indices
		HashSet ixRed = new HashSet();
		for( byte ix : retIx )
			ixRed.add(ix);
	

		if( shfInst!=null && shfInst.length()>0 ) {
			shfInst = replaceInstructionPatch(shfInst);
			MRInstruction[] ins = MRInstructionParser.parseMixedInstructions(shfInst);
			for( MRInstruction inst : ins )
				for( byte ix : inst.getAllIndexes() )
					ixRed.add(ix);
		}
		
		if( aggInst!=null && aggInst.length()>0 ) {
			aggInst = replaceInstructionPatch(aggInst);
			MRInstruction[] ins = MRInstructionParser.parseAggregateInstructions(aggInst);
			for( MRInstruction inst : ins )
				for( byte ix : inst.getAllIndexes() )
					ixRed.add(ix);
		}
		
		if( otherInst!=null && otherInst.length()>0 ) {
			otherInst = replaceInstructionPatch(otherInst);
			MRInstruction[] ins = MRInstructionParser.parseMixedInstructions(otherInst);
			for( MRInstruction inst : ins )
				for( byte ix : inst.getAllIndexes() )
					ixRed.add(ix);
		}

		//difference
		ixMap.retainAll(ixRed);
			
		//copy result
		byte[] ret = new byte[ixMap.size()];
		int i = 0;
		for( byte ix : ixMap )
			ret[i++] = ix;
		
		return ret;
	}
	
	/**
	 * 
	 * @param vs
	 * @param inputIx
	 * @param blocksize
	 * @param maxPMap
	 * @param jobtype
	 * @return
	 */
	private int computeNumMapTasks( VarStats[] vs, byte[] inputIx, double blocksize, int maxPMap, JobType jobtype )
	{
		//special cases
		if( jobtype == JobType.DATAGEN )
			return maxPMap;
			
		//input size, num blocks
		double mapInputSize = 0;
		int numBlocks = 0;
		for( int i=0; i indexes = new ArrayList();
		
		if( InstructionUtils.isDistributedCacheUsed(inst) ) {
			MRInstruction mrinst = MRInstructionParser.parseSingleInstruction(inst);
			if( mrinst instanceof IDistributedCacheConsumer )
				((IDistributedCacheConsumer)mrinst).addDistCacheIndex(inst, indexes);
		}
		
		if( !indexes.isEmpty() )
			return indexes.get(0);
		else
			return -1;
	}
	
	
	/////////////////////
	// I/O Costs       //
	/////////////////////	
	
	/**
	 * Returns the estimated read time from HDFS. 
	 * NOTE: Does not handle unknowns.
	 * 
	 * @param dm
	 * @param dn
	 * @param ds
	 * @return
	 */
	private double getHDFSReadTime( long dm, long dn, double ds )
	{
		boolean sparse = MatrixBlock.evalSparseFormatOnDisk(dm, dn, (long)(ds*dm*dn));
		double ret = ((double)MatrixBlock.estimateSizeOnDisk((long)dm, (long)dn, (long)(ds*dm*dn))) / (1024*1024);  		
		
		if( sparse )
			ret /= DEFAULT_MBS_HDFSREAD_BINARYBLOCK_SPARSE;
		else //dense
			ret /= DEFAULT_MBS_HDFSREAD_BINARYBLOCK_DENSE;
		
		return ret;
	}
	
	/**
	 * 
	 * @param dm
	 * @param dn
	 * @param ds
	 * @return
	 */
	private double getHDFSWriteTime( long dm, long dn, double ds )
	{
		boolean sparse = MatrixBlock.evalSparseFormatOnDisk(dm, dn, (long)(ds*dm*dn));
		
		double bytes = (double)MatrixBlock.estimateSizeOnDisk((long)dm, (long)dn, (long)(ds*dm*dn));
		double mbytes = bytes / (1024*1024);  		
		
		double ret = -1;
		if( sparse )
			ret = mbytes / DEFAULT_MBS_HDFSWRITE_BINARYBLOCK_SPARSE;
		else //dense
			ret = mbytes / DEFAULT_MBS_HDFSWRITE_BINARYBLOCK_DENSE;
		
		//if( LOG.isDebugEnabled() )
		//	LOG.debug("Costs[export] = "+ret+"s, "+mbytes+" MB ("+dm+","+dn+","+ds+").");
		
		
		return ret;
	}
	
	private double getHDFSWriteTime( long dm, long dn, double ds, String format )
	{
		boolean sparse = MatrixBlock.evalSparseFormatOnDisk(dm, dn, (long)(ds*dm*dn));
		
		double bytes = (double)MatrixBlock.estimateSizeOnDisk((long)dm, (long)dn, (long)(ds*dm*dn));
		double mbytes = bytes / (1024*1024);  		
		
		double ret = -1;
		
		if( format.equals("textcell") || format.equals("csv") )
		{
			if( sparse )
				ret = mbytes / DEFAULT_MBS_HDFSWRITE_TEXT_SPARSE;
			else //dense
				ret = mbytes / DEFAULT_MBS_HDFSWRITE_TEXT_DENSE;	
			ret *= 2.75; //text commonly 2x-3.5x larger than binary
		}
		else
		{
			if( sparse )
				ret = mbytes / DEFAULT_MBS_HDFSWRITE_BINARYBLOCK_SPARSE;
			else //dense
				ret = mbytes / DEFAULT_MBS_HDFSWRITE_BINARYBLOCK_DENSE;
		}
		//if( LOG.isDebugEnabled() )
		//	LOG.debug("Costs[export] = "+ret+"s, "+mbytes+" MB ("+dm+","+dn+","+ds+").");
		
		
		return ret;
	}

	/**
	 * Returns the estimated read time from local FS. 
	 * NOTE: Does not handle unknowns.
	 * 
	 * @param dm
	 * @param dn
	 * @param ds
	 * @return
	 */
	private double getFSReadTime( long dm, long dn, double ds )
	{
		boolean sparse = MatrixBlock.evalSparseFormatOnDisk(dm, dn, (long)(ds*dm*dn));
		
		double ret = ((double)MatrixBlock.estimateSizeOnDisk((long)dm, (long)dn, (long)(ds*dm*dn))) / (1024*1024);  		
		if( sparse )
			ret /= DEFAULT_MBS_FSREAD_BINARYBLOCK_SPARSE;
		else //dense
			ret /= DEFAULT_MBS_FSREAD_BINARYBLOCK_DENSE;
		
		return ret;
	}

	/**
	 * 
	 * @param dm
	 * @param dn
	 * @param ds
	 * @return
	 */
	private double getFSWriteTime( long dm, long dn, double ds )
	{
		boolean sparse = MatrixBlock.evalSparseFormatOnDisk(dm, dn, (long)(ds*dm*dn));
		
		double ret = ((double)MatrixBlock.estimateSizeOnDisk((long)dm, (long)dn, (long)(ds*dm*dn))) / (1024*1024);  		
		
		if( sparse )
			ret /= DEFAULT_MBS_FSWRITE_BINARYBLOCK_SPARSE;
		else //dense
			ret /= DEFAULT_MBS_FSWRITE_BINARYBLOCK_DENSE;
		
		return ret;
	}

	
	/////////////////////
	// Operation Costs //
	/////////////////////
	
	/**
	 * 
	 * @param inst
	 * @param vs
	 * @param args
	 * @return
	 * @throws DMLRuntimeException
	 * @throws DMLUnsupportedOperationException
	 */
	private double getInstTimeEstimate(String opcode, VarStats[] vs, String[] args, ExecType et) 
		throws DMLRuntimeException, DMLUnsupportedOperationException
	{
		boolean inMR = (et == ExecType.MR);
		return getInstTimeEstimate(opcode, inMR,  
				                   vs[0]._rlen, vs[0]._clen, (vs[0]._nnz<0)? 1.0:(double)vs[0]._nnz/vs[0]._rlen/vs[0]._clen, 
						           vs[1]._rlen, vs[1]._clen, (vs[1]._nnz<0)? 1.0:(double)vs[1]._nnz/vs[1]._rlen/vs[1]._clen, 
						           vs[2]._rlen, vs[2]._clen, (vs[2]._nnz<0)? 1.0:(double)vs[2]._nnz/vs[2]._rlen/vs[2]._clen,
						           args);
	}
	
	/**
	 * Returns the estimated instruction execution time, w/o data transfer and single-threaded.
	 * For scalars input dims must be set to 1 before invocation. 
	 * 
	 * NOTE: Does not handle unknowns.
	 * 
	 * @param opcode
	 * @param d1m
	 * @param d1n
	 * @param d1s
	 * @param d2m
	 * @param d2n
	 * @param d2s
	 * @param d3m
	 * @param d3n
	 * @param d3s
	 * @return
	 * @throws DMLRuntimeException 
	 * @throws DMLUnsupportedOperationException 
	 */
	private double getInstTimeEstimate( String opcode, boolean inMR, long d1m, long d1n, double d1s, long d2m, long d2n, double d2s, long d3m, long d3n, double d3s, String[] args ) throws DMLRuntimeException, DMLUnsupportedOperationException
	{
		double nflops = getNFLOP(opcode, inMR, d1m, d1n, d1s, d2m, d2n, d2s, d3m, d3n, d3s, args);
		double time = nflops / DEFAULT_FLOPS;
		
		if( LOG.isDebugEnabled() )
			LOG.debug("Cost["+opcode+"] = "+time+"s, "+nflops+" flops ("+d1m+","+d1n+","+d1s+","+d2m+","+d2n+","+d2s+","+d3m+","+d3n+","+d3s+").");
		
		return time;
	}
	
	/**
	 * 
	 * @param optype
	 * @param d1m
	 * @param d1n
	 * @param d1s
	 * @param d2m
	 * @param d2n
	 * @param d2s
	 * @param d3m
	 * @param d3n
	 * @param d3s
	 * @param args
	 * @return
	 * @throws DMLRuntimeException
	 * @throws DMLUnsupportedOperationException
	 */
	private double getNFLOP( String optype, boolean inMR, long d1m, long d1n, double d1s, long d2m, long d2n, double d2s, long d3m, long d3n, double d3s, String[] args ) 
		throws DMLRuntimeException, DMLUnsupportedOperationException
	{
		//operation costs in FLOP on matrix block level (for CP and MR instructions)
		//(excludes IO and parallelism; assumes known dims for all inputs, outputs )
	
		boolean leftSparse = MatrixBlock.evalSparseFormatInMemory(d1m, d1n, (long)(d1s*d1m*d1n));
		boolean rightSparse = MatrixBlock.evalSparseFormatInMemory(d2m, d2n, (long)(d2s*d2m*d2n));
		boolean onlyLeft = (d1m>=0 && d1n>=0 && d2m<0 && d2n<0 );
		boolean allExists = (d1m>=0 && d1n>=0 && d2m>=0 && d2n>=0 && d3m>=0 && d3n>=0 );
		
		//NOTE: all instruction types that are equivalent in CP and MR are only
		//included in CP to prevent redundancy
		CPINSTRUCTION_TYPE cptype = CPInstructionParser.String2CPInstructionType.get(optype);
		if( cptype != null ) //for CP Ops and equivalent MR ops 
		{
			//general approach: count of floating point *, /, +, -, ^, builtin ;
			switch(cptype) 
			{
			
				case AggregateBinary: //opcodes: ba+*, cov
					if( optype.equals("ba+*") ) { //matrix mult
						//reduction by factor 2 because matrix mult better than
						//average flop count
						if( !leftSparse && !rightSparse )
							return 2 * (d1m * d1n * ((d2n>1)?d1s:1.0) * d2n) /2;
						else if( !leftSparse && rightSparse )
							return 2 * (d1m * d1n * d1s * d2n * d2s) /2;
						else if( leftSparse && !rightSparse )
							return 2 * (d1m * d1n * d1s * d2n) /2;
						else //leftSparse && rightSparse
							return 2 * (d1m * d1n * d1s * d2n * d2s) /2;
					}
					else if( optype.equals("cov") ) {
						//note: output always scalar, d3 used as weights block
						//if( allExists ), same runtime for 2 and 3 inputs
						return 23 * d1m; //(11+3*k+)
					}
					
					return 0;
				
				case MMChain:
					//reduction by factor 2 because matrix mult better than average flop count
					//(mmchain essentially two matrix-vector muliplications)
					if( !leftSparse  )
						return (2+2) * (d1m * d1n) /2;
					else 
						return (2+2) * (d1m * d1n * d1s) /2;
					
				case AggregateTernary: //opcodes: tak+*
					return 6 * d1m * d1n; //2*1(*) + 4 (k+)
					
				case AggregateUnary: //opcodes: uak+, uark+, uack+, uasqk+, uarsqk+, uacsqk+,
				                     //         uamean, uarmean, uacmean, uavar, uarvar, uacvar,
				                     //         uamax, uarmax, uarimax, uacmax, uamin, uarmin, uacmin,
				                     //         ua+, uar+, uac+, ua*, uatrace, uaktrace,
				                     //         nrow, ncol, length, cm
					
					if( optype.equals("nrow") || optype.equals("ncol") || optype.equals("length") )
						return DEFAULT_NFLOP_NOOP;
					else if( optype.equals( "cm" ) ) {
						double xcm = 1;
						switch( Integer.parseInt(args[0]) ) {
							case 0: xcm=1; break; //count
							case 1: xcm=8; break; //mean
							case 2: xcm=16; break; //cm2
							case 3: xcm=31; break; //cm3
							case 4: xcm=51; break; //cm4
							case 5: xcm=16; break; //variance
						}
						return (leftSparse) ? xcm * (d1m * d1s + 1) : xcm * d1m;
					}
				    else if( optype.equals("uatrace") || optype.equals("uaktrace") )
				    	return 2 * d1m * d1n;
				    else if( optype.equals("ua+") || optype.equals("uar+") || optype.equals("uac+")  ){
				    	//sparse safe operations
				    	if( !leftSparse ) //dense
				    		return d1m * d1n;
				    	else //sparse
				    		return d1m * d1n * d1s;
				    }
				    else if( optype.equals("uak+") || optype.equals("uark+") || optype.equals("uack+"))
				    	return 4 * d1m * d1n; //1*k+
				    else if( optype.equals("uasqk+") || optype.equals("uarsqk+") || optype.equals("uacsqk+"))
						return 5 * d1m * d1n; // +1 for multiplication to square term
				    else if( optype.equals("uamean") || optype.equals("uarmean") || optype.equals("uacmean"))
						return 7 * d1m * d1n; //1*k+
				    else if( optype.equals("uavar") || optype.equals("uarvar") || optype.equals("uacvar"))
						return 14 * d1m * d1n;
				    else if(   optype.equals("uamax") || optype.equals("uarmax") || optype.equals("uacmax")
				    		|| optype.equals("uamin") || optype.equals("uarmin") || optype.equals("uacmin")
				    		|| optype.equals("uarimax") || optype.equals("ua*") )
				    	return d1m * d1n;
					
				    return 0;	
				    
				case ArithmeticBinary: //opcodes: +, -, *, /, ^ (incl. ^2, *2)
					//note: covers scalar-scalar, scalar-matrix, matrix-matrix
					if( optype.equals("+") || optype.equals("-") //sparse safe
						&& ( leftSparse || rightSparse ) )
						return d1m*d1n*d1s + d2m*d2n*d2s;
					else
						return d3m*d3n;
					
				case Ternary: //opcodes: ctable
					if( optype.equals("ctable") ){
						if( leftSparse )
							return d1m * d1n * d1s; //add
						else 
							return d1m * d1n;
					}
					return 0;
					
				case BooleanBinary: //opcodes: &&, ||
					return 1; //always scalar-scalar
						
				case BooleanUnary: //opcodes: !
					return 1; //always scalar-scalar

				case Builtin: //opcodes: log 
					//note: covers scalar-scalar, scalar-matrix, matrix-matrix
					//note: can be unary or binary
					if( allExists ) //binary
						return 3 * d3m * d3n;
					else //unary
						return d3m * d3n;
					
				case BuiltinBinary: //opcodes: max, min, solve
					//note: covers scalar-scalar, scalar-matrix, matrix-matrix
					if( optype.equals("solve") ) //see also MultiReturnBuiltin
						return d1m * d1n * d1n; //for 1kx1k ~ 1GFLOP -> 0.5s
					else //default
						return d3m * d3n;

					
				case BuiltinUnary: //opcodes: exp, abs, sin, cos, tan, sign, sqrt, plogp, print, round, sprop, sigmoid
					//TODO add cost functions for commons math builtins: inverse, cholesky
					if( optype.equals("print") ) //scalar only
						return 1;
					else
					{
						double xbu = 1; //default for all ops
						if( optype.equals("plogp") ) xbu = 2;
						else if( optype.equals("round") ) xbu = 4;
						
						if( optype.equals("sin") || optype.equals("tan") || optype.equals("round")
							|| optype.equals("abs") || optype.equals("sqrt") || optype.equals("sprop")
							|| optype.equals("sigmoid") || optype.equals("sign") ) //sparse-safe
						{
							if( leftSparse ) //sparse
								return xbu * d1m * d1n * d1s;	
							else //dense
								return xbu * d1m * d1n;
						}
						else
							return xbu * d1m * d1n;
					}
										
				case Reorg: //opcodes: r', rdiag
				case MatrixReshape: //opcodes: rshape
					if( leftSparse )
						return d1m * d1n * d1s;
					else
						return d1m * d1n;
					
				case Append: //opcodes: append
					return DEFAULT_NFLOP_CP * 
					       (((leftSparse) ? d1m * d1n * d1s : d1m * d1n ) +
					        ((rightSparse) ? d2m * d2n * d2s : d2m * d2n ));
					
				case RelationalBinary: //opcodes: ==, !=, <, >, <=, >=  
					//note: all relational ops are not sparsesafe
					return d3m * d3n; //covers all combinations of scalar and matrix  
					
				case File: //opcodes: rm, mv
					return DEFAULT_NFLOP_NOOP;
					
				case Variable: //opcodes: assignvar, cpvar, rmvar, rmfilevar, assignvarwithfile, attachfiletovar, valuepick, iqsize, read, write, createvar, setfilename, castAsMatrix
					if( optype.equals("write") ){
						boolean text = args[0].equals("textcell") || args[0].equals("csv");
						double xwrite =  text ? DEFAULT_NFLOP_TEXT_IO : DEFAULT_NFLOP_CP;
						
						if( !leftSparse )
							return d1m * d1n * xwrite; 
						else
							return d1m * d1n * d1s * xwrite;
					}
					else if ( optype.equals("inmem-iqm") )
						//note: assumes uniform distribution
						return 2 * d1m + //sum of weights
						       5 + 0.25d * d1m + //scan to lower quantile
						       8 * 0.5 * d1m; //scan from lower to upper quantile
					else
						return DEFAULT_NFLOP_NOOP;
			
				case Rand: //opcodes: rand, seq
					if( optype.equals(DataGen.RAND_OPCODE) ){
						int nflopRand = 32; //per random number
						switch(Integer.parseInt(args[0])) {
							case 0: return DEFAULT_NFLOP_NOOP; //empty matrix
							case 1: return d3m * d3n * 8; //allocate, arrayfill
							case 2: //full rand
							{
								if( d3s==1.0 )
									return d3m * d3n * nflopRand + d3m * d3n * 8; //DENSE gen (incl allocate)    
								else 
									return (d3s>=MatrixBlock.SPARSITY_TURN_POINT)? 
										    2 * d3m * d3n * nflopRand + d3m * d3n * 8: //DENSE gen (incl allocate)    
									        3 * d3m * d3n * d3s * nflopRand + d3m * d3n * d3s * 24; //SPARSE gen (incl allocate)
							}
						}
					}
					else //seq
						return d3m * d3n * DEFAULT_NFLOP_CP;
				
				case StringInit: //sinit
					return d3m * d3n * DEFAULT_NFLOP_CP;
					
				case External: //opcodes: extfunct
					//note: should be invoked independently for multiple outputs
					return d1m * d1n * d1s * DEFAULT_NFLOP_UNKNOWN;
				
				case MultiReturnBuiltin: //opcodes: qr, lu, eigen
					//note: they all have cubic complexity, the scaling factor refers to commons.math
					double xf = 2; //default e.g, qr
					if( optype.equals("eigen") ) 
						xf = 32;
					else if ( optype.equals("lu") )
						xf = 16;
					return xf * d1m * d1n * d1n; //for 1kx1k ~ 2GFLOP -> 1s
					
				case ParameterizedBuiltin: //opcodes: cdf, invcdf, groupedagg, rmempty
					if( optype.equals("cdf") || optype.equals("invcdf"))
						return DEFAULT_NFLOP_UNKNOWN; //scalar call to commons.math
					else if( optype.equals("groupedagg") ){	
						double xga = 1;
						switch( Integer.parseInt(args[0]) ) {
							case 0: xga=4; break; //sum, see uk+
							case 1: xga=1; break; //count, see cm
							case 2: xga=8; break; //mean
							case 3: xga=16; break; //cm2
							case 4: xga=31; break; //cm3
							case 5: xga=51; break; //cm4
							case 6: xga=16; break; //variance
						}						
						return 2 * d1m + xga * d1m; //scan for min/max, groupedagg
					}	
					else if( optype.equals("rmempty") ){
						switch(Integer.parseInt(args[0])){
							case 0: //remove rows
								return ((leftSparse) ? d1m : d1m * Math.ceil(1.0d/d1s)/2) +
									   DEFAULT_NFLOP_CP * d3m * d2m;
							case 1: //remove cols
								return d1n * Math.ceil(1.0d/d1s)/2 + 
								       DEFAULT_NFLOP_CP * d3m * d2m;
						}
						
					}	
					return 0;
					
				case QSort: //opcodes: sort
					if( optype.equals("sort") ){
						//note: mergesort since comparator used
						double sortCosts = 0;
						if( onlyLeft )
							sortCosts = DEFAULT_NFLOP_CP * d1m + d1m;
						else //w/ weights
							sortCosts = DEFAULT_NFLOP_CP * ((leftSparse)?d1m*d1s:d1m); 
						return sortCosts + d1m*(int)(Math.log(d1m)/Math.log(2)) + //mergesort
										   DEFAULT_NFLOP_CP * d1m;
					}
					return 0;
					
				case MatrixIndexing: //opcodes: rangeReIndex, leftIndex
					if( optype.equals("leftIndex") ){
						return DEFAULT_NFLOP_CP * ((leftSparse)? d1m*d1n*d1s : d1m*d1n)
						       + 2 * DEFAULT_NFLOP_CP * ((rightSparse)? d2m*d2n*d2s : d2m*d2n );
					}
					else if( optype.equals("rangeReIndex") ){
						return DEFAULT_NFLOP_CP * ((leftSparse)? d2m*d2n*d2s : d2m*d2n );
					}
					return 0;
					
				case MMTSJ: //opcodes: tsmm
					//diff to ba+* only upper triangular matrix
					//reduction by factor 2 because matrix mult better than
					//average flop count
					if( MMTSJType.valueOf(args[0]).isLeft() ) { //lefttranspose
						if( !rightSparse ) //dense						
							return d1m * d1n * d1s * d1n /2;
						else //sparse
							return d1m * d1n * d1s * d1n * d1s /2; 
					}
					else if(onlyLeft) { //righttranspose
						if( !leftSparse ) //dense
							return (double)d1m * d1n * d1m /2;
						else //sparse
							return   d1m * d1n * d1s //reorg sparse
							       + d1m * d1n * d1s * d1n * d1s /2; //core tsmm
					}					
					return 0;
				
				case Partition:
					return d1m * d1n * d1s + //partitioning costs
						   (inMR ? 0 : //include write cost if in CP  	
							getHDFSWriteTime(d1m, d1n, d1s)* DEFAULT_FLOPS);
					
				case INVALID:
					return 0;
				
				default: 
					throw new DMLRuntimeException("CostEstimator: unsupported instruction type: "+optype);
			}
				
		}
		
		//if not found in CP instructions
		MRINSTRUCTION_TYPE mrtype = MRInstructionParser.String2MRInstructionType.get(optype);
		if ( mrtype != null ) //for specific MR ops
		{
			switch(mrtype)
			{
				case Aggregate: //opcodes: a+, ak+, asqk+, a*, amax, amin, amean
					//TODO should be aggregate unary
					int numMap = Integer.parseInt(args[0]);
					if( optype.equals("ak+") )
						return 4 * numMap * d1m * d1n * d1s;
					else if( optype.equals("asqk+") )
						return 5 * numMap * d1m * d1n * d1s; // +1 for multiplication to square term
					else if( optype.equals("avar") )
						return 14 * numMap * d1m * d1n * d1s;
					else
						return numMap * d1m * d1n * d1s;
					
				case AggregateBinary: //opcodes: cpmm, rmm, mapmult
					//note: copy from CP costs
					if(    optype.equals("cpmm") || optype.equals("rmm") 
						|| optype.equals(MapMult.OPCODE) ) //matrix mult
					{
						//reduction by factor 2 because matrix mult better than
						//average flop count
						if( !leftSparse && !rightSparse )
							return 2 * (d1m * d1n * ((d2n>1)?d1s:1.0) * d2n) /2;
						else if( !leftSparse && rightSparse )
							return 2 * (d1m * d1n * d1s * d2n * d2s) /2;
						else if( leftSparse && !rightSparse )
							return 2 * (d1m * d1n * d1s * d2n) /2;
						else //leftSparse && rightSparse
							return 2 * (d1m * d1n * d1s * d2n * d2s) /2;
					}
					return 0;
					
				case MapMultChain: //opcodes: mapmultchain	
					//assume dense input2 and input3
					return   2 * d1m * d2n * d1n * ((d2n>1)?d1s:1.0) //ba(+*) 
						   + d1m * d2n //cellwise b(*) 
					       + d1m * d2n //r(t)
					       + 2 * d2n * d1n * d1m * (leftSparse?d1s:1.0) //ba(+*)
					       + d2n * d1n; //r(t)
					
				case ArithmeticBinary: //opcodes: s-r, so, max, min, 
					                   //         >, >=, <, <=, ==, != 
					//TODO Should be relational 
				
					//note: all relational ops are not sparsesafe
					return d3m * d3n; //covers all combinations of scalar and matrix  
	
				case CombineUnary: //opcodes: combineunary
					return d1m * d1n * d1s;
					
				case CombineBinary: //opcodes: combinebinary
					return   d1m * d1n * d1s
					       + d2m * d2n * d2s;
					
				case CombineTernary: //opcodes: combinetertiary
					return   d1m * d1n * d1s
				           + d2m * d2n * d2s
				           + d3m * d3n * d3s;
					
				case Unary: //opcodes: log, slog, pow 			
					//TODO requires opcode consolidation (builtin, arithmic)
					//note: covers scalar, matrix, matrix-scalar
					return d3m * d3n;
					
				case Ternary: //opcodes: ctabletransform, ctabletransformscalarweight, ctabletransformhistogram, ctabletransformweightedhistogram
					//note: copy from cp
					if( leftSparse )
						return d1m * d1n * d1s; //add
					else 
						return d1m * d1n;
			
				case Quaternary:
					//TODO pattern specific and all 4 inputs requires
					return d1m * d1n * d1s *4;
					
				case Reblock: //opcodes: rblk
					return DEFAULT_NFLOP_CP * ((leftSparse)? d1m*d1n*d1s : d1m*d1n); 
					
				case Replicate: //opcodes: rblk
					return DEFAULT_NFLOP_CP * ((leftSparse)? d1m*d1n*d1s : d1m*d1n); 
					
				case CM_N_COV: //opcodes: mean
					double xcm = 8;
					return (leftSparse) ? xcm * (d1m * d1s + 1) : xcm * d1m;
					
				case GroupedAggregate: //opcodes: groupedagg		
					//TODO: need to consolidate categories (ParameterizedBuiltin)
					//copy from CP opertion
					double xga = 1;
					switch( Integer.parseInt(args[0]) ) {
						case 0: xga=4; break; //sum, see uk+
						case 1: xga=1; break; //count, see cm
						case 2: xga=8; break; //mean
						case 3: xga=16; break; //cm2
						case 4: xga=31; break; //cm3
						case 5: xga=51; break; //cm4
						case 6: xga=16; break; //variance
					}						
					return 2 * d1m + xga * d1m; //scan for min/max, groupedagg
					
				case PickByCount: //opcodes: valuepick, rangepick
					break;
					//TODO
					//String2MRInstructionType.put( "valuepick"  , MRINSTRUCTION_TYPE.PickByCount);  // for quantile()
					//String2MRInstructionType.put( "rangepick"  , MRINSTRUCTION_TYPE.PickByCount);  // for interQuantile()
					
				case RangeReIndex: //opcodes: rangeReIndex, rangeReIndexForLeft
					//TODO: requires category consolidation
					if( optype.equals("rangeReIndex") )
						return DEFAULT_NFLOP_CP * ((leftSparse)? d2m*d2n*d2s : d2m*d2n );
					else //rangeReIndexForLeft
						return   DEFAULT_NFLOP_CP * ((leftSparse)? d1m*d1n*d1s : d1m*d1n)
					           + DEFAULT_NFLOP_CP * ((rightSparse)? d2m*d2n*d2s : d2m*d2n );
	
				case ZeroOut: //opcodes: zeroOut
					return   DEFAULT_NFLOP_CP * ((leftSparse)? d1m*d1n*d1s : d1m*d1n)
				           + DEFAULT_NFLOP_CP * ((rightSparse)? d2m*d2n*d2s : d2m*d2n );								
					
				default:
					return 0;
			}
		}
		else
		{
			throw new DMLRuntimeException("CostEstimator: unsupported instruction type: "+optype);
		}
		
		//TODO Parameterized Builtin Functions
		//String2CPFileInstructionType.put( "rmempty"	    , CPINSTRUCTION_TYPE.ParameterizedBuiltin);
		
		return -1; //should never come here.
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy