All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.controlprogram.ForProgramBlock Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.controlprogram;

import java.util.ArrayList;
import java.util.Iterator;

import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLScriptException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.IntObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.yarn.DMLAppMasterUtils;

public class ForProgramBlock extends ProgramBlock
{	
	protected ArrayList 	_fromInstructions;
	protected ArrayList 	_toInstructions;
	protected ArrayList 	_incrementInstructions;
	
	protected ArrayList  	_exitInstructions ;
	protected ArrayList 	_childBlocks;

	protected String[]                  _iterablePredicateVars; //from,to,where constants/internal vars not captured via instructions

	
	public ForProgramBlock(Program prog, String[] iterPredVars) throws DMLRuntimeException
	{
		super(prog);
		
		_exitInstructions = new ArrayList();
		_childBlocks = new ArrayList();
		_iterablePredicateVars = iterPredVars;
	}
	
	public ArrayList getFromInstructions() {
		return _fromInstructions;
	}
	
	public void setFromInstructions(ArrayList instructions) {
		_fromInstructions = instructions;
	}
	
	public ArrayList getToInstructions() {
		return _toInstructions;
	}
	
	public void setToInstructions(ArrayList instructions) {
		_toInstructions = instructions;
	}
	
	public ArrayList getIncrementInstructions() {
		return _incrementInstructions;
	}
	
	public void setIncrementInstructions(ArrayList instructions) {
		_incrementInstructions = instructions;
	}
	
	public void addExitInstruction(Instruction inst) {
		_exitInstructions.add(inst);
	}
	
	public ArrayList getExitInstructions() {
		return _exitInstructions;
	}
	
	public void setExitInstructions(ArrayList inst) {
		_exitInstructions = inst;
	}
	
	public void addProgramBlock(ProgramBlock childBlock) {
		_childBlocks.add(childBlock);
	}
	
	public ArrayList getChildBlocks() {
		return _childBlocks;
	}
	
	public void setChildBlocks(ArrayList pbs) {
		_childBlocks = pbs;
	}
	
	public String[] getIterablePredicateVars() {
		return _iterablePredicateVars;
	}
	
	public void setIterablePredicateVars(String[] iterPredVars) {
		_iterablePredicateVars = iterPredVars;
	}
	
	@Override	
	public void execute(ExecutionContext ec) 
		throws DMLRuntimeException
	{
		// add the iterable predicate variable to the variable set
		String iterVarName = _iterablePredicateVars[0];

		// evaluate from, to, incr only once (assumption: known at for entry)
		IntObject from = executePredicateInstructions( 1, _fromInstructions, ec );
		IntObject to   = executePredicateInstructions( 2, _toInstructions, ec );
		IntObject incr = (_incrementInstructions == null || _incrementInstructions.isEmpty()) && _iterablePredicateVars[3]==null ? 
				new IntObject((from.getLongValue()<=to.getLongValue()) ? 1 : -1) :
				executePredicateInstructions( 3, _incrementInstructions, ec );
		
		if ( incr.getLongValue() == 0 ) //would produce infinite loop
			throw new DMLRuntimeException(this.printBlockErrorLocation() + "Expression for increment of variable '" + iterVarName + "' must evaluate to a non-zero value.");
		
		// execute for loop
		try 
		{
			// prepare update in-place variables
			UpdateType[] flags = prepareUpdateInPlaceVariables(ec, _tid);
			
			// run for loop body for each instance of predicate sequence 
			SequenceIterator seqIter = new SequenceIterator(iterVarName, from, to, incr);
			for( IntObject iterVar : seqIter ) 
			{
				//set iteration variable
				ec.setVariable(iterVarName, iterVar); 
				
				//execute all child blocks
				for(int i=0 ; i < this._childBlocks.size() ; i++) {
					ec.updateDebugState( i );
					_childBlocks.get(i).execute(ec);
				}				
			}
			
			// reset update-in-place variables
			resetUpdateInPlaceVariableFlags(ec, flags);
		}
		catch (DMLScriptException e) {
			//propagate stop call
			throw e;
		}
		catch (Exception e) {
			throw new DMLRuntimeException(printBlockErrorLocation() + "Error evaluating for program block", e);
		}
		
		//execute exit instructions
		try {
			executeInstructions(_exitInstructions, ec);	
		}
		catch (Exception e){
			throw new DMLRuntimeException(printBlockErrorLocation() + "Error evaluating for exit instructions", e);
		}
	}

	/**
	 * 
	 * @param pos
	 * @param instructions
	 * @param ec
	 * @return
	 * @throws DMLRuntimeException
	 */
	protected IntObject executePredicateInstructions( int pos, ArrayList instructions, ExecutionContext ec ) 
		throws DMLRuntimeException
	{
		ScalarObject tmp = null;
		IntObject ret = null;
		
		try
		{
			if( _iterablePredicateVars[pos] != null )
			{
				//probe for scalar variables
				Data ldat = ec.getVariable( _iterablePredicateVars[pos] );
				if( ldat != null && ldat instanceof ScalarObject )
					tmp = (ScalarObject)ldat;
				else //handle literals
					tmp = new IntObject( UtilFunctions.parseToLong(_iterablePredicateVars[pos]) );
			}		
			else
			{
				if( _sb!=null )
				{
					if( DMLScript.isActiveAM() ) //set program block specific remote memory
						DMLAppMasterUtils.setupProgramBlockRemoteMaxMemory(this);
					
					ForStatementBlock fsb = (ForStatementBlock)_sb;
					Hop predHops = null;
					boolean recompile = false;
					if (pos == 1){ 
						predHops = fsb.getFromHops();
						recompile = fsb.requiresFromRecompilation();
					}
					else if (pos == 2) {
						predHops = fsb.getToHops();
						recompile = fsb.requiresToRecompilation();
					}
					else if (pos == 3){
						predHops = fsb.getIncrementHops();
						recompile = fsb.requiresIncrementRecompilation();
					}
					tmp = (IntObject) executePredicate(instructions, predHops, recompile, ValueType.INT, ec);
				}
				else
					tmp = (IntObject) executePredicate(instructions, null, false, ValueType.INT, ec);
			}
		}
		catch(Exception ex)
		{
			String predNameStr = null;
			if 		(pos == 1) predNameStr = "from";
			else if (pos == 2) predNameStr = "to";
			else if (pos == 3) predNameStr = "increment";
			
			throw new DMLRuntimeException(this.printBlockErrorLocation() +"Error evaluating '" + predNameStr + "' predicate", ex);
		}
		
		//final check of resulting int object (guaranteed to be non-null, see executePredicate)
		if( tmp instanceof IntObject )
			ret = (IntObject)tmp;
		else //downcast to int if necessary
			ret = new IntObject(tmp.getName(),tmp.getLongValue()); 
		
		return ret;
	}
	
	public String printBlockErrorLocation(){
		return "ERROR: Runtime error in for program block generated from for statement block between lines " + _beginLine + " and " + _endLine + " -- ";
	}
	
	/**
	 * Utility class for iterating over positive or negative predicate sequences.
	 */
	protected class SequenceIterator implements Iterator, Iterable
	{
		private String _varName = null;
		private long _cur = -1;
		private long _to = -1;
		private long _incr = -1;
		private boolean _inuse = false;
		
		protected SequenceIterator(String varName, IntObject from, IntObject to, IntObject incr) {
			_varName = varName;
			_cur = from.getLongValue();
			_to = to.getLongValue();
			_incr = incr.getLongValue();
		}

		@Override
		public boolean hasNext() {
			return _incr > 0 ? _cur <= _to : _cur >= _to;
		}

		@Override
		public IntObject next() {
			IntObject ret = new IntObject( _varName, _cur );
			_cur += _incr; //update current val
			return ret;
		}

		@Override
		public Iterator iterator() {
			if( _inuse )
				throw new RuntimeException("Unsupported reuse of iterator.");				
			_inuse = true;
			return this;
		}

		@Override
		public void remove() {
			throw new RuntimeException("Unsupported remove on iterator.");
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy