All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.controlprogram.IfProgramBlock Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.controlprogram;

import java.util.ArrayList;

import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLScriptException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.Instruction.INSTRUCTION_TYPE;
import org.apache.sysml.runtime.instructions.cp.BooleanObject;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.cp.StringObject;
import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE;
import org.apache.sysml.yarn.DMLAppMasterUtils;


public class IfProgramBlock extends ProgramBlock 
{
	
	private ArrayList _predicate;
	private String _predicateResultVar;
	private ArrayList  _exitInstructions ;
	private ArrayList _childBlocksIfBody;
	private ArrayList _childBlocksElseBody;
	
	public IfProgramBlock(Program prog, ArrayList predicate) throws DMLRuntimeException{
		super(prog);
		
		_childBlocksIfBody = new ArrayList();
		_childBlocksElseBody = new ArrayList();
		
		_predicate = predicate;
		_predicateResultVar = findPredicateResultVar ();
		_exitInstructions = new ArrayList();
	}
	
	public ArrayList getChildBlocksIfBody() { 
		return _childBlocksIfBody; 
	}

	public void setChildBlocksIfBody(ArrayList blocks) { 
		_childBlocksIfBody = blocks; 
	}
	
	public void addProgramBlockIfBody(ProgramBlock pb) { 
		_childBlocksIfBody.add(pb); 
	}	
	
	public ArrayList getChildBlocksElseBody() { 
		return _childBlocksElseBody; 
	}

	public void setChildBlocksElseBody(ArrayList blocks) { 
		_childBlocksElseBody = blocks; 
	}
	
	public void addProgramBlockElseBody(ProgramBlock pb) {
		_childBlocksElseBody.add(pb); 
	}
	
	public void setExitInstructions2(ArrayList exitInstructions){
		_exitInstructions = exitInstructions;
	}

	public void setExitInstructions1(ArrayList predicate){
		_predicate = predicate;
	}
	
	public void addExitInstruction(Instruction inst){
		_exitInstructions.add(inst);
	}
	
	public ArrayList getPredicate(){
		return _predicate;
	}

	public void setPredicate(ArrayList predicate) 
	{
		_predicate = predicate;
		
		//update result var if non-empty predicate (otherwise,
		//do not overwrite varname predicate in predicateResultVar)
		if( _predicate != null && !_predicate.isEmpty()  )
			_predicateResultVar = findPredicateResultVar();
	}
	
	public String getPredicateResultVar(){
		return _predicateResultVar;
	}
	
	public void setPredicateResultVar(String resultVar) {
		_predicateResultVar = resultVar;
	}
	
	public ArrayList getExitInstructions(){
		return _exitInstructions;
	}
	
	@Override
	public void execute(ExecutionContext ec) 
		throws DMLRuntimeException
	{	
		BooleanObject predResult = executePredicate(ec); 
	
		//execute if statement
		if(predResult.getBooleanValue())
		{	
			try 
			{	
				for (int i=0 ; i < _childBlocksIfBody.size() ; i++) {
					ec.updateDebugState(i);
					_childBlocksIfBody.get(i).execute(ec);
				}
			}
			catch(DMLScriptException e) {
				throw e;
			}
			catch(Exception e)
			{
				throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error evaluating if statement body ", e);
			}
		}
		else
		{
			try 
			{	
				for (int i=0 ; i < _childBlocksElseBody.size() ; i++) {
					ec.updateDebugState(i);
					_childBlocksElseBody.get(i).execute(ec);
				}
			}
			catch(DMLScriptException e) {
				throw e;
			}
			catch(Exception e)
			{
				throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error evaluating else statement body ", e);
			}	
		}
		
		//execute exit instructions
		try { 
			executeInstructions(_exitInstructions, ec);
		}
		catch(DMLScriptException e) {
			throw e;
		}
		catch (Exception e){
			
			throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error evaluating if exit instructions ", e);
		}
	}

	private BooleanObject executePredicate(ExecutionContext ec) 
		throws DMLRuntimeException 
	{
		BooleanObject result = null;
		try
		{
			if( _predicate!=null && !_predicate.isEmpty() )
			{
				if( _sb != null )
				{
					if( DMLScript.isActiveAM() ) //set program block specific remote memory
						DMLAppMasterUtils.setupProgramBlockRemoteMaxMemory(this);
					
					IfStatementBlock isb = (IfStatementBlock)_sb;
					Hop predicateOp = isb.getPredicateHops();
					boolean recompile = isb.requiresPredicateRecompilation();
					result = (BooleanObject) executePredicate(_predicate, predicateOp, recompile, ValueType.BOOLEAN, ec);
				}
				else
					result = (BooleanObject) executePredicate(_predicate, null, false, ValueType.BOOLEAN, ec);
			}
			else 
			{
				//get result var
				ScalarObject scalarResult = null;
				Data resultData = ec.getVariable(_predicateResultVar);
				if ( resultData == null ) {
					//note: resultvar is a literal (can it be of any value type other than String, hence no literal/varname conflict) 
					scalarResult = ec.getScalarInput(_predicateResultVar, ValueType.BOOLEAN, true);
				}
				else {
					scalarResult = ec.getScalarInput(_predicateResultVar, ValueType.BOOLEAN, false);
				}
				
				//check for invalid type String 
				if (scalarResult instanceof StringObject)
					throw new DMLRuntimeException(this.printBlockErrorLocation() + "\nIf predicate variable "+ _predicateResultVar + " evaluated to string " + scalarResult + " which is not allowed for predicates in DML");
				
				//process result
				if( scalarResult instanceof BooleanObject )
					result = (BooleanObject)scalarResult;
				else
					result = new BooleanObject( scalarResult.getBooleanValue() ); //auto casting
			}
		}
		catch(Exception ex)
		{
			throw new DMLRuntimeException(this.printBlockErrorLocation() + "Failed to evaluate the IF predicate.", ex);
		}
		
		//(guaranteed to be non-null, see executePredicate/getScalarInput)
		return result;
	}
	
	private String findPredicateResultVar ( ) {
		String result = null;
		for ( Instruction si : _predicate ) {
			if ( si.getType() == INSTRUCTION_TYPE.CONTROL_PROGRAM && ((CPInstruction)si).getCPInstructionType() != CPINSTRUCTION_TYPE.Variable ) {
				result = ((ComputationCPInstruction) si).getOutputVariableName();  
			}
			else if(si instanceof VariableCPInstruction && ((VariableCPInstruction)si).isVariableCastInstruction()){
				result = ((VariableCPInstruction)si).getOutputVariableName();
			}
		}
		return result;
	}
	
	public String printBlockErrorLocation(){
		return "ERROR: Runtime error in if program block generated from if statement block between lines " + _beginLine + " and " + _endLine + " -- ";
	}
	
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy