org.apache.sysml.parser.FunctionStatementBlock Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.parser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.FunctionOp.FunctionType;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
public class FunctionStatementBlock extends StatementBlock
{
private boolean _recompileOnce = false;
/**
* TODO: DRB: This needs to be changed to reflect:
*
* 1) Default values for variables -- need to add R styled check here to make sure that once vars with
* default values start, they keep going to the right
*
* 2) The other parameters for External Functions
* @throws IOException
*/
@Override
public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap constVars, boolean conditional)
throws LanguageException, ParseException, IOException
{
if (_statements.size() > 1){
LOG.error(this.printBlockErrorLocation() + "FunctionStatementBlock should have only 1 statement (FunctionStatement)");
throw new LanguageException(this.printBlockErrorLocation() + "FunctionStatementBlock should have only 1 statement (FunctionStatement)");
}
FunctionStatement fstmt = (FunctionStatement) _statements.get(0);
// validate all function input parameters
ArrayList inputValues = fstmt.getInputParams();
for( DataIdentifier inputValue : inputValues ) {
//check all input matrices have value type double
if( inputValue.getDataType()==DataType.MATRIX && inputValue.getValueType()!=ValueType.DOUBLE ) {
raiseValidateError("for function " + fstmt.getName() + ", input variable " + inputValue.getName()
+ " has an unsupported value type of " + inputValue.getValueType() + ".", false);
}
}
// handle DML-bodied functions
if (!(fstmt instanceof ExternalFunctionStatement))
{
// perform validate for function body
this._dmlProg = dmlProg;
for(StatementBlock sb : fstmt.getBody())
{
ids = sb.validate(dmlProg, ids, constVars, conditional);
constVars = sb.getConstOut();
}
if (fstmt.getBody().size() > 0)
_constVarsIn.putAll(fstmt.getBody().get(0).getConstIn());
if (fstmt.getBody().size() > 1)
_constVarsOut.putAll(fstmt.getBody().get(fstmt.getBody().size()-1).getConstOut());
// for each return value, check variable is defined and validate the return type
// if returnValue type known incorrect, then throw exception
ArrayList returnValues = fstmt.getOutputParams();
for (DataIdentifier returnValue : returnValues){
DataIdentifier curr = ids.getVariable(returnValue.getName());
if (curr == null){
raiseValidateError("for function " + fstmt.getName() + ", return variable " + returnValue.getName() + " must be defined in function ", conditional);
}
if (curr.getDataType() == DataType.UNKNOWN){
raiseValidateError("for function " + fstmt.getName() + ", return variable " + curr.getName() + " data type of " + curr.getDataType() + " may not match data type in function signature of " + returnValue.getDataType(), true);
}
if (curr.getValueType() == ValueType.UNKNOWN){
raiseValidateError("for function " + fstmt.getName() + ", return variable " + curr.getName() + " data type of " + curr.getValueType() + " may not match data type in function signature of " + returnValue.getValueType(), true);
}
if (curr.getDataType() != DataType.UNKNOWN && !curr.getDataType().equals(returnValue.getDataType()) ){
raiseValidateError("for function " + fstmt.getName() + ", return variable " + curr.getName() + " data type of " + curr.getDataType() + " does not match data type in function signature of " + returnValue.getDataType(), conditional);
}
if (curr.getValueType() != ValueType.UNKNOWN && !curr.getValueType().equals(returnValue.getValueType())){
// attempt to convert value type: handle conversion from scalar DOUBLE or INT
if (curr.getDataType() == DataType.SCALAR && returnValue.getDataType() == DataType.SCALAR){
if (returnValue.getValueType() == ValueType.DOUBLE){
if (curr.getValueType() == ValueType.INT){
IntIdentifier currIntValue = (IntIdentifier)constVars.get(curr.getName());
if (currIntValue != null){
DoubleIdentifier currDoubleValue = new DoubleIdentifier(currIntValue.getValue(),
curr.getFilename(), curr.getBeginLine(), curr.getBeginColumn(),
curr.getEndLine(), curr.getEndColumn());
constVars.put(curr.getName(), currDoubleValue);
}
LOG.warn(curr.printWarningLocation() + "for function " + fstmt.getName()
+ ", return variable " + curr.getName() + " value type of "
+ curr.getValueType() + " does not match value type in function signature of "
+ returnValue.getValueType() + " but was safely cast");
curr.setValueType(ValueType.DOUBLE);
ids.addVariable(curr.getName(), curr);
}
else {
// THROW EXCEPTION -- CANNOT CONVERT
LOG.error(curr.printErrorLocation() + "for function " + fstmt.getName()
+ ", return variable " + curr.getName() + " value type of "
+ curr.getValueType() + " does not match value type in function signature of "
+ returnValue.getValueType() + " and cannot safely cast value");
throw new LanguageException(curr.printErrorLocation() + "for function "
+ fstmt.getName() + ", return variable " + curr.getName()
+ " value type of " + curr.getValueType()
+ " does not match value type in function signature of "
+ returnValue.getValueType() + " and cannot safely cast value");
}
}
if (returnValue.getValueType() == ValueType.INT){
// THROW EXCEPTION -- CANNOT CONVERT
LOG.error(curr.printErrorLocation() + "for function " + fstmt.getName()
+ ", return variable " + curr.getName() + " value type of "
+ curr.getValueType() + " does not match value type in function signature of "
+ returnValue.getValueType() + " and cannot safely cast " + curr.getValueType()
+ " as " + returnValue.getValueType());
throw new LanguageException(curr.printErrorLocation() + "for function " + fstmt.getName()
+ ", return variable " + curr.getName() + " value type of " + curr.getValueType()
+ " does not match value type in function signature of "
+ returnValue.getValueType() + " and cannot safely cast " + curr.getValueType()
+ " as " + returnValue.getValueType());
}
}
else {
LOG.error(curr.printErrorLocation() + "for function " + fstmt.getName() + ", return variable " + curr.getName() + " value type of " + curr.getValueType() + " does not match value type in function signature of " + returnValue.getValueType() + " and cannot safely cast double as int");
throw new LanguageException(curr.printErrorLocation() + "for function " + fstmt.getName() + ", return variable " + curr.getName() + " value type of " + curr.getValueType() + " does not match value type in function signature of " + returnValue.getValueType() + " and cannot safely cast " + curr.getValueType() + " as " + returnValue.getValueType());
}
}
}
}
// handle external functions
else
{
//validate specified attributes and attribute values
ExternalFunctionStatement efstmt = (ExternalFunctionStatement) fstmt;
efstmt.validateParameters(this);
//validate child statements
this._dmlProg = dmlProg;
for(StatementBlock sb : efstmt.getBody())
{
ids = sb.validate(dmlProg, ids, constVars, conditional);
constVars = sb.getConstOut();
}
}
return ids;
}
public FunctionType getFunctionOpType()
{
FunctionType ret = FunctionType.UNKNOWN;
FunctionStatement fstmt = (FunctionStatement) _statements.get(0);
if (fstmt instanceof ExternalFunctionStatement)
{
ExternalFunctionStatement efstmt = (ExternalFunctionStatement) fstmt;
String execType = efstmt.getOtherParams().get(ExternalFunctionStatement.EXEC_TYPE);
if( execType!=null ){
if(execType.equals(ExternalFunctionStatement.IN_MEMORY))
ret = FunctionType.EXTERNAL_MEM;
else
ret = FunctionType.EXTERNAL_FILE;
}
}
else
{
ret = FunctionType.DML;
}
return ret;
}
public VariableSet initializeforwardLV(VariableSet activeInPassed) throws LanguageException {
FunctionStatement fstmt = (FunctionStatement)_statements.get(0);
if (_statements.size() > 1){
LOG.error(this.printBlockErrorLocation() + "FunctionStatementBlock should have only 1 statement (while statement)");
throw new LanguageException(this.printBlockErrorLocation() + "FunctionStatementBlock should have only 1 statement (while statement)");
}
_read = new VariableSet();
_gen = new VariableSet();
VariableSet current = new VariableSet();
current.addVariables(activeInPassed);
for( StatementBlock sb : fstmt.getBody() )
{
current = sb.initializeforwardLV(current);
// for each generated variable in this block, check variable not killed
// in prior statement block in while stmt blody
for (String varName : sb._gen.getVariableNames()){
// IF the variable is NOT set in the while loop PRIOR to this stmt block,
// THEN needs to be generated
if (!_kill.getVariableNames().contains(varName)){
_gen.addVariable(varName, sb._gen.getVariable(varName));
}
}
_read.addVariables(sb._read);
_updated.addVariables(sb._updated);
// only add kill variables for statement blocks guaranteed to execute
if (!(sb instanceof WhileStatementBlock) && !(sb instanceof ForStatementBlock) ){
_kill.addVariables(sb._kill);
}
}
// activeOut includes variables from passed live in and updated in the while body
_liveOut = new VariableSet();
_liveOut.addVariables(current);
_liveOut.addVariables(_updated);
return _liveOut;
}
public VariableSet initializebackwardLV(VariableSet loPassed) throws LanguageException{
FunctionStatement wstmt = (FunctionStatement)_statements.get(0);
VariableSet lo = new VariableSet();
lo.addVariables(loPassed);
// calls analyze for each statement block in while stmt body
int numBlocks = wstmt.getBody().size();
for (int i = numBlocks - 1; i >= 0; i--){
lo = wstmt.getBody().get(i).analyze(lo);
}
VariableSet loReturn = new VariableSet();
loReturn.addVariables(lo);
return loReturn;
}
public ArrayList get_hops() throws HopsException {
if (_hops != null && _hops.size() > 0){
LOG.error(this.printBlockErrorLocation() + "there should be no HOPs associated with the FunctionStatementBlock");
throw new HopsException(this.printBlockErrorLocation() + "there should be no HOPs associated with the FunctionStatementBlock");
}
return _hops;
}
public VariableSet analyze(VariableSet loPassed) throws LanguageException{
LOG.error(this.printBlockErrorLocation() + "Both liveIn and liveOut variables need to be specified for liveness analysis for FunctionStatementBlock");
throw new LanguageException(this.printBlockErrorLocation() + "Both liveIn and liveOut variables need to be specified for liveness analysis for FunctionStatementBlock");
}
public VariableSet analyze(VariableSet liPassed, VariableSet loPassed) throws LanguageException{
VariableSet candidateLO = new VariableSet();
candidateLO.addVariables(loPassed);
candidateLO.addVariables(_gen);
VariableSet origLiveOut = new VariableSet();
origLiveOut.addVariables(_liveOut);
_liveOut = new VariableSet();
for (String name : candidateLO.getVariableNames()){
if (origLiveOut.containsVariable(name)){
_liveOut.addVariable(name, candidateLO.getVariable(name));
}
}
initializebackwardLV(_liveOut);
// Cannot remove kill variables
_liveIn = new VariableSet();
_liveIn.addVariables(liPassed);
VariableSet liveInReturn = new VariableSet();
liveInReturn.addVariables(_liveIn);
return liveInReturn;
}
public void setRecompileOnce( boolean flag ) {
_recompileOnce = flag;
}
public boolean isRecompileOnce() {
return _recompileOnce;
}
}