All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.instructions.cp;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;

import org.apache.sysml.api.DMLScript;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLScriptException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;


/**
 * 
 */
public class FunctionCallCPInstruction extends CPInstruction 
{	
	private String _functionName;
	private String _namespace;
	
	public String getFunctionName(){
		return _functionName;
	}
	
	public String getNamespace() {
		return _namespace;
	}
	
	// stores both the bound input and output parameters
	private ArrayList _boundInputParamOperands;
	private ArrayList _boundInputParamNames;
	private ArrayList _boundOutputParamNames;
	
	public FunctionCallCPInstruction(String namespace, String functName, ArrayList boundInParamOperands, ArrayList boundInParamNames, ArrayList boundOutParamNames, String istr) {
		super(null, functName, istr);
		
		_cptype = CPINSTRUCTION_TYPE.External;
		_functionName = functName;
		_namespace = namespace;
		_boundInputParamOperands = boundInParamOperands;
		_boundInputParamNames = boundInParamNames;
		_boundOutputParamNames = boundOutParamNames;
		
	}
		
	/**
	 * Instruction format extFunct:::[FUNCTION NAME]:::[num input params]:::[num output params]:::[list of delimited input params ]:::[list of delimited ouput params]
	 * These are the "bound names" for the inputs / outputs.  For example, out1 = foo(in1, in2) yields
	 * extFunct:::foo:::2:::1:::in1:::in2:::out1
	 * 
	 */
	public static FunctionCallCPInstruction parseInstruction(String str) 
		throws DMLRuntimeException, DMLUnsupportedOperationException 
	{	
		String[] parts = InstructionUtils.getInstructionPartsWithValueType ( str );
		String namespace = parts[1];
		String functionName = parts[2];
		int numInputs = Integer.valueOf(parts[3]);
		int numOutputs = Integer.valueOf(parts[4]);
		ArrayList boundInParamOperands = new ArrayList();
		ArrayList boundInParamNames = new ArrayList();
		ArrayList boundOutParamNames = new ArrayList();
		
		int FIRST_PARAM_INDEX = 5;
		for (int i = 0; i < numInputs; i++) {
			CPOperand operand = new CPOperand(parts[FIRST_PARAM_INDEX + i]);
			boundInParamOperands.add(operand);
			boundInParamNames.add(operand.getName());
		}
		for (int i = 0; i < numOutputs; i++) {
			boundOutParamNames.add(parts[FIRST_PARAM_INDEX + numInputs + i]);
		}
		
		return new FunctionCallCPInstruction ( namespace,functionName, boundInParamOperands, boundInParamNames, boundOutParamNames, str );
	}

	
	
	@Override
	public Instruction preprocessInstruction(ExecutionContext ec)
		throws DMLRuntimeException, DMLUnsupportedOperationException 
	{
		//default pre-process behavior
		Instruction tmp = super.preprocessInstruction(ec);
		
		//maintain debug state (function call stack) 
		if( DMLScript.ENABLE_DEBUG_MODE ) {
			ec.handleDebugFunctionEntry((FunctionCallCPInstruction) tmp);
		}

		return tmp;
	}

	@Override
	public void processInstruction(ExecutionContext ec) 
		throws DMLRuntimeException, DMLUnsupportedOperationException 
	{		
		if( LOG.isTraceEnabled() ){
			LOG.trace("Executing instruction : " + this.toString());
		}
		
		// get the function program block (stored in the Program object)
		FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(_namespace, _functionName);
		
		// create bindings to formal parameters for given function call
		// These are the bindings passed to the FunctionProgramBlock for function execution 
		LocalVariableMap functionVariables = new LocalVariableMap();		
		for( int i=0; i _boundInputParamNames.size() 
				|| (!_boundInputParamOperands.get(i).isLiteral() && ec.getVariable(_boundInputParamNames.get(i)) == null))
			{	
				String defaultVal = fpb.getInputParams().get(i).getDefaultValue();
				currFormalParamValue = ec.getScalarInput(defaultVal, valType, false);
			}
			// CASE (b) literals or symbol table entries
			else {
				CPOperand operand = _boundInputParamOperands.get(i);
				if( operand.getDataType()==DataType.SCALAR )
					currFormalParamValue = ec.getScalarInput(operand.getName(), operand.getValueType(), operand.isLiteral());
				else
					currFormalParamValue = ec.getVariable(operand.getName());					
			}
				
			functionVariables.put(currFormalParamName,currFormalParamValue);						
		}
		
		// Pin the input variables so that they do not get deleted 
		// from pb's symbol table at the end of execution of function
	    HashMap pinStatus = ec.pinVariables(_boundInputParamNames);
		
		// Create a symbol table under a new execution context for the function invocation,
		// and copy the function arguments into the created table. 
		ExecutionContext fn_ec = ExecutionContextFactory.createContext(false, ec.getProgram());
		fn_ec.setVariables(functionVariables);
		
		// execute the function block
		try {
			fpb.execute(fn_ec);
		}
		catch (DMLScriptException e) {
			throw e;
		}
		catch (Exception e){
			String fname = this._namespace + "::" + this._functionName;
			throw new DMLRuntimeException("error executing function " + fname, e);
		}
		
		LocalVariableMap retVars = fn_ec.getVariables();  
		
		// cleanup all returned variables w/o binding 
		Collection retVarnames = new LinkedList(retVars.keySet());
		HashSet probeVars = new HashSet();
		for(DataIdentifier di : fpb.getOutputParams())
			probeVars.add(di.getName());
		for( String var : retVarnames ) {
			if( !probeVars.contains(var) ) //cleanup candidate
			{
				Data dat = fn_ec.removeVariable(var);
				if( dat != null && dat instanceof MatrixObject )
					fn_ec.cleanupMatrixObject((MatrixObject)dat);
			}
		}
		
		// Unpin the pinned variables
		ec.unpinVariables(_boundInputParamNames, pinStatus);
		
		// add the updated binding for each return variable to the variables in original symbol table
		for (int i=0; i< fpb.getOutputParams().size(); i++){
		
			String boundVarName = _boundOutputParamNames.get(i); 
			Data boundValue = retVars.get(fpb.getOutputParams().get(i).getName());
			if (boundValue == null)
				throw new DMLUnsupportedOperationException(boundVarName + " was not assigned a return value");

			//cleanup existing data bound to output variable name
			Data exdata = ec.removeVariable(boundVarName);
			if ( exdata != null && exdata instanceof MatrixObject && exdata != boundValue ) {
				ec.cleanupMatrixObject( (MatrixObject)exdata );
			}
			
			//add/replace data in symbol table
			if( boundValue instanceof MatrixObject )
				((MatrixObject) boundValue).setVarName(boundVarName);
			ec.setVariable(boundVarName, boundValue);
		}
	}

	@Override
	public void postprocessInstruction(ExecutionContext ec)
		throws DMLRuntimeException 
	{
		//maintain debug state (function call stack) 
		if (DMLScript.ENABLE_DEBUG_MODE ) {
			ec.handleDebugFunctionExit( this );
		}
		
		//default post-process behavior
		super.postprocessInstruction(ec);
	}

	@Override
	public void printMe() {
		LOG.debug("ExternalBuiltInFunction: " + this.toString());
	}

	public String getGraphString() {
		return "ExtBuiltinFunc: " + _functionName;
	}
	
	public ArrayList getBoundInputParamNames()
	{
		return _boundInputParamNames;
	}
	
	public ArrayList getBoundOutputParamNames()
	{
		return _boundOutputParamNames;
	}
	
	/**
	 * 
	 * @param fname
	 */
	public void setFunctionName(String fname)
	{
		//update instruction string
		String oldfname = _functionName;
		instString = updateInstStringFunctionName(oldfname, fname);
		
		//set attribute
		_functionName = fname;
		instOpcode = fname;
	}

	/**
	 * 
	 * @param pattern
	 * @param replace
	 */
	public String updateInstStringFunctionName(String pattern, String replace)
	{
		//split current instruction
		String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
		if( parts[3].equals(pattern) )
			parts[3] = replace;	
		
		//construct and set modified instruction
		StringBuilder sb = new StringBuilder();
		for( String part : parts ) {
			sb.append(part);
			sb.append(Lop.OPERAND_DELIMITOR);
		}

		return sb.substring( 0, sb.length()-Lop.OPERAND_DELIMITOR.length() );
	}
	
	
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy