org.apache.sysml.hops.cost.CostEstimator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.cost;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map.Entry;
import java.util.StringTokenizer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ExternalFunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.MRInstructionParser;
import org.apache.sysml.runtime.instructions.MRJobInstruction;
import org.apache.sysml.runtime.instructions.cp.AggregateTernaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.DataGenCPInstruction;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysml.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysml.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
import org.apache.sysml.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysml.runtime.instructions.cp.StringInitCPInstruction;
import org.apache.sysml.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysml.runtime.util.UtilFunctions;
public abstract class CostEstimator
{
protected static final Log LOG = LogFactory.getLog(CostEstimator.class.getName());
private static final int DEFAULT_NUMITER = 15;
protected static final VarStats _unknownStats = new VarStats(1,1,-1,-1,-1,false);
protected static final VarStats _scalarStats = new VarStats(1,1,1,1,1,true);
public double getTimeEstimate(Program rtprog, LocalVariableMap vars, HashMap stats)
throws DMLRuntimeException
{
double costs = 0;
//obtain stats from symboltable (e.g., during recompile)
maintainVariableStatistics(vars, stats);
//get cost estimate
for( ProgramBlock pb : rtprog.getProgramBlocks() )
costs += rGetTimeEstimate(pb, stats, new HashSet(), true);
return costs;
}
public double getTimeEstimate(ProgramBlock pb, LocalVariableMap vars, HashMap stats, boolean recursive)
throws DMLRuntimeException
{
//obtain stats from symboltable (e.g., during recompile)
maintainVariableStatistics(vars, stats);
//get cost estimate
return rGetTimeEstimate(pb, stats, new HashSet(), recursive);
}
private double rGetTimeEstimate(ProgramBlock pb, HashMap stats, HashSet memoFunc, boolean recursive)
throws DMLRuntimeException
{
double ret = 0;
if (pb instanceof WhileProgramBlock)
{
WhileProgramBlock tmp = (WhileProgramBlock)pb;
if( recursive )
for (ProgramBlock pb2 : tmp.getChildBlocks())
ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
ret *= DEFAULT_NUMITER;
}
else if (pb instanceof IfProgramBlock)
{
IfProgramBlock tmp = (IfProgramBlock)pb;
if( recursive ) {
for( ProgramBlock pb2 : tmp.getChildBlocksIfBody() )
ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
if( tmp.getChildBlocksElseBody()!=null )
for( ProgramBlock pb2 : tmp.getChildBlocksElseBody() ){
ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
ret /= 2; //weighted sum
}
}
}
else if (pb instanceof ForProgramBlock) //includes ParFORProgramBlock
{
ForProgramBlock tmp = (ForProgramBlock)pb;
if( recursive )
for( ProgramBlock pb2 : tmp.getChildBlocks() )
ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
ret *= getNumIterations(stats, tmp);
}
else if ( pb instanceof FunctionProgramBlock
&& !(pb instanceof ExternalFunctionProgramBlock)) //see generic
{
FunctionProgramBlock tmp = (FunctionProgramBlock) pb;
if( recursive )
for( ProgramBlock pb2 : tmp.getChildBlocks() )
ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
}
else
{
ArrayList tmp = pb.getInstructions();
for( Instruction inst : tmp )
{
if( inst instanceof CPInstruction ) //CP
{
//obtain stats from createvar, cpvar, rmvar, rand
maintainCPInstVariableStatistics((CPInstruction)inst, stats);
//extract statistics (instruction-specific)
Object[] o = extractCPInstStatistics(inst, stats);
VarStats[] vs = (VarStats[]) o[0];
String[] attr = (String[]) o[1];
//if(LOG.isDebugEnabled())
// LOG.debug(inst);
//call time estimation for inst
ret += getCPInstTimeEstimate(inst, vs, attr);
if( inst instanceof FunctionCallCPInstruction ) //functions
{
FunctionCallCPInstruction finst = (FunctionCallCPInstruction)inst;
String fkey = DMLProgram.constructFunctionKey(finst.getNamespace(), finst.getFunctionName());
//awareness of recursive functions, missing program
if( !memoFunc.contains(fkey) && pb.getProgram()!=null )
{
if(LOG.isDebugEnabled())
LOG.debug("Begin Function "+fkey);
memoFunc.add(fkey);
Program prog = pb.getProgram();
FunctionProgramBlock fpb = prog.getFunctionProgramBlock(
finst.getNamespace(), finst.getFunctionName());
ret += rGetTimeEstimate(fpb, stats, memoFunc, recursive);
memoFunc.remove(fkey);
if(LOG.isDebugEnabled())
LOG.debug("End Function "+fkey);
}
}
}
else if(inst instanceof MRJobInstruction) //MR
{
//obtain stats for job
maintainMRJobInstVariableStatistics(inst, stats);
//extract input statistics
Object[] o = extractMRJobInstStatistics(inst, stats);
VarStats[] vs = (VarStats[]) o[0];
//if(LOG.isDebugEnabled())
// LOG.debug(inst);
if(LOG.isDebugEnabled())
LOG.debug("Begin MRJob type="+((MRJobInstruction)inst).getJobType());
//call time estimation for complex MR inst
ret += getMRJobInstTimeEstimate(inst, vs, null);
if(LOG.isDebugEnabled())
LOG.debug("End MRJob");
//cleanup stats for job
cleanupMRJobVariableStatistics(inst, stats);
}
}
}
return ret;
}
private static void maintainVariableStatistics( LocalVariableMap vars, HashMap stats )
throws DMLRuntimeException
{
for( String varname : vars.keySet() )
{
Data dat = vars.get(varname);
VarStats vs = null;
if( dat instanceof MatrixObject ) //matrix
{
MatrixObject mo = (MatrixObject) dat;
MatrixCharacteristics mc = mo.getMatrixCharacteristics();
long rlen = mc.getRows();
long clen = mc.getCols();
long brlen = mc.getRowsPerBlock();
long bclen = mc.getColsPerBlock();
long nnz = mc.getNonZeros();
boolean inmem = mo.getStatusAsString().equals("CACHED");
vs = new VarStats(rlen, clen, brlen, bclen, nnz, inmem);
}
else //scalar
{
vs = _scalarStats;
}
stats.put(varname, vs);
}
}
private static void maintainCPInstVariableStatistics( CPInstruction inst, HashMap stats )
{
if( inst instanceof VariableCPInstruction )
{
String optype = inst.getOpcode();
String[] parts = InstructionUtils.getInstructionParts(inst.toString());
if( optype.equals("createvar") ) {
if( parts.length<10 )
return;
String varname = parts[1];
long rlen = Long.parseLong(parts[6]);
long clen = Long.parseLong(parts[7]);
long brlen = Long.parseLong(parts[8]);
long bclen = Long.parseLong(parts[9]);
long nnz = Long.parseLong(parts[10]);
VarStats vs = new VarStats(rlen, clen, brlen, bclen, nnz, false);
stats.put(varname, vs);
}
else if ( optype.equals("cpvar") ) {
String varname = parts[1];
String varname2 = parts[2];
VarStats vs = stats.get(varname);
stats.put(varname2, vs);
}
else if ( optype.equals("mvvar") ) {
String varname = parts[1];
String varname2 = parts[2];
VarStats vs = stats.remove(varname);
stats.put(varname2, vs);
}
else if( optype.equals("rmvar") ) {
String varname = parts[1];
stats.remove(varname);
}
}
else if( inst instanceof DataGenCPInstruction ){
DataGenCPInstruction randInst = (DataGenCPInstruction) inst;
String varname = randInst.output.getName();
long rlen = randInst.getRows();
long clen = randInst.getCols();
long brlen = randInst.getRowsInBlock();
long bclen = randInst.getColsInBlock();
long nnz = (long) (randInst.getSparsity() * rlen * clen);
VarStats vs = new VarStats(rlen, clen, brlen, bclen, nnz, true);
stats.put(varname, vs);
}
else if( inst instanceof StringInitCPInstruction ){
StringInitCPInstruction iinst = (StringInitCPInstruction) inst;
String varname = iinst.output.getName();
long rlen = iinst.getRows();
long clen = iinst.getCols();
VarStats vs = new VarStats(rlen, clen, ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize(), rlen*clen, true);
stats.put(varname, vs);
}
else if( inst instanceof FunctionCallCPInstruction )
{
FunctionCallCPInstruction finst = (FunctionCallCPInstruction) inst;
ArrayList outVars = finst.getBoundOutputParamNames();
for( String varname : outVars )
stats.put(varname, _unknownStats);
}
}
private void maintainMRJobInstVariableStatistics( Instruction inst, HashMap stats )
throws DMLRuntimeException
{
MRJobInstruction jobinst = (MRJobInstruction)inst;
//input sizes (varname, index mapping)
String[] inVars = jobinst.getInputVars();
int index = -1;
for( String varname : inVars )
{
VarStats vs = stats.get(varname);
if( vs==null )
vs = _unknownStats;
stats.put(String.valueOf(++index), vs);
}
//rand output
String rdInst = jobinst.getIv_randInstructions();
if( rdInst != null && rdInst.length()>0 )
{
StringTokenizer st = new StringTokenizer(rdInst,Lop.INSTRUCTION_DELIMITOR);
while( st.hasMoreTokens() ) //foreach rand instruction
{
String[] parts = InstructionUtils.getInstructionParts(st.nextToken());
byte outIndex = Byte.parseByte(parts[2]);
long rlen = parts[3].contains(Lop.VARIABLE_NAME_PLACEHOLDER)?-1:UtilFunctions.parseToLong(parts[3]);
long clen = parts[4].contains(Lop.VARIABLE_NAME_PLACEHOLDER)?-1:UtilFunctions.parseToLong(parts[4]);
long brlen = Long.parseLong(parts[5]);
long bclen = Long.parseLong(parts[6]);
long nnz = (long) (Double.parseDouble(parts[9]) * rlen * clen);
VarStats vs = new VarStats(rlen, clen, brlen, bclen, nnz, false);
stats.put(String.valueOf(outIndex), vs);
}
}
//compute intermediate result indices
HashMap dims = new HashMap<>();
//populate input indices
for( Entry e : stats.entrySet() )
{
if(UtilFunctions.isIntegerNumber(e.getKey()))
{
byte ix = Byte.parseByte(e.getKey());
VarStats vs = e.getValue();
if( vs !=null )
{
MatrixCharacteristics mc = new MatrixCharacteristics(vs._rlen, vs._clen, (int)vs._brlen, (int)vs._bclen, (long)vs._nnz);
dims.put(ix, mc);
}
}
}
//compute dims for all instructions
String[] instCat = new String[]{
jobinst.getIv_randInstructions(),
jobinst.getIv_recordReaderInstructions(),
jobinst.getIv_instructionsInMapper(),
jobinst.getIv_shuffleInstructions(),
jobinst.getIv_aggInstructions(),
jobinst.getIv_otherInstructions()};
for( String linstCat : instCat )
if( linstCat !=null && linstCat.length()>0 )
{
String[] linst = linstCat.split(Instruction.INSTRUCTION_DELIM);
for( String instStr : linst )
{
String instStr2 = replaceInstructionPatch(instStr);
MRInstruction mrinst = MRInstructionParser.parseSingleInstruction(instStr2);
MatrixCharacteristics.computeDimension(dims, mrinst);
}
}
//create varstats if necessary
for( Entry e : dims.entrySet() )
{
byte ix = e.getKey();
if( !stats.containsKey(String.valueOf(ix)) )
{
MatrixCharacteristics mc = e.getValue();
VarStats vs = new VarStats(mc.getRows(), mc.getCols(), mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros(), false);
stats.put(String.valueOf(ix), vs);
}
}
//map result indexes
String[] outLabels = jobinst.getOutputVars();
byte[] resultIndexes = jobinst.getIv_resultIndices();
for( int i=0; i stats )
{
Object[] ret = new Object[2]; //stats, attrs
VarStats[] vs = new VarStats[3];
String[] attr = null;
if( inst instanceof UnaryCPInstruction )
{
if( inst instanceof DataGenCPInstruction )
{
DataGenCPInstruction rinst = (DataGenCPInstruction) inst;
vs[0] = _unknownStats;
vs[1] = _unknownStats;
vs[2] = stats.get( rinst.output.getName() );
//prepare attributes for cost estimation
int type = 2; //full rand
if( rinst.getMinValue() == 0.0 && rinst.getMaxValue() == 0.0 )
type = 0;
else if( rinst.getSparsity() == 1.0 && rinst.getMinValue() == rinst.getMaxValue() )
type = 1;
attr = new String[]{String.valueOf(type)};
}
else if( inst instanceof StringInitCPInstruction )
{
StringInitCPInstruction rinst = (StringInitCPInstruction) inst;
vs[0] = _unknownStats;
vs[1] = _unknownStats;
vs[2] = stats.get( rinst.output.getName() );
}
else //general unary
{
UnaryCPInstruction uinst = (UnaryCPInstruction) inst;
vs[0] = stats.get( uinst.input1.getName() );
vs[1] = _unknownStats;
vs[2] = stats.get( uinst.output.getName() );
if( vs[0] == null ) //scalar input, e.g., print
vs[0] = _scalarStats;
if( vs[2] == null ) //scalar output
vs[2] = _scalarStats;
if( inst instanceof MMTSJCPInstruction )
{
String type = ((MMTSJCPInstruction)inst).getMMTSJType().toString();
attr = new String[]{type};
}
else if( inst instanceof AggregateUnaryCPInstruction )
{
String[] parts = InstructionUtils.getInstructionParts(inst.toString());
String opcode = parts[0];
if( opcode.equals("cm") )
attr = new String[]{parts[parts.length-2]};
}
}
}
else if( inst instanceof BinaryCPInstruction )
{
BinaryCPInstruction binst = (BinaryCPInstruction) inst;
vs[0] = stats.get( binst.input1.getName() );
vs[1] = stats.get( binst.input2.getName() );
vs[2] = stats.get( binst.output.getName() );
if( vs[0] == null ) //scalar input,
vs[0] = _scalarStats;
if( vs[1] == null ) //scalar input,
vs[1] = _scalarStats;
if( vs[2] == null ) //scalar output
vs[2] = _scalarStats;
}
else if( inst instanceof AggregateTernaryCPInstruction )
{
AggregateTernaryCPInstruction binst = (AggregateTernaryCPInstruction) inst;
//of same dimension anyway but missing third input
vs[0] = stats.get( binst.input1.getName() );
vs[1] = stats.get( binst.input2.getName() );
vs[2] = stats.get( binst.output.getName() );
if( vs[0] == null ) //scalar input,
vs[0] = _scalarStats;
if( vs[1] == null ) //scalar input,
vs[1] = _scalarStats;
if( vs[2] == null ) //scalar output
vs[2] = _scalarStats;
}
else if( inst instanceof ParameterizedBuiltinCPInstruction )
{
//ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction) inst;
String[] parts = InstructionUtils.getInstructionParts(inst.toString());
String opcode = parts[0];
if( opcode.equals("groupedagg") )
{
HashMap paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
String fn = paramsMap.get("fn");
String order = paramsMap.get("order");
AggregateOperationTypes type = CMOperator.getAggOpType(fn, order);
attr = new String[]{String.valueOf(type.ordinal())};
}
else if( opcode.equals("rmempty") )
{
HashMap paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
attr = new String[]{String.valueOf(paramsMap.get("margin").equals("rows")?0:1)};
}
vs[0] = stats.get( parts[1].substring(7).replaceAll(Lop.VARIABLE_NAME_PLACEHOLDER, "") );
vs[1] = _unknownStats; //TODO
vs[2] = stats.get( parts[parts.length-1] );
if( vs[0] == null ) //scalar input
vs[0] = _scalarStats;
if( vs[2] == null ) //scalar output
vs[2] = _scalarStats;
}
else if( inst instanceof MultiReturnBuiltinCPInstruction )
{
//applies to qr, lu, eigen (cost computation on input1)
MultiReturnBuiltinCPInstruction minst = (MultiReturnBuiltinCPInstruction) inst;
vs[0] = stats.get( minst.input1.getName() );
vs[1] = stats.get( minst.getOutput(0).getName() );
vs[2] = stats.get( minst.getOutput(1).getName() );
}
else if( inst instanceof VariableCPInstruction )
{
setUnknownStats(vs);
VariableCPInstruction varinst = (VariableCPInstruction) inst;
if( varinst.getOpcode().equals("write") ) {
//special handling write of matrix objects (non existing if scalar)
if( stats.containsKey( varinst.getInput1().getName() ) )
vs[0] = stats.get( varinst.getInput1().getName() );
attr = new String[]{varinst.getInput3().getName()};
}
}
else
{
setUnknownStats(vs);
}
//maintain var status (CP output always inmem)
vs[2]._inmem = true;
ret[0] = vs;
ret[1] = attr;
return ret;
}
private static void setUnknownStats(VarStats[] vs) {
vs[0] = _unknownStats;
vs[1] = _unknownStats;
vs[2] = _unknownStats;
}
private static Object[] extractMRJobInstStatistics( Instruction inst, HashMap stats )
{
Object[] ret = new Object[2]; //stats, attrs
VarStats[] vs = null;
String[] attr = null;
MRJobInstruction jinst = (MRJobInstruction)inst;
//get number of indices
byte[] indexes = jinst.getIv_resultIndices();
byte maxIx = -1;
for( int i=0; i stats )
{
MRJobInstruction jinst = (MRJobInstruction)inst;
//get number of indices
byte[] indexes = jinst.getIv_resultIndices();
byte maxIx = -1;
for( int i=0; i stats, ForProgramBlock pb) {
return OptimizerUtils.getNumIterations(pb, DEFAULT_NUMITER);
}
protected abstract double getCPInstTimeEstimate( Instruction inst, VarStats[] vs, String[] args )
throws DMLRuntimeException;
protected abstract double getMRJobInstTimeEstimate( Instruction inst, VarStats[] vs, String[] args )
throws DMLRuntimeException;
}