org.apache.sysml.runtime.controlprogram.parfor.opt.ProgramRecompiler Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.controlprogram.parfor.opt;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.Hop.VisitStatus;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.compile.Dag;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptNode.NodeType;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.ArithmeticBinaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
/**
*
*/
public class ProgramRecompiler
{
/**
*
* @param rtprog
* @param sbs
* @return
* @throws IOException
* @throws DMLUnsupportedOperationException
* @throws DMLRuntimeException
* @throws LopsException
* @throws HopsException
*/
public static ArrayList generatePartitialRuntimeProgram(Program rtprog, ArrayList sbs)
throws LopsException, DMLRuntimeException, DMLUnsupportedOperationException, IOException, HopsException
{
ArrayList ret = new ArrayList();
DMLConfig config = ConfigurationManager.getConfig();
//construct lops from hops if not existing
DMLTranslator dmlt = new DMLTranslator(sbs.get(0).getDMLProg());
for( StatementBlock sb : sbs ) {
dmlt.constructLops(sb);
}
//construct runtime program from lops
for( StatementBlock sb : sbs ) {
DMLProgram prog = sb.getDMLProg();
ret.add( prog.createRuntimeProgramBlock(rtprog, sb, config) );
}
return ret;
}
/**
* NOTE: if force is set, we set and recompile the respective indexing hops;
* otherwise, we release the forced exec type and recompile again. Hence,
* any changes can be exactly reverted with the same access behavior.
*
* @param sb
* @param pb
* @param var
* @param ec
* @param force
* @throws DMLUnsupportedOperationException
* @throws DMLRuntimeException
*/
public static void rFindAndRecompileIndexingHOP( StatementBlock sb, ProgramBlock pb, String var, ExecutionContext ec, boolean force )
throws DMLUnsupportedOperationException, DMLRuntimeException
{
if( pb instanceof IfProgramBlock && sb instanceof IfStatementBlock )
{
IfProgramBlock ipb = (IfProgramBlock) pb;
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement is = (IfStatement) sb.getStatement(0);
//process if condition
if( isb.getPredicateHops()!=null )
ipb.setPredicate( rFindAndRecompileIndexingHOP(isb.getPredicateHops(),ipb.getPredicate(),var,ec,force) );
//process if branch
int len = is.getIfBody().size();
for( int i=0; i newInst = Recompiler.recompileHopsDag(sb, sb.get_hops(), ec.getVariables(), null, true, 0);
pb.setInstructions( newInst );
}
}
catch(Exception ex)
{
throw new DMLRuntimeException(ex);
}
}
}
/**
*
* @param prog
* @param parforSB
* @param vars
* @return
* @throws DMLUnsupportedOperationException
* @throws DMLRuntimeException
*/
public static LocalVariableMap getReusableScalarVariables( DMLProgram prog, StatementBlock parforSB, LocalVariableMap vars )
throws DMLUnsupportedOperationException, DMLRuntimeException
{
LocalVariableMap constVars = new LocalVariableMap();
for( String varname : vars.keySet() )
{
Data dat = vars.get(varname);
if( dat instanceof ScalarObject //scalar
&& isApplicableForReuseVariable(prog, parforSB, varname) ) //constant
{
constVars.put(varname, dat);
}
}
return constVars;
}
public static void replaceConstantScalarVariables( StatementBlock sb, LocalVariableMap vars )
throws DMLUnsupportedOperationException, DMLRuntimeException, HopsException
{
if( sb instanceof IfStatementBlock )
{
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement is = (IfStatement) sb.getStatement(0);
replacePredicateLiterals(isb.getPredicateHops(), vars);
for( StatementBlock lsb : is.getIfBody() )
replaceConstantScalarVariables(lsb, vars);
for( StatementBlock lsb : is.getElseBody() )
replaceConstantScalarVariables(lsb, vars);
}
else if( sb instanceof WhileStatementBlock )
{
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement ws = (WhileStatement) sb.getStatement(0);
replacePredicateLiterals(wsb.getPredicateHops(), vars);
for( StatementBlock lsb : ws.getBody() )
replaceConstantScalarVariables(lsb, vars);
}
else if( sb instanceof ForStatementBlock ) //for or parfor
{
ForStatementBlock fsb = (ForStatementBlock)sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
replacePredicateLiterals(fsb.getFromHops(), vars);
replacePredicateLiterals(fsb.getToHops(), vars);
replacePredicateLiterals(fsb.getIncrementHops(), vars);
for( StatementBlock lsb : fs.getBody() )
replaceConstantScalarVariables(lsb, vars);
}
else //last level block
{
ArrayList hops = sb.get_hops();
if( hops != null )
{
//replace constant literals
Hop.resetVisitStatus(hops);
for( Hop hopRoot : hops )
Recompiler.rReplaceLiterals( hopRoot, vars );
}
}
}
/**
*
* @param pred
* @param vars
* @throws DMLRuntimeException
*/
private static void replacePredicateLiterals( Hop pred, LocalVariableMap vars )
throws DMLRuntimeException
{
if( pred != null ){
pred.resetVisitStatus();
Recompiler.rReplaceLiterals(pred, vars);
}
}
/**
* This function determines if an parfor input variable is guaranteed to be read-only
* across multiple invocations of parfor optimization (e.g., in a surrounding while loop).
* In case of invariant variables we can reuse partitioned matrices and propagate constants
* for better size estimation.
*
* @param prog
* @param parforSB
* @param var
* @return
* @throws DMLUnsupportedOperationException
* @throws DMLRuntimeException
*/
public static boolean isApplicableForReuseVariable( DMLProgram prog, StatementBlock parforSB, String var )
throws DMLUnsupportedOperationException, DMLRuntimeException
{
boolean ret = false;
for( StatementBlock sb : prog.getStatementBlocks() )
ret |= isApplicableForReuseVariable(sb, parforSB, var);
return ret;
}
/**
*
* @param sb
* @param parforSB
* @param var
* @throws DMLUnsupportedOperationException
* @throws DMLRuntimeException
*/
private static boolean isApplicableForReuseVariable( StatementBlock sb, StatementBlock parforSB, String var )
throws DMLUnsupportedOperationException, DMLRuntimeException
{
boolean ret = false;
if( sb instanceof IfStatementBlock )
{
IfStatement is = (IfStatement) sb.getStatement(0);
for( StatementBlock lsb : is.getIfBody() )
ret |= isApplicableForReuseVariable(lsb, parforSB, var);
for( StatementBlock lsb : is.getElseBody() )
ret |= isApplicableForReuseVariable(lsb, parforSB, var);
}
else if( sb instanceof WhileStatementBlock )
{
WhileStatement ws = (WhileStatement) sb.getStatement(0);
for( StatementBlock lsb : ws.getBody() )
ret |= isApplicableForReuseVariable(lsb, parforSB, var);
}
else if( sb instanceof ForStatementBlock ) //for or parfor
{
ForStatementBlock fsb = (ForStatementBlock)sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
if( fsb == parforSB ) {
//found parfor statement
ret = true;
}
else {
for( StatementBlock lsb : fs.getBody() )
ret |= isApplicableForReuseVariable(lsb, parforSB, var);
}
}
return ret && !sb.variablesUpdated().containsVariable(var);
}
/**
*
* @param pb
* @return
*/
public static boolean containsAtLeastOneFunction( ProgramBlock pb )
{
if( pb instanceof IfProgramBlock )
{
IfProgramBlock ipb = (IfProgramBlock) pb;
for( ProgramBlock lpb : ipb.getChildBlocksIfBody() )
if( containsAtLeastOneFunction(lpb) )
return true;
for( ProgramBlock lpb : ipb.getChildBlocksElseBody() )
if( containsAtLeastOneFunction(lpb) )
return true;
}
else if( pb instanceof WhileProgramBlock )
{
WhileProgramBlock wpb = (WhileProgramBlock) pb;
for( ProgramBlock lpb : wpb.getChildBlocks() )
if( containsAtLeastOneFunction(lpb) )
return true;
}
else if( pb instanceof ForProgramBlock ) //incl parfor
{
ForProgramBlock fpb = (ForProgramBlock) pb;
for( ProgramBlock lpb : fpb.getChildBlocks() )
if( containsAtLeastOneFunction(lpb) )
return true;
}
else
{
if( pb.getInstructions() != null )
for( Instruction inst : pb.getInstructions() )
if( inst instanceof FunctionCallCPInstruction )
return true;
}
return false;
}
/**
*
* @param hop
* @param in
* @param force
* @return
* @throws DMLRuntimeException
*/
private static ArrayList rFindAndRecompileIndexingHOP( Hop hop, ArrayList in, String var, ExecutionContext ec, boolean force )
throws DMLRuntimeException
{
ArrayList tmp = in;
try
{
boolean ret = false;
hop.resetVisitStatus();
if( force ) //set forced execution type
ret = rFindAndSetCPIndexingHOP(hop, var);
else //release forced execution type
ret = rFindAndReleaseIndexingHOP(hop, var);
//recompilation on-demand
if( ret )
{
//construct new instructions
tmp = Recompiler.recompileHopsDag(hop, ec.getVariables(), null, true, 0);
}
}
catch(Exception ex)
{
throw new DMLRuntimeException(ex);
}
return tmp;
}
/**
*
* @param hop
* @param var
* @return
*/
private static boolean rFindAndSetCPIndexingHOP(Hop hop, String var)
{
boolean ret = false;
if( hop.getVisited() == VisitStatus.DONE )
return ret;
ArrayList in = hop.getInput();
if( hop instanceof IndexingOp )
{
String inMatrix = hop.getInput().get(0).getName();
if( inMatrix.equals(var) )
{
//NOTE: mem estimate of RIX, set to output size by parfor optmizer
//(rowblock/colblock only applied if in total less than two blocks,
// hence always mem_est in = hop.getInput();
if( hop instanceof IndexingOp )
{
String inMatrix = hop.getInput().get(0).getName();
if( inMatrix.equals(var) )
{
hop.setForcedExecType(null);
hop.clearMemEstimate();
ret = true;
}
}
//recursive search
if( in != null )
for( Hop hin : in )
ret |= rFindAndReleaseIndexingHOP(hin,var);
hop.setVisited(VisitStatus.DONE);
return ret;
}
///////
// additional general-purpose functionalities
/**
*
* @param iterVar
* @param offset
* @return
* @throws DMLRuntimeException
* @throws DMLUnsupportedOperationException
*/
protected static ArrayList createNestedParallelismToInstructionSet(String iterVar, String offset)
throws DMLRuntimeException, DMLUnsupportedOperationException
{
//create instruction string
StringBuilder sb = new StringBuilder("CP"+Lop.OPERAND_DELIMITOR+"+"+Lop.OPERAND_DELIMITOR);
sb.append(iterVar);
sb.append(Lop.DATATYPE_PREFIX+"SCALAR"+Lop.VALUETYPE_PREFIX+"INT"+Lop.OPERAND_DELIMITOR);
sb.append(offset);
sb.append(Lop.DATATYPE_PREFIX+"SCALAR"+Lop.VALUETYPE_PREFIX+"INT"+Lop.OPERAND_DELIMITOR);
sb.append(iterVar);
sb.append(Lop.DATATYPE_PREFIX+"SCALAR"+Lop.VALUETYPE_PREFIX+"INT");
String str = sb.toString();
//create instruction set
ArrayList tmp = new ArrayList();
Instruction inst = ArithmeticBinaryCPInstruction.parseInstruction(str);
tmp.add(inst);
return tmp;
}
/////////////////////////////////
// experimental functionality
//////////
/**
*
* @param n
* @throws DMLRuntimeException
*/
protected static void recompilePartialPlan( OptNode n )
throws DMLRuntimeException
{
//NOTE: need to recompile complete programblock because (1) many to many relationships
//between hops and instructions and (2) due to changed internal variable names
try
{
//get parent program and statement block
OptTreePlanMappingAbstract map = OptTreeConverter.getAbstractPlanMapping();
long pid = map.getMappedParentID(n.getID());
Object[] o = map.getMappedProg(pid);
StatementBlock sbOld = (StatementBlock) o[0];
ProgramBlock pbOld = (ProgramBlock) o[1];
//get changed node and set type appropriately
Hop hop = (Hop) map.getMappedHop(n.getID());
hop.setForcedExecType(n.getExecType().toLopsExecType());
hop.setLops(null); //to enable fresh construction
//get all hops of statement and construct new instructions
Dag dag = new Dag();
for( Hop hops : sbOld.get_hops() )
{
hops.resetVisitStatus();
Recompiler.rClearLops(hops);
Lop lops = hops.constructLops();
lops.addToDag(dag);
}
//construct new instructions
ArrayList newInst = dag.getJobs(sbOld, ConfigurationManager.getConfig());
//exchange instructions
pbOld.getInstructions().clear();
pbOld.getInstructions().addAll(newInst);
}
catch(Exception ex)
{
throw new DMLRuntimeException(ex);
}
}
/**
* NOTE: need to recompile complete programblock because (1) many to many relationships
* between hops and instructions and (2) due to changed internal variable names
*
* @param n
* @return
* @throws DMLRuntimeException
*/
protected static ProgramBlock recompile( OptNode n )
throws DMLRuntimeException
{
ProgramBlock pbNew = null;
try
{
if( n.getNodeType() == NodeType.HOP )
{
//get parent program and statement block
OptTreePlanMappingAbstract map = OptTreeConverter.getAbstractPlanMapping();
long pid = map.getMappedParentID(n.getID());
Object[] o = map.getMappedProg(pid);
StatementBlock sbOld = (StatementBlock) o[0];
ProgramBlock pbOld = (ProgramBlock) o[1];
LopProperties.ExecType oldtype = null;
//get changed node and set type appropriately
Hop hop = (Hop) map.getMappedHop(n.getID());
hop.setForcedExecType(n.getExecType().toLopsExecType());
hop.setLops(null); //to enable fresh construction
//get all hops of statement and construct new lops
Dag dag = new Dag();
for( Hop hops : sbOld.get_hops() )
{
hops.resetVisitStatus();
Recompiler.rClearLops(hops);
Lop lops = hops.constructLops();
lops.addToDag(dag);
}
//construct new instructions
ArrayList newInst = dag.getJobs(sbOld, ConfigurationManager.getConfig());
//exchange instructions
pbNew = new ProgramBlock(pbOld.getProgram());
pbNew.setInstructions(newInst);
//reset type global repository
hop.setForcedExecType(oldtype);
}
else if( n.getNodeType() == NodeType.PARFOR )
{
//no recompilation required
OptTreePlanMappingAbstract map = OptTreeConverter.getAbstractPlanMapping();
ParForProgramBlock pb = (ParForProgramBlock)map.getMappedProg(n.getID())[1];
pbNew = ProgramConverter.createShallowCopyParForProgramBlock(pb, pb.getProgram());
((ParForProgramBlock)pbNew).setExecMode(n.getExecType().toParForExecMode());
}
else
{
throw new DMLRuntimeException("Unexpected node type.");
}
}
catch(Exception ex)
{
throw new DMLRuntimeException(ex);
}
return pbNew;
}
/**
*
* @param hlNodeID
* @param pbNew
* @throws DMLRuntimeException
*/
protected static void exchangeProgram(long hlNodeID, ProgramBlock pbNew)
throws DMLRuntimeException
{
OptTreePlanMappingAbstract map = OptTreeConverter.getAbstractPlanMapping();
OptNode node = map.getOptNode(hlNodeID);
if( node.getNodeType() == NodeType.HOP )
{
long pid = map.getMappedParentID(hlNodeID);
Object[] o = map.getMappedProg(pid);
ProgramBlock pbOld = (ProgramBlock) o[1];
//exchange instructions (save version)
pbOld.getInstructions().clear();
pbOld.getInstructions().addAll( pbNew.getInstructions() );
}
else if( node.getNodeType() == NodeType.PARFOR )
{
ParForProgramBlock pbOld = (ParForProgramBlock) map.getMappedProg(node.getID())[1];
pbOld.setExecMode(((ParForProgramBlock)pbNew).getExecMode());
//TODO extend as required
}
else
{
throw new DMLRuntimeException("Unexpected node type: "+node.getNodeType());
}
}
}