All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.controlprogram.ExternalFunctionProgramBlock Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.controlprogram;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.StringTokenizer;
import java.util.TreeMap;

import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.ReBlock;
import org.apache.sysml.lops.compile.JobType;
import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.parser.ExternalFunctionStatement;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.CacheException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.MRJobInstruction;
import org.apache.sysml.runtime.instructions.cp.BooleanObject;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.DoubleObject;
import org.apache.sysml.runtime.instructions.cp.IntObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.cp.StringObject;
import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.udf.ExternalFunctionInvocationInstruction;
import org.apache.sysml.udf.FunctionParameter;
import org.apache.sysml.udf.Matrix;
import org.apache.sysml.udf.PackageFunction;
import org.apache.sysml.udf.PackageRuntimeException;
import org.apache.sysml.udf.Scalar;
import org.apache.sysml.udf.FunctionParameter.FunctionParameterType;
import org.apache.sysml.udf.BinaryObject;
import org.apache.sysml.udf.Scalar.ScalarValueType;

public class ExternalFunctionProgramBlock extends FunctionProgramBlock 
{
		
	protected static final IDSequence _idSeq = new IDSequence();

	protected String _baseDir = null;

	ArrayList block2CellInst; 
	ArrayList cell2BlockInst; 

	// holds other key value parameters specified in function declaration
	protected HashMap _otherParams;

	protected HashMap _unblockedFileNames;
	protected HashMap _blockedFileNames;

	protected long _runID = -1; //ID for block of statements
	
	/**
	 * Constructor that also provides otherParams that are needed for external
	 * functions. Remaining parameters will just be passed to constructor for
	 * function program block.
	 * 
	 * @param eFuncStat
	 * @throws DMLRuntimeException 
	 */
	protected ExternalFunctionProgramBlock(Program prog,
			ArrayList inputParams,
			ArrayList outputParams,
			String baseDir) throws DMLRuntimeException
	{
		super(prog, inputParams, outputParams);		
		_baseDir = baseDir;
	}
	
	public ExternalFunctionProgramBlock(Program prog,
			ArrayList inputParams,
			ArrayList outputParams,
			HashMap otherParams,
			String baseDir) throws DMLRuntimeException {

		super(prog, inputParams, outputParams);
		_baseDir = baseDir;
		
		// copy other params
		_otherParams = new HashMap();
		_otherParams.putAll(otherParams);

		_unblockedFileNames = new HashMap();
		_blockedFileNames = new HashMap();
	
		// generate instructions
		createInstructions();
	}
	
	private void changeTmpInput( long id, ExecutionContext ec )
	{
		ArrayList inputParams = getInputParams();
		block2CellInst = getBlock2CellInstructions(inputParams, _unblockedFileNames);
		
		//post processing FUNCTION PATCH
		for( String var : _skipInReblock )
		{
			Data dat = ec.getVariable(var);
			if( dat instanceof MatrixObject )
				_unblockedFileNames.put(var, ((MatrixObject)dat).getFileName());
		}
	}
	
	/**
	 * It is necessary to change the local temporary files as only file handles are passed out
	 * by the external function program block.
	 * 
	 * 
	 * @param id
	 */
	private void changeTmpOutput( long id )
	{
		ArrayList outputParams = getOutputParams();
		cell2BlockInst = getCell2BlockInstructions(outputParams, _blockedFileNames);
	}
	
	/**
	 * 
	 * @return
	 */
	public String getBaseDir()
	{
		return _baseDir;
	}
	
	/**
	 * Method to be invoked to execute instructions for the external function
	 * invocation
	 * @throws DMLRuntimeException 
	 */
	@Override
	public void execute(ExecutionContext ec) 
		throws DMLRuntimeException
	{
		_runID = _idSeq.getNextID();
		
		changeTmpInput( _runID, ec ); 
		changeTmpOutput( _runID );
		
		// export input variables to HDFS (see RunMRJobs)
		ArrayList inputParams = null;
		
		try {
			inputParams = getInputParams();
			for(DataIdentifier di : inputParams ) {			
				Data d = ec.getVariable(di.getName());
				if ( d.getDataType() == DataType.MATRIX ) {
					MatrixObject inputObj = (MatrixObject) d;
					inputObj.exportData();
				}
			}
		}
		catch (Exception e){
			throw new PackageRuntimeException(this.printBlockErrorLocation() + "Error exporting input variables to HDFS", e);
		}
		
		// convert block to cell
		if( block2CellInst != null )
		{
			ArrayList tempInst = new ArrayList();
			tempInst.addAll(block2CellInst);
			try {
				this.executeInstructions(tempInst,ec);
			} catch (Exception e) {
				
				throw new PackageRuntimeException(this.printBlockErrorLocation() + "Error executing "
						+ tempInst.toString(), e);
			}
		}
		
		// now execute package function
		for (int i = 0; i < _inst.size(); i++) 
		{
			try {
				if (_inst.get(i) instanceof ExternalFunctionInvocationInstruction)
					executeInstruction(ec, (ExternalFunctionInvocationInstruction) _inst.get(i));
			} 
			catch(Exception e) {
				throw new PackageRuntimeException(this.printBlockErrorLocation() + 
						"Failed to execute instruction " + _inst.get(i).toString(), e);
			}
		}

		// convert cell to block
		if( cell2BlockInst != null )
		{
			ArrayList tempInst = new ArrayList();
			try {
				tempInst.clear();
				tempInst.addAll(cell2BlockInst);
				this.executeInstructions(tempInst, ec);
			} catch (Exception e) {
				
				throw new PackageRuntimeException(this.printBlockErrorLocation() + "Failed to execute instruction "
						+ cell2BlockInst.toString(), e);
			}
		}
		
		// check return values
		checkOutputParameters(ec.getVariables());
	}

	/**
	 * Given a list of parameters as data identifiers, returns a string
	 * representation.
	 * 
	 * @param params
	 * @return
	 */

	protected String getParameterString(ArrayList params) {
		String parameterString = "";

		for (int i = 0; i < params.size(); i++) {
			if (i != 0)
				parameterString += ",";

			DataIdentifier param = params.get(i);

			if (param.getDataType() == DataType.MATRIX) {
				String s = getDataTypeString(DataType.MATRIX) + ":";
				s = s + "" + param.getName() + "" + ":";
				s = s + getValueTypeString(param.getValueType());
				parameterString += s;
				continue;
			}

			if (param.getDataType() == DataType.SCALAR) {
				String s = getDataTypeString(DataType.SCALAR) + ":";
				s = s + "" + param.getName() + "" + ":";
				s = s + getValueTypeString(param.getValueType());
				parameterString += s;
				continue;
			}

			if (param.getDataType() == DataType.OBJECT) {
				String s = getDataTypeString(DataType.OBJECT) + ":";
				s = s + "" + param.getName() + "" + ":";
				parameterString += s;
				continue;
			}
		}

		return parameterString;
	}

	/**
	 * method to get instructions
	 */
	protected void createInstructions() {

		_inst = new ArrayList();

		// unblock all input matrices
		block2CellInst = getBlock2CellInstructions(getInputParams(),_unblockedFileNames);

		// assemble information provided through keyvalue pairs
		String className = _otherParams.get(ExternalFunctionStatement.CLASS_NAME);
		String configFile = _otherParams.get(ExternalFunctionStatement.CONFIG_FILE);
		
		// class name cannot be null, however, configFile and execLocation can
		// be null
		if (className == null)
			throw new PackageRuntimeException(this.printBlockErrorLocation() + ExternalFunctionStatement.CLASS_NAME + " not provided!");

		// assemble input and output param strings
		String inputParameterString = getParameterString(getInputParams());
		String outputParameterString = getParameterString(getOutputParams());

		// generate instruction
		ExternalFunctionInvocationInstruction einst = new ExternalFunctionInvocationInstruction(
				className, configFile, inputParameterString,
				outputParameterString);
		
		if (getInputParams().size() > 0)
			einst.setLocation(getInputParams().get(0));
		else if (getOutputParams().size() > 0)
			einst.setLocation(getOutputParams().get(0));
		else
			einst.setLocation(this._beginLine, this._endLine, this._beginColumn, this._endColumn);
		
		_inst.add(einst);

		// block output matrices
		cell2BlockInst = getCell2BlockInstructions(getOutputParams(),_blockedFileNames);
	}

	
	/**
	 * Method to generate a reblock job to convert the cell representation into block representation
	 * @param outputParams
	 * @param blockedFileNames
	 * @return
	 */
	private ArrayList getCell2BlockInstructions(
			ArrayList outputParams,
			HashMap blockedFileNames) {
		
		ArrayList c2binst = null;
		
		//list of matrices that need to be reblocked
		ArrayList matrices = new ArrayList();
		ArrayList matricesNoReblock = new ArrayList();

		// identify outputs that are matrices
		for (int i = 0; i < outputParams.size(); i++) {
			if (outputParams.get(i).getDataType() == DataType.MATRIX) {
				if( _skipOutReblock.contains(outputParams.get(i).getName()) )
					matricesNoReblock.add(outputParams.get(i));
				else
					matrices.add(outputParams.get(i));
			}
		}

		if( !matrices.isEmpty() )
		{
			c2binst = new ArrayList();
			MRJobInstruction reblkInst = new MRJobInstruction(JobType.REBLOCK);
			TreeMap> MRJobLineNumbers = null;
			if(DMLScript.ENABLE_DEBUG_MODE) {
				MRJobLineNumbers = new TreeMap>();
			}
			
			ArrayList inLabels = new ArrayList();
			ArrayList outLabels = new ArrayList();
			String[] outputs = new String[matrices.size()];
			byte[] resultIndex = new byte[matrices.size()];
			String reblock = "";
			String reblockStr = ""; //Keep a copy of a single MR reblock instruction
	
			String scratchSpaceLoc = ConfigurationManager.getConfig().getTextValue(DMLConfig.SCRATCH_SPACE);
			
			try {
				// create a RBLK job that transforms each output matrix from cell to block
				for (int i = 0; i < matrices.size(); i++) {
					inLabels.add(matrices.get(i).getName());
					outLabels.add(matrices.get(i).getName() + "_extFnOutput");
					outputs[i] = scratchSpaceLoc +
					             Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + DMLScript.getUUID() + Lop.FILE_SEPARATOR + 
		                         _otherParams.get(ExternalFunctionStatement.CLASS_NAME) + _runID + "_" + i + "Output";
					blockedFileNames.put(matrices.get(i).getName(), outputs[i]);
					resultIndex[i] = (byte) i; // (matrices.size()+i);
		
					if (i > 0)
						reblock += Lop.INSTRUCTION_DELIMITOR;
		
					reblock += "MR" + ReBlock.OPERAND_DELIMITOR + "rblk" + ReBlock.OPERAND_DELIMITOR + 
									i + ReBlock.DATATYPE_PREFIX + matrices.get(i).getDataType() + ReBlock.VALUETYPE_PREFIX + matrices.get(i).getValueType() + ReBlock.OPERAND_DELIMITOR + 
									i + ReBlock.DATATYPE_PREFIX + matrices.get(i).getDataType() + ReBlock.VALUETYPE_PREFIX + matrices.get(i).getValueType() + ReBlock.OPERAND_DELIMITOR + 
									DMLTranslator.DMLBlockSize + ReBlock.OPERAND_DELIMITOR + DMLTranslator.DMLBlockSize + ReBlock.OPERAND_DELIMITOR + "true";
					
					if(DMLScript.ENABLE_DEBUG_MODE) {
						//Create a copy of reblock instruction but as a single instruction (FOR DEBUGGER)
						reblockStr = "MR" + ReBlock.OPERAND_DELIMITOR + "rblk" + ReBlock.OPERAND_DELIMITOR + 
										i + ReBlock.DATATYPE_PREFIX + matrices.get(i).getDataType() + ReBlock.VALUETYPE_PREFIX + matrices.get(i).getValueType() + ReBlock.OPERAND_DELIMITOR + 
										i + ReBlock.DATATYPE_PREFIX + matrices.get(i).getDataType() + ReBlock.VALUETYPE_PREFIX + matrices.get(i).getValueType() + ReBlock.OPERAND_DELIMITOR + 
										DMLTranslator.DMLBlockSize + ReBlock.OPERAND_DELIMITOR + DMLTranslator.DMLBlockSize  + ReBlock.OPERAND_DELIMITOR + "true";					
						//Set MR reblock instruction line number (FOR DEBUGGER)
						if (!MRJobLineNumbers.containsKey(matrices.get(i).getBeginLine())) {
							MRJobLineNumbers.put(matrices.get(i).getBeginLine(), new ArrayList()); 
						}
						MRJobLineNumbers.get(matrices.get(i).getBeginLine()).add(reblockStr);					
					}
					// create metadata instructions to populate symbol table 
					// with variables that hold blocked matrices
					
			  		/*StringBuilder mtdInst = new StringBuilder();
					mtdInst.append("CP" + Lops.OPERAND_DELIMITOR + "createvar");
			 		mtdInst.append(Lops.OPERAND_DELIMITOR + outLabels.get(i) + Lops.DATATYPE_PREFIX + matrices.get(i).getDataType() + Lops.VALUETYPE_PREFIX + matrices.get(i).getValueType());
			  		mtdInst.append(Lops.OPERAND_DELIMITOR + outputs[i] + Lops.DATATYPE_PREFIX + DataType.SCALAR + Lops.VALUETYPE_PREFIX + ValueType.STRING);
			  		mtdInst.append(Lops.OPERAND_DELIMITOR + OutputInfo.outputInfoToString(OutputInfo.BinaryBlockOutputInfo) ) ;
					c2binst.add(CPInstructionParser.parseSingleInstruction(mtdInst.toString()));*/
					Instruction createInst = VariableCPInstruction.prepareCreateVariableInstruction(outLabels.get(i), outputs[i], false, OutputInfo.outputInfoToString(OutputInfo.BinaryBlockOutputInfo));
					
					createInst.setLocation(matrices.get(i));
					
					c2binst.add(createInst);

				}
		
				reblkInst.setReBlockInstructions(inLabels.toArray(new String[inLabels.size()]), "", reblock, "", 
						outLabels.toArray(new String[inLabels.size()]), resultIndex, 1, 1);
				c2binst.add(reblkInst);
		
				// generate instructions that rename the output variables of REBLOCK job
				Instruction cpInst = null, rmInst = null;
				for (int i = 0; i < matrices.size(); i++) {
					cpInst = VariableCPInstruction.prepareCopyInstruction(outLabels.get(i), matrices.get(i).getName());
					rmInst = VariableCPInstruction.prepareRemoveInstruction(outLabels.get(i));
					
					cpInst.setLocation(matrices.get(i));
					rmInst.setLocation(matrices.get(i));
					
					c2binst.add(cpInst);
					c2binst.add(rmInst);
					//c2binst.add(CPInstructionParser.parseSingleInstruction("CP" + Lops.OPERAND_DELIMITOR + "cpvar"+Lops.OPERAND_DELIMITOR+ outLabels.get(i) + Lops.OPERAND_DELIMITOR + matrices.get(i).getName()));
				}
			} catch (Exception e) {
				throw new PackageRuntimeException(this.printBlockErrorLocation() + "error generating instructions", e);
			}
			
			//LOGGING instructions
			if (LOG.isTraceEnabled()){
				LOG.trace("\n--- Cell-2-Block Instructions ---");
				for(Instruction i : c2binst) {
					LOG.trace(i.toString());
				}
				LOG.trace("----------------------------------");
			}
			
		}
		
		return c2binst; //null if no output matrices
	}

	/**
	 * Method to generate instructions to convert input matrices from block to
	 * cell. We generate a GMR job here.
	 * 
	 * @param inputParams
	 * @return
	 */
	private ArrayList getBlock2CellInstructions(
			ArrayList inputParams,
			HashMap unBlockedFileNames) {
		
		ArrayList b2cinst = null;
		
		//list of input matrices
		ArrayList matrices = new ArrayList();
		ArrayList matricesNoReblock = new ArrayList();

		// find all inputs that are matrices
		for (int i = 0; i < inputParams.size(); i++) {
			if (inputParams.get(i).getDataType() == DataType.MATRIX) {
				if( _skipInReblock.contains(inputParams.get(i).getName()) )
					matricesNoReblock.add(inputParams.get(i));
				else
					matrices.add(inputParams.get(i));
			}
		}
		
		if( !matrices.isEmpty() )
		{
			b2cinst = new ArrayList();
			MRJobInstruction gmrInst = new MRJobInstruction(JobType.GMR);
			TreeMap> MRJobLineNumbers = null;
			if(DMLScript.ENABLE_DEBUG_MODE) {
				MRJobLineNumbers = new TreeMap>();
			}
			String gmrStr="";
			ArrayList inLabels = new ArrayList();
			ArrayList outLabels = new ArrayList();
			String[] outputs = new String[matrices.size()];
			byte[] resultIndex = new byte[matrices.size()];
	
			String scratchSpaceLoc = ConfigurationManager.getConfig().getTextValue(DMLConfig.SCRATCH_SPACE);
			
			
			try {
				// create a GMR job that transforms each of these matrices from block to cell
				for (int i = 0; i < matrices.size(); i++) {
					
					//inputs[i] = "##" + matrices.get(i).getName() + "##";
					//inputInfo[i] = binBlockInputInfo;
					//outputInfo[i] = textCellOutputInfo;
					//numRows[i] = numCols[i] = numRowsPerBlock[i] = numColsPerBlock[i] = -1;
					//resultDimsUnknown[i] = 1;
	
					inLabels.add(matrices.get(i).getName());
					outLabels.add(matrices.get(i).getName()+"_extFnInput");
					resultIndex[i] = (byte) i; //(matrices.size()+i);
	
					outputs[i] = scratchSpaceLoc +
									Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + DMLScript.getUUID() + Lop.FILE_SEPARATOR + 
									_otherParams.get(ExternalFunctionStatement.CLASS_NAME) + _runID + "_" + i + "Input";
					unBlockedFileNames.put(matrices.get(i).getName(), outputs[i]);
	
					if(DMLScript.ENABLE_DEBUG_MODE) {
						//Create a dummy gmr instruction (FOR DEBUGGER)
						gmrStr = "MR" + Lop.OPERAND_DELIMITOR + "gmr" + Lop.OPERAND_DELIMITOR + 
										i + Lop.DATATYPE_PREFIX + matrices.get(i).getDataType() + Lop.VALUETYPE_PREFIX + matrices.get(i).getValueType() + Lop.OPERAND_DELIMITOR + 
										i + Lop.DATATYPE_PREFIX + matrices.get(i).getDataType() + Lop.VALUETYPE_PREFIX + matrices.get(i).getValueType() + Lop.OPERAND_DELIMITOR + 
										DMLTranslator.DMLBlockSize + Lop.OPERAND_DELIMITOR + DMLTranslator.DMLBlockSize;
						
						//Set MR gmr instruction line number (FOR DEBUGGER)
						if (!MRJobLineNumbers.containsKey(matrices.get(i).getBeginLine())) {
							MRJobLineNumbers.put(matrices.get(i).getBeginLine(), new ArrayList()); 
						}
						MRJobLineNumbers.get(matrices.get(i).getBeginLine()).add(gmrStr);
					}
					// create metadata instructions to populate symbol table 
					// with variables that hold unblocked matrices
				 	
					/*StringBuilder mtdInst = new StringBuilder();
					mtdInst.append("CP" + Lops.OPERAND_DELIMITOR + "createvar");
						mtdInst.append(Lops.OPERAND_DELIMITOR + outLabels.get(i) + Lops.DATATYPE_PREFIX + matrices.get(i).getDataType() + Lops.VALUETYPE_PREFIX + matrices.get(i).getValueType());
				 		mtdInst.append(Lops.OPERAND_DELIMITOR + outputs[i] + Lops.DATATYPE_PREFIX + DataType.SCALAR + Lops.VALUETYPE_PREFIX + ValueType.STRING);
				 		mtdInst.append(Lops.OPERAND_DELIMITOR + OutputInfo.outputInfoToString(OutputInfo.TextCellOutputInfo) ) ;
					b2cinst.add(CPInstructionParser.parseSingleInstruction(mtdInst.toString()));*/
					
			 		Instruction createInst = VariableCPInstruction.prepareCreateVariableInstruction(outLabels.get(i), outputs[i], false, OutputInfo.outputInfoToString(OutputInfo.TextCellOutputInfo));
			 		
			 		createInst.setLocation(matrices.get(i));
			 		
			 		b2cinst.add(createInst);
				}
			
				// Finally, generate GMR instruction that performs block2cell conversion
				gmrInst.setGMRInstructions(inLabels.toArray(new String[inLabels.size()]), "", "", "", "", 
						outLabels.toArray(new String[outLabels.size()]), resultIndex, 0, 1);
				
				b2cinst.add(gmrInst);
			
				// generate instructions that rename the output variables of GMR job
				Instruction cpInst=null, rmInst=null;
				for (int i = 0; i < matrices.size(); i++) {
						cpInst = VariableCPInstruction.prepareCopyInstruction(outLabels.get(i), matrices.get(i).getName());
						rmInst = VariableCPInstruction.prepareRemoveInstruction(outLabels.get(i));
						
						cpInst.setLocation(matrices.get(i));
						rmInst.setLocation(matrices.get(i));
						
						b2cinst.add(cpInst);
						b2cinst.add(rmInst);
				}
			} catch (Exception e) {
				throw new PackageRuntimeException(e);
			}
		
			//LOG instructions
			if (LOG.isTraceEnabled()){
				LOG.trace("\n--- Block-2-Cell Instructions ---");
				for(Instruction i : b2cinst) {
					LOG.trace(i.toString());
				}
				LOG.trace("----------------------------------");
			}			
		}
		
		//BEGIN FUNCTION PATCH
		if( !matricesNoReblock.isEmpty() )
		{
			//if( b2cinst==null )
			//	b2cinst = new ArrayList();
			
			for( int i=0; i cla = (Class) Class.forName(className);
			o = cla.newInstance();
		} 
		catch (Exception e) 
		{
			throw new PackageRuntimeException(this.printBlockErrorLocation() + "Error generating package function object " ,e );
		}

		if (!(o instanceof PackageFunction))
			throw new PackageRuntimeException(this.printBlockErrorLocation() + "Class is not of type PackageFunction");

		PackageFunction func = (PackageFunction) o;

		// add inputs to this package function based on input parameter
		// and their mappings.
		setupInputs(func, inst.getInputParams(), ec.getVariables());
		func.setConfiguration(configFile);
		func.setBaseDir(_baseDir);
		
		//executes function
		func.execute();
		
		// verify output of function execution matches declaration
		// and add outputs to variableMapping and Metadata
		verifyAndAttachOutputs(ec, func, inst.getOutputParams());
	}

	/**
	 * Method to verify that function outputs match with declared outputs
	 * 
	 * @param returnFunc
	 * @param outputParams
	 * @throws DMLRuntimeException 
	 */
	protected void verifyAndAttachOutputs(ExecutionContext ec, PackageFunction returnFunc,
			String outputParams) throws DMLRuntimeException {

		ArrayList outputs = getParameters(outputParams);
		// make sure they are of equal size first

		if (outputs.size() != returnFunc.getNumFunctionOutputs()) {
			throw new PackageRuntimeException(
					"Number of function outputs ("+returnFunc.getNumFunctionOutputs()+") " +
					"does not match with declaration ("+outputs.size()+").");
		}

		// iterate over each output and verify that type matches
		for (int i = 0; i < outputs.size(); i++) {
			StringTokenizer tk = new StringTokenizer(outputs.get(i), ":");
			ArrayList tokens = new ArrayList();
			while (tk.hasMoreTokens()) {
				tokens.add(tk.nextToken());
			}

			if (returnFunc.getFunctionOutput(i).getType() == FunctionParameterType.Matrix) {
				Matrix m = (Matrix) returnFunc.getFunctionOutput(i);

				if (!(tokens.get(0)
						.compareTo(getFunctionParameterDataTypeString(FunctionParameterType.Matrix)) == 0)
						|| !(tokens.get(2).compareTo(
								getMatrixValueTypeString(m.getValueType())) == 0)) {
					throw new PackageRuntimeException(
							"Function output '"+outputs.get(i)+"' does not match with declaration.");
				}

				// add result to variableMapping
				String varName = tokens.get(1);
				MatrixObject newVar = createOutputMatrixObject( m ); 
				newVar.setVarName(varName);
				
				/* cleanup not required because done at central position (FunctionCallCPInstruction)
				MatrixObjectNew oldVar = (MatrixObjectNew)getVariable(varName);
				if( oldVar!=null )
					oldVar.clearData();*/
				
				//getVariables().put(varName, newVar); //put/override in local symbol table
				ec.setVariable(varName, newVar);
				
				continue;
			}

			if (returnFunc.getFunctionOutput(i).getType() == FunctionParameterType.Scalar) {
				Scalar s = (Scalar) returnFunc.getFunctionOutput(i);

				if (!tokens.get(0).equals(getFunctionParameterDataTypeString(FunctionParameterType.Scalar))
						|| !tokens.get(2).equals(
								getScalarValueTypeString(s.getScalarType()))) {
					throw new PackageRuntimeException(
							"Function output '"+outputs.get(i)+"' does not match with declaration.");
				}

				// allocate and set appropriate object based on type
				ScalarObject scalarObject = null;
				ScalarValueType type = s.getScalarType();
				switch (type) {
				case Integer:
					scalarObject = new IntObject(tokens.get(1),
							Long.parseLong(s.getValue()));
					break;
				case Double:
					scalarObject = new DoubleObject(tokens.get(1),
							Double.parseDouble(s.getValue()));
					break;
				case Boolean:
					scalarObject = new BooleanObject(tokens.get(1),
							Boolean.parseBoolean(s.getValue()));
					break;
				case Text:
					scalarObject = new StringObject(tokens.get(1), s.getValue());
					break;
				default:
					throw new PackageRuntimeException(
							"Unknown scalar value type '"+type+"' of output '"+outputs.get(i)+"'.");
				}

				//this.getVariables().put(tokens.get(1), scalarObject);
				ec.setVariable(tokens.get(1), scalarObject);
				continue;
			}

			if (returnFunc.getFunctionOutput(i).getType() == FunctionParameterType.Object) {
				if (!tokens.get(0).equals(getFunctionParameterDataTypeString(FunctionParameterType.Object))) {
					throw new PackageRuntimeException(
							"Function output '"+outputs.get(i)+"' does not match with declaration.");
				}

				throw new PackageRuntimeException(
						"Object types not yet supported");

				// continue;
			}

			throw new PackageRuntimeException(
					"Unknown data type '"+returnFunc.getFunctionOutput(i).getType()+"' " +
					"of output '"+outputs.get(i)+"'.");
		}
	}

	protected MatrixObject createOutputMatrixObject( Matrix m ) 
		throws CacheException 
	{
		MatrixCharacteristics mc = new MatrixCharacteristics(m.getNumRows(),m.getNumCols(), DMLTranslator.DMLBlockSize, DMLTranslator.DMLBlockSize);
		MatrixFormatMetaData mfmd = new MatrixFormatMetaData(mc, OutputInfo.TextCellOutputInfo, InputInfo.TextCellInputInfo);		
		return new MatrixObject(ValueType.DOUBLE, m.getFilePath(), mfmd);
	}

	/**
	 * Method to get string representation of scalar value type
	 * 
	 * @param scalarType
	 * @return
	 */

	protected String getScalarValueTypeString(ScalarValueType scalarType) {

		if (scalarType.equals(ScalarValueType.Double))
			return "Double";
		if (scalarType.equals(ScalarValueType.Integer))
			return "Integer";
		if (scalarType.equals(ScalarValueType.Boolean))
			return "Boolean";
		if (scalarType.equals(ScalarValueType.Text))
			return "String";

		throw new PackageRuntimeException("Unknown scalar value type");
	}

	/**
	 * Method to parse inputs, update labels, and add to package function.
	 * 
	 * @param func
	 * @param inputParams
	 * @param metaData
	 * @param variableMapping
	 */
	protected void setupInputs (PackageFunction func, String inputParams,
			LocalVariableMap variableMapping) {

		ArrayList inputs = getParameters(inputParams);
		ArrayList inputObjects = getInputObjects(inputs, variableMapping);
		func.setNumFunctionInputs(inputObjects.size());
		for (int i = 0; i < inputObjects.size(); i++)
			func.setInput(inputObjects.get(i), i);

	}

	/**
	 * Method to convert string representation of input into function input
	 * object.
	 * 
	 * @param inputs
	 * @param variableMapping
	 * @param metaData
	 * @return
	 */

	protected ArrayList getInputObjects(ArrayList inputs,
			LocalVariableMap variableMapping) {
		ArrayList inputObjects = new ArrayList();

		for (int i = 0; i < inputs.size(); i++) {
			ArrayList tokens = new ArrayList();
			StringTokenizer tk = new StringTokenizer(inputs.get(i), ":");
			while (tk.hasMoreTokens()) {
				tokens.add(tk.nextToken());
			}

			if (tokens.get(0).equals("Matrix")) {
				String varName = tokens.get(1);
				MatrixObject mobj = (MatrixObject) variableMapping.get(varName);
				MatrixCharacteristics mc = mobj.getMatrixCharacteristics();
				Matrix m = new Matrix(mobj.getFileName(),
						mc.getRows(), mc.getCols(),
						getMatrixValueType(tokens.get(2)));
				modifyInputMatrix(m, mobj);
				inputObjects.add(m);
			}

			if (tokens.get(0).equals("Scalar")) {
				String varName = tokens.get(1);
				ScalarObject so = (ScalarObject) variableMapping.get(varName);
				Scalar s = new Scalar(getScalarValueType(tokens.get(2)),
						so.getStringValue());
				inputObjects.add(s);

			}

			if (tokens.get(0).equals("Object")) {
				String varName = tokens.get(1);
				Object o = variableMapping.get(varName);
				BinaryObject obj = new BinaryObject(o);
				inputObjects.add(obj);

			}
		}

		return inputObjects;

	}

	protected void modifyInputMatrix(Matrix m, MatrixObject mobj) 
	{
		//do nothing, intended for extensions
	}

	/**
	 * Converts string representation of scalar value type to enum type
	 * 
	 * @param string
	 * @return
	 */
	protected ScalarValueType getScalarValueType(String string) {
		if (string.equals("Double"))
			return ScalarValueType.Double;
		if (string.equals("Integer"))
			return ScalarValueType.Integer;
		if (string.equals("Boolean"))
			return ScalarValueType.Boolean;
		if (string.equals("String"))
			return ScalarValueType.Text;

		throw new PackageRuntimeException("Unknown scalar type");

	}

	/**
	 * Get string representation of matrix value type
	 * 
	 * @param t
	 * @return
	 */

	protected String getMatrixValueTypeString(Matrix.ValueType t) {
		if (t.equals(Matrix.ValueType.Double))
			return "Double";

		if (t.equals(Matrix.ValueType.Integer))
			return "Integer";

		throw new PackageRuntimeException("Unknown matrix value type");
	}

	/**
	 * Converts string representation of matrix value type into enum type
	 * 
	 * @param string
	 * @return
	 */

	protected org.apache.sysml.udf.Matrix.ValueType getMatrixValueType(String string) {

		if (string.equals("Double"))
			return Matrix.ValueType.Double;
		if (string.equals("Integer"))
			return Matrix.ValueType.Integer;

		throw new PackageRuntimeException("Unknown matrix value type");

	}

	/**
	 * Method to break the comma separated input parameters into an arraylist of
	 * parameters
	 * 
	 * @param inputParams
	 * @return
	 */
	protected ArrayList getParameters(String inputParams) {
		ArrayList inputs = new ArrayList();

		StringTokenizer tk = new StringTokenizer(inputParams, ",");
		while (tk.hasMoreTokens()) {
			inputs.add(tk.nextToken());
		}

		return inputs;
	}

	/**
	 * Get string representation for data type
	 * 
	 * @param d
	 * @return
	 */
	protected String getDataTypeString(DataType d) {
		if (d.equals(DataType.MATRIX))
			return "Matrix";

		if (d.equals(DataType.SCALAR))
			return "Scalar";

		if (d.equals(DataType.OBJECT))
			return "Object";

		throw new PackageRuntimeException("Should never come here");

	}

	/**
	 * Method to get string representation of data type.
	 * 
	 * @param t
	 * @return
	 */
	protected String getFunctionParameterDataTypeString(FunctionParameterType t) {
		if (t.equals(FunctionParameterType.Matrix))
			return "Matrix";

		if (t.equals(FunctionParameterType.Scalar))
			return "Scalar";

		if (t.equals(FunctionParameterType.Object))
			return "Object";

		throw new PackageRuntimeException("Should never come here");
	}

	/**
	 * Get string representation of value type
	 * 
	 * @param v
	 * @return
	 */
	protected String getValueTypeString(ValueType v) {
		if (v.equals(ValueType.DOUBLE))
			return "Double";

		if (v.equals(ValueType.INT))
			return "Integer";

		if (v.equals(ValueType.BOOLEAN))
			return "Boolean";

		if (v.equals(ValueType.STRING))
			return "String";

		throw new PackageRuntimeException("Should never come here");
	}

	public void printMe() {
		//System.out.println("***** INSTRUCTION BLOCK *****");
		for (Instruction i : this._inst) {
			i.printMe();
		}
	}
	
	public HashMap getOtherParams()
	{
		return _otherParams;
	}
	
	public String printBlockErrorLocation(){
		return "ERROR: Runtime error in external function program block generated from external function statement block between lines " + _beginLine + " and " + _endLine + " -- ";
	}
	
	
	/////////////////////////////////////////////////
	// Extension for Global Data Flow Optimization
	// by Mathias Peters
	///////
	
	//FUNCTION PATCH
	
	private Collection _skipInReblock = new HashSet();
	private Collection _skipOutReblock = new HashSet();
	
	public void setSkippedReblockLists( Collection varsIn, Collection varsOut )
	{
		_skipInReblock.clear();
		_skipOutReblock.clear();
		
		if( varsIn!=null || varsOut!=null )
		{
			if( varsIn != null )
				_skipInReblock.addAll(varsIn);		
			if( varsOut != null )
				_skipOutReblock.addAll(varsOut);
		
			 //regenerate instructions
			createInstructions();
		}
	}
	
	
	@Override
	public ArrayList getInstructions()
	{
		ArrayList tmp = new ArrayList();
		if( cell2BlockInst != null )
			tmp.addAll(cell2BlockInst);
		if( block2CellInst != null )
			tmp.addAll(block2CellInst);
		return tmp;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy