org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.instructions.cp;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLScriptException;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
public class FunctionCallCPInstruction extends CPInstruction
{
private String _functionName;
private String _namespace;
public String getFunctionName(){
return _functionName;
}
public String getNamespace() {
return _namespace;
}
// stores both the bound input and output parameters
private ArrayList _boundInputParamOperands;
private ArrayList _boundInputParamNames;
private ArrayList _boundOutputParamNames;
public FunctionCallCPInstruction(String namespace, String functName, ArrayList boundInParamOperands, ArrayList boundInParamNames, ArrayList boundOutParamNames, String istr) {
super(null, functName, istr);
_cptype = CPINSTRUCTION_TYPE.External;
_functionName = functName;
_namespace = namespace;
_boundInputParamOperands = boundInParamOperands;
_boundInputParamNames = boundInParamNames;
_boundOutputParamNames = boundOutParamNames;
}
public static FunctionCallCPInstruction parseInstruction(String str)
throws DMLRuntimeException
{
//schema: extfunct, fname, num inputs, num outputs, inputs, outputs
String[] parts = InstructionUtils.getInstructionPartsWithValueType ( str );
String namespace = parts[1];
String functionName = parts[2];
int numInputs = Integer.valueOf(parts[3]);
int numOutputs = Integer.valueOf(parts[4]);
ArrayList boundInParamOperands = new ArrayList();
ArrayList boundInParamNames = new ArrayList();
ArrayList boundOutParamNames = new ArrayList();
for (int i = 0; i < numInputs; i++) {
CPOperand operand = new CPOperand(parts[5 + i]);
boundInParamOperands.add(operand);
boundInParamNames.add(operand.getName());
}
for (int i = 0; i < numOutputs; i++) {
boundOutParamNames.add(parts[5 + numInputs + i]);
}
return new FunctionCallCPInstruction ( namespace,functionName,
boundInParamOperands, boundInParamNames, boundOutParamNames, str );
}
@Override
public Instruction preprocessInstruction(ExecutionContext ec)
throws DMLRuntimeException
{
//default pre-process behavior
Instruction tmp = super.preprocessInstruction(ec);
//maintain debug state (function call stack)
if( DMLScript.ENABLE_DEBUG_MODE ) {
ec.handleDebugFunctionEntry((FunctionCallCPInstruction) tmp);
}
return tmp;
}
@Override
public void processInstruction(ExecutionContext ec)
throws DMLRuntimeException
{
if( LOG.isTraceEnabled() ){
LOG.trace("Executing instruction : " + this.toString());
}
// get the function program block (stored in the Program object)
FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(_namespace, _functionName);
// create bindings to formal parameters for given function call
// These are the bindings passed to the FunctionProgramBlock for function execution
LocalVariableMap functionVariables = new LocalVariableMap();
for( int i=0; i _boundInputParamNames.size() )
{
String defaultVal = fpb.getInputParams().get(i).getDefaultValue();
currFormalParamValue = ec.getScalarInput(defaultVal, valType, false);
}
// CASE (b) literals or symbol table entries
else {
CPOperand operand = _boundInputParamOperands.get(i);
String varname = operand.getName();
//error handling non-existing variables
if( !operand.isLiteral() && !ec.containsVariable(varname) ) {
throw new DMLRuntimeException("Input variable '"+varname+"' not existing on call of " +
DMLProgram.constructFunctionKey(_namespace, _functionName) + " (line "+getLineNum()+").");
}
//get input matrix/frame/scalar
currFormalParamValue = (operand.getDataType()!=DataType.SCALAR) ? ec.getVariable(varname) :
ec.getScalarInput(varname, operand.getValueType(), operand.isLiteral());
}
functionVariables.put(currFormalParamName,currFormalParamValue);
}
// Pin the input variables so that they do not get deleted
// from pb's symbol table at the end of execution of function
HashMap pinStatus = ec.pinVariables(_boundInputParamNames);
// Create a symbol table under a new execution context for the function invocation,
// and copy the function arguments into the created table.
ExecutionContext fn_ec = ExecutionContextFactory.createContext(false, ec.getProgram());
fn_ec.setVariables(functionVariables);
// execute the function block
try {
fpb._functionName = this._functionName;
fpb._namespace = this._namespace;
fpb.execute(fn_ec);
}
catch (DMLScriptException e) {
throw e;
}
catch (Exception e){
String fname = DMLProgram.constructFunctionKey(_namespace, _functionName);
throw new DMLRuntimeException("error executing function " + fname, e);
}
LocalVariableMap retVars = fn_ec.getVariables();
// cleanup all returned variables w/o binding
Collection retVarnames = new LinkedList(retVars.keySet());
HashSet probeVars = new HashSet();
for(DataIdentifier di : fpb.getOutputParams())
probeVars.add(di.getName());
for( String var : retVarnames ) {
if( !probeVars.contains(var) ) //cleanup candidate
{
Data dat = fn_ec.removeVariable(var);
if( dat != null && dat instanceof MatrixObject )
fn_ec.cleanupMatrixObject((MatrixObject)dat);
}
}
// Unpin the pinned variables
ec.unpinVariables(_boundInputParamNames, pinStatus);
// add the updated binding for each return variable to the variables in original symbol table
for (int i=0; i< fpb.getOutputParams().size(); i++){
String boundVarName = _boundOutputParamNames.get(i);
Data boundValue = retVars.get(fpb.getOutputParams().get(i).getName());
if (boundValue == null)
throw new DMLRuntimeException(boundVarName + " was not assigned a return value");
//cleanup existing data bound to output variable name
Data exdata = ec.removeVariable(boundVarName);
if ( exdata != null && exdata instanceof MatrixObject && exdata != boundValue ) {
ec.cleanupMatrixObject( (MatrixObject)exdata );
}
//add/replace data in symbol table
if( boundValue instanceof MatrixObject )
((MatrixObject) boundValue).setVarName(boundVarName);
ec.setVariable(boundVarName, boundValue);
}
}
@Override
public void postprocessInstruction(ExecutionContext ec)
throws DMLRuntimeException
{
//maintain debug state (function call stack)
if (DMLScript.ENABLE_DEBUG_MODE ) {
ec.handleDebugFunctionExit( this );
}
//default post-process behavior
super.postprocessInstruction(ec);
}
@Override
public void printMe() {
LOG.debug("ExternalBuiltInFunction: " + this.toString());
}
public String getGraphString() {
return "ExtBuiltinFunc: " + _functionName;
}
public ArrayList getBoundInputParamNames()
{
return _boundInputParamNames;
}
public ArrayList getBoundOutputParamNames()
{
return _boundOutputParamNames;
}
public void setFunctionName(String fname)
{
//update instruction string
String oldfname = _functionName;
instString = updateInstStringFunctionName(oldfname, fname);
//set attribute
_functionName = fname;
instOpcode = fname;
}
public String updateInstStringFunctionName(String pattern, String replace)
{
//split current instruction
String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
if( parts[3].equals(pattern) )
parts[3] = replace;
//construct and set modified instruction
StringBuilder sb = new StringBuilder();
for( String part : parts ) {
sb.append(part);
sb.append(Lop.OPERAND_DELIMITOR);
}
return sb.substring( 0, sb.length()-Lop.OPERAND_DELIMITOR.length() );
}
}