All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.controlprogram.parfor.opt.OptimizationWrapper Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.controlprogram.parfor.opt;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map.Entry;
import java.util.Set;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.ipa.InterProceduralAnalysis;
import org.apache.sysml.hops.rewrite.HopRewriteRule;
import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysml.hops.rewrite.ProgramRewriter;
import org.apache.sysml.hops.rewrite.RewriteConstantFolding;
import org.apache.sysml.hops.rewrite.RewriteRemoveUnnecessaryBranches;
import org.apache.sysml.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.LanguageException;
import org.apache.sysml.parser.ParForStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.POptMode;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.parfor.opt.Optimizer.CostModelType;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Stat;
import org.apache.sysml.runtime.controlprogram.parfor.stat.StatisticMonitor;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.utils.Statistics;


/**
 * Wrapper to ParFOR cost estimation and optimizer. This is intended to be the 
 * only public access to the optimizer package.
 * 
 * NOTE: There are two main alternatives for invocation of this OptimizationWrapper:
 * (1) During compilation (after creating rtprog), (2) on execute of all top-level ParFOR PBs.
 * We decided to use (2) (and carry the SBs during execution) due to the following advantages
 *   - Known Statistics: problem size of top-level parfor known, in general, less unknown statistics
 *   - No Overhead: preventing overhead for non-parfor scripts (finding top-level parfors)
 *   - Simplicity: no need of finding top-level parfors 
 * 
 */
public class OptimizationWrapper 
{
	
	private static final boolean LDEBUG = false; //internal local debug level
	private static final Log LOG = LogFactory.getLog(OptimizationWrapper.class.getName());
	
	//internal parameters
	public static final double PAR_FACTOR_INFRASTRUCTURE = 1.0;
	private static final boolean ALLOW_RUNTIME_COSTMODEL = false;
	private static final boolean CHECK_PLAN_CORRECTNESS = false; 
	
	static
	{
		// for internal debugging only
		if( LDEBUG ) {
			Logger.getLogger("org.apache.sysml.runtime.controlprogram.parfor.opt")
				  .setLevel((Level) Level.DEBUG);
		}
	}
	
	/**
	 * Called once per DML script (during program compile time) 
	 * in order to optimize all top-level parfor program blocks.
	 * 
	 * NOTE: currently note used at all.
	 * 
	 * @param prog
	 * @param rtprog
	 * @throws DMLRuntimeException 
	 * @throws LanguageException 
	 * @throws DMLUnsupportedOperationException 
	 */
	public static void optimize(DMLProgram prog, Program rtprog, boolean monitor) 
		throws DMLRuntimeException, LanguageException, DMLUnsupportedOperationException 
	{
		LOG.debug("ParFOR Opt: Running optimize all on DML program "+DMLScript.getUUID());
		
		//init internal structures 
		HashMap sbs = new HashMap();
		HashMap pbs = new HashMap();	
		
		//find all top-level paror pbs
		findParForProgramBlocks(prog, rtprog, sbs, pbs);
		
		// Create an empty symbol table
		ExecutionContext ec = ExecutionContextFactory.createContext();
		
		//optimize each top-level parfor pb independently
		for( Entry entry : pbs.entrySet() )
		{
			long key = entry.getKey();
			ParForStatementBlock sb = sbs.get(key);
			ParForProgramBlock pb = entry.getValue();
			
			//optimize (and implicit exchange)
			POptMode type = pb.getOptimizationMode(); //known to be >0
			optimize( type, sb, pb, ec, monitor );
		}		
		
		LOG.debug("ParFOR Opt: Finished optimization for DML program "+DMLScript.getUUID());
	}

	/**
	 * Called once per top-level parfor (during runtime, on parfor execute)
	 * in order to optimize the specific parfor program block.
	 * 
	 * NOTE: this is the default way to invoke parfor optimizers.
	 * 
	 * @param type
	 * @param sb
	 * @param pb
	 * @throws DMLRuntimeException
	 * @throws DMLUnsupportedOperationException 
	 */
	public static void optimize( POptMode type, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor ) 
		throws DMLRuntimeException, DMLUnsupportedOperationException
	{
		Timing time = new Timing(true);
		
		LOG.debug("ParFOR Opt: Running optimization for ParFOR("+pb.getID()+")");
		
		
		//set max contraints if not specified
		int ck = UtilFunctions.toInt( Math.max( InfrastructureAnalyzer.getCkMaxCP(),
						                        InfrastructureAnalyzer.getCkMaxMR() ) * PAR_FACTOR_INFRASTRUCTURE );
		double cm = InfrastructureAnalyzer.getCmMax() * OptimizerUtils.MEM_UTIL_FACTOR; 
		
		//execute optimizer
		optimize( type, ck, cm, sb, pb, ec, monitor );
		
		double timeVal = time.stop();
		LOG.debug("ParFOR Opt: Finished optimization for PARFOR("+pb.getID()+") in "+timeVal+"ms.");
		//System.out.println("ParFOR Opt: Finished optimization for PARFOR("+pb.getID()+") in "+timeVal+"ms.");
		if( monitor )
			StatisticMonitor.putPFStat( pb.getID() , Stat.OPT_T, timeVal);
	}
	
	/**
	 * 
	 * @param optLogLevel
	 */
	public static void setLogLevel( Level optLogLevel )
	{
		if( !LDEBUG ){ //set log level if not overwritten by internal flag
			Logger.getLogger("org.apache.sysml.runtime.controlprogram.parfor.opt")
			      .setLevel( optLogLevel );
		}
	}
	
	/**
	 * 
	 * @param type
	 * @param ck
	 * @param cm
	 * @param sb
	 * @param pb
	 * @throws DMLRuntimeException
	 * @throws DMLUnsupportedOperationException 
	 * @throws  
	 */
	@SuppressWarnings("unused")
	private static void optimize( POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor ) 
		throws DMLRuntimeException, DMLUnsupportedOperationException
	{
		Timing time = new Timing(true);
		
		//maintain statistics
		if( DMLScript.STATISTICS )
			Statistics.incrementParForOptimCount();
		
		//create specified optimizer
		Optimizer opt = createOptimizer( otype );
		CostModelType cmtype = opt.getCostModelType();
		LOG.trace("ParFOR Opt: Created optimizer ("+otype+","+opt.getPlanInputType()+","+opt.getCostModelType());
		
		if( cmtype == CostModelType.RUNTIME_METRICS  //TODO remove check when perftesttool supported
			&& !ALLOW_RUNTIME_COSTMODEL )
		{
			throw new DMLRuntimeException("ParFOR Optimizer "+otype+" requires cost model "+cmtype+" that is not suported yet.");
		}
		
		OptTree tree = null;
		
		//recompile parfor body 
		if( OptimizerUtils.ALLOW_DYN_RECOMPILATION )
		{
			ForStatement fs = (ForStatement) sb.getStatement(0);
			
			//debug output before recompilation
			if( LOG.isDebugEnabled() ) 
			{
				try {
					tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec); 
					LOG.debug("ParFOR Opt: Input plan (before recompilation):\n" + tree.explain(false));
					OptTreeConverter.clear();
				}
				catch(Exception ex)
				{
					throw new DMLRuntimeException("Unable to create opt tree.", ex);
				}
			}
			
			//constant propagation into parfor body 
			//(input scalars to parfor are guaranteed read only, but need to ensure safe-replace on multiple reopt
			//separate propagation required because recompile in-place without literal replacement)
			try{
				LocalVariableMap constVars = ProgramRecompiler.getReusableScalarVariables(sb.getDMLProg(), sb, ec.getVariables());
				ProgramRecompiler.replaceConstantScalarVariables(sb, constVars);
			}
			catch(Exception ex){
				throw new DMLRuntimeException(ex);
			}
			
			
			//program rewrites (e.g., constant folding, branch removal) according to replaced literals
			try {
				ProgramRewriter rewriter = createProgramRewriterWithRuleSets();
				ProgramRewriteStatus state = new ProgramRewriteStatus();
				rewriter.rewriteStatementBlockHopDAGs( sb, state );
				fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state));
				if( state.getRemovedBranches() ){
					LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
					pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody()));
				}
			}
			catch(Exception ex){
				throw new DMLRuntimeException(ex);
			}
			
			//recompilation of parfor body and called functions (if safe)
			try{
				//core parfor body recompilation (based on symbol table entries)
				//* clone of variables in order to allow for statistics propagation across DAGs
				//(tid=0, because deep copies created after opt)
				LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone();
				Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, true);
				
				//inter-procedural optimization (based on previous recompilation)
				if( pb.hasFunctions() ) {
					InterProceduralAnalysis ipa = new InterProceduralAnalysis();
					Set fcand = ipa.analyzeSubProgram(sb);		
					
					if( !fcand.isEmpty() ) {
						//regenerate runtime program of modified functions
						for( String func : fcand )
						{
							String[] funcparts = DMLProgram.splitFunctionKey(func);
							FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
							//reset recompilation flags according to recompileOnce because it is only safe if function is recompileOnce 
							//because then recompiled for every execution (otherwise potential issues if func also called outside parfor)
							Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0, fpb.isRecompileOnce());
						}		
					}
				}
			}
			catch(Exception ex){
				throw new DMLRuntimeException(ex);
			}
		}
		
		//create opt tree (before optimization)
		try {
			tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec); 
			LOG.debug("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false));
		}
		catch(Exception ex)
		{
			throw new DMLRuntimeException("Unable to create opt tree.", ex);
		}
		
		//create cost estimator
		CostEstimator est = createCostEstimator( cmtype );
		LOG.trace("ParFOR Opt: Created cost estimator ("+cmtype+")");
		
		//core optimize
		opt.optimize( sb, pb, tree, est, ec );
		LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false));
		
		//assert plan correctness
		if( CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled() )
		{
			try{
				OptTreePlanChecker.checkProgramCorrectness(pb, sb, new HashSet());
				LOG.debug("ParFOR Opt: Checked plan and program correctness.");
			}
			catch(Exception ex)
			{
				throw new DMLRuntimeException("Failed to check program correctness.", ex);
			}
		}
		
		long ltime = (long) time.stop();
		LOG.trace("ParFOR Opt: Optimized plan in "+ltime+"ms.");
		if( DMLScript.STATISTICS )
			Statistics.incrementParForOptimTime(ltime);
		
		//cleanup phase
		OptTreeConverter.clear();
		
		//monitor stats
		if( monitor )
		{
			StatisticMonitor.putPFStat( pb.getID() , Stat.OPT_OPTIMIZER, otype.ordinal());
			StatisticMonitor.putPFStat( pb.getID() , Stat.OPT_NUMTPLANS, opt.getNumTotalPlans());
			StatisticMonitor.putPFStat( pb.getID() , Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());
		}
	}

	/**
	 * 
	 * @param prog
	 * @param rtprog
	 * @throws LanguageException 
	 */
	private static void findParForProgramBlocks( DMLProgram prog, Program rtprog, 
			HashMap sbs, HashMap pbs ) 
		throws LanguageException
	{
		//handle function program blocks
		HashMap fpbs = rtprog.getFunctionProgramBlocks();
		for( Entry entry : fpbs.entrySet() )
		{
			String[] keypart = entry.getKey().split( Program.KEY_DELIM );
			String namespace = keypart[0];
			String name      = keypart[1]; 
			
			ProgramBlock pb = entry.getValue();
			StatementBlock sb = prog.getFunctionStatementBlock(namespace, name);
			
			//recursive find 
			rfindParForProgramBlocks(sb, pb, sbs, pbs);	
		}
		
		//handle actual program blocks
		ArrayList tpbs = rtprog.getProgramBlocks();
		for( int i=0; i sbs, HashMap pbs )
	{
		if( pb instanceof ParForProgramBlock  ) 
		{
			//put top-level parfor into map, but no recursion
			ParForProgramBlock pfpb = (ParForProgramBlock) pb;
			ParForStatementBlock pfsb = (ParForStatementBlock) sb;
			
			LOG.trace("ParFOR: found ParForProgramBlock with POptMode="+pfpb.getOptimizationMode().toString());
			
			if( pfpb.getOptimizationMode() != POptMode.NONE )
			{
				//register programblock tree for optimization
				long pfid = pfpb.getID();
				pbs.put(pfid, pfpb);
				sbs.put(pfid, pfsb);
			}
		}
		else if( pb instanceof ForProgramBlock )
		{
			//recursive find
			ArrayList fpbs = ((ForProgramBlock) pb).getChildBlocks();
			ArrayList fsbs = ((ForStatement)((ForStatementBlock) sb).getStatement(0)).getBody();
			for( int i=0;  i< fpbs.size(); i++ )
				rfindParForProgramBlocks(fsbs.get(i), fpbs.get(i), sbs, pbs);
		}
		else if( pb instanceof WhileProgramBlock )
		{
			//recursive find
			ArrayList wpbs = ((WhileProgramBlock) pb).getChildBlocks();
			ArrayList wsbs = ((WhileStatement)((WhileStatementBlock) sb).getStatement(0)).getBody();
			for( int i=0;  i< wpbs.size(); i++ )
				rfindParForProgramBlocks(wsbs.get(i), wpbs.get(i), sbs, pbs);	
		}
		else if( pb instanceof IfProgramBlock  )
		{
			//recursive find
			IfProgramBlock ifpb = (IfProgramBlock) pb;
			IfStatement ifs = (IfStatement) ((IfStatementBlock) sb).getStatement(0);			
			ArrayList ipbs1 = ifpb.getChildBlocksIfBody();
			ArrayList ipbs2 = ifpb.getChildBlocksElseBody();
			ArrayList isbs1 = ifs.getIfBody();
			ArrayList isbs2 = ifs.getElseBody();			
			for( int i=0;  i< ipbs1.size(); i++ )
				rfindParForProgramBlocks(isbs1.get(i), ipbs1.get(i), sbs, pbs);				
			for( int i=0;  i< ipbs2.size(); i++ )
				rfindParForProgramBlocks(isbs2.get(i), ipbs2.get(i), sbs, pbs);								
		}
	}
	
	/**
	 * 
	 * @param otype
	 * @return
	 * @throws DMLRuntimeException
	 */
	private static Optimizer createOptimizer( POptMode otype ) 
		throws DMLRuntimeException
	{
		Optimizer opt = null;
		
		switch( otype )
		{
			case HEURISTIC:
				opt = new OptimizerHeuristic();
				break;
			case RULEBASED:
				opt = new OptimizerRuleBased();
				break;	
			case CONSTRAINED:
				opt = new OptimizerConstrained();
				break;	
		
			//MB: removed unused and experimental prototypes
			//case FULL_DP:
			//	opt = new OptimizerDPEnum();
			//	break;
			//case GREEDY:
			//	opt = new OptimizerGreedyEnum();
			//	break;
			
			default:
				throw new DMLRuntimeException("Undefined optimizer: '"+otype+"'.");
		}
		
		return opt;
	}

	/**
	 * 
	 * @param cmtype
	 * @return
	 * @throws DMLRuntimeException
	 */
	private static CostEstimator createCostEstimator( CostModelType cmtype ) 
		throws DMLRuntimeException
	{
		CostEstimator est = null;
		
		switch( cmtype )
		{
			case STATIC_MEM_METRIC:
				est = new CostEstimatorHops( OptTreeConverter.getAbstractPlanMapping() );
				break;
			case RUNTIME_METRICS:
				est = new CostEstimatorRuntime();
				break;
			default:
				throw new DMLRuntimeException("Undefined cost model type: '"+cmtype+"'.");
		}
		
		return est;
	}
	
	/**
	 * 
	 * @return
	 */
	private static ProgramRewriter createProgramRewriterWithRuleSets()
	{
		//create hop rewrite set
		ArrayList hRewrites = new ArrayList();
		hRewrites.add( new RewriteConstantFolding() );
		
		//create statementblock rewrite set
		ArrayList sbRewrites = new ArrayList();
		sbRewrites.add( new RewriteRemoveUnnecessaryBranches() );
		
		ProgramRewriter rewriter = new ProgramRewriter( hRewrites, sbRewrites );
		
		return rewriter;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy