org.apache.sysml.parser.IfStatementBlock Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.parser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map.Entry;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.parser.Expression.DataType;
public class IfStatementBlock extends StatementBlock
{
private Hop _predicateHops;
private Lop _predicateLops = null;
private boolean _requiresPredicateRecompile = false;
@Override
public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap constVars, boolean conditional)
throws LanguageException, ParseException, IOException
{
if (_statements.size() > 1){
raiseValidateError("IfStatementBlock should only have 1 statement (IfStatement)", conditional);
}
IfStatement ifstmt = (IfStatement) _statements.get(0);
//validate conditional predicate (incl constant propagation)
Expression pred = ifstmt.getConditionalPredicate().getPredicate();
pred.validateExpression(ids.getVariables(), constVars, conditional);
if( pred instanceof DataIdentifier && constVars.containsKey( ((DataIdentifier)pred).getName()) ) {
ifstmt.getConditionalPredicate().setPredicate(constVars.get(((DataIdentifier)pred).getName()));
}
HashMap constVarsIfCopy = new HashMap<>(constVars);
HashMap constVarsElseCopy = new HashMap<> (constVars);
VariableSet idsIfCopy = new VariableSet(ids);
VariableSet idsElseCopy = new VariableSet(ids);
VariableSet idsOrigCopy = new VariableSet(ids);
// handle if stmt body
_dmlProg = dmlProg;
ArrayList ifBody = ifstmt.getIfBody();
for(StatementBlock sb : ifBody){ //conditional exec
idsIfCopy = sb.validate(dmlProg, idsIfCopy, constVarsIfCopy, true);
constVarsIfCopy = sb.getConstOut();
}
// handle else stmt body
ArrayList elseBody = ifstmt.getElseBody();
for(StatementBlock sb : elseBody){ //conditional exec
idsElseCopy = sb.validate(dmlProg,idsElseCopy, constVarsElseCopy, true);
constVarsElseCopy = sb.getConstOut();
}
/////////////////////////////////////////////////////////////////////////////////
// check data type and value type are same for updated variables in both
// if statement and else statement
// (reject conditional data type change)
/////////////////////////////////////////////////////////////////////////////////
for (String updatedVar : this._updated.getVariableNames()){
DataIdentifier origVersion = idsOrigCopy.getVariable(updatedVar);
DataIdentifier ifVersion = idsIfCopy.getVariable(updatedVar);
DataIdentifier elseVersion = idsElseCopy.getVariable(updatedVar);
//data type handling: reject conditional data type change
if( ifVersion != null && elseVersion != null ) //both branches exist
{
if (!ifVersion.getOutput().getDataType().equals(elseVersion.getOutput().getDataType())){
raiseValidateError("IfStatementBlock has unsupported conditional data type change of variable '"+updatedVar+"' in if/else branch.", conditional);
}
}
else if( origVersion !=null ) //only if branch exists
{
if (!ifVersion.getOutput().getDataType().equals(origVersion.getOutput().getDataType())){
raiseValidateError("IfStatementBlock has unsupported conditional data type change of variable '"+updatedVar+"' in if branch.", conditional);
}
}
//value type handling
if (ifVersion != null && elseVersion != null && !ifVersion.getOutput().getValueType().equals(elseVersion.getOutput().getValueType())){
LOG.warn(elseVersion.printWarningLocation() + "Variable " + elseVersion.getName() + " defined with different value type in if and else clause.");
}
}
// handle constant variable propagation -- (IF UNION ELSE) MINUS updated vars
//////////////////////////////////////////////////////////////////////////////////
// handle constant variables
// 1) (IF UNION ELSE) MINUS updated const vars
// 2) reconcile updated const vars
// a) IF updated const variables have same value and datatype in both if / else branch, THEN set updated size to updated size
// b) ELSE leave out of reconciled set
/////////////////////////////////////////////////////////////////////////////////
HashMap recConstVars = new HashMap<>();
// STEP 1: (IF UNION ELSE) MINUS updated vars
for (Entry e : constVarsIfCopy.entrySet() ){
String varName = e.getKey();
if (!_updated.containsVariable(varName))
recConstVars.put(varName, e.getValue());
}
for (Entry e : constVarsElseCopy.entrySet() ){
String varName = e.getKey();
if (!_updated.containsVariable(varName))
recConstVars.put(varName, e.getValue());
}
// STEP 2: check that updated const values have in both if / else branches
// a) same data type,
// b) same value type (SCALAR),
// c) same value
for (String updatedVar : _updated.getVariableNames()){
DataIdentifier ifVersion = idsIfCopy.getVariable(updatedVar);
DataIdentifier elseVersion = idsElseCopy.getVariable(updatedVar);
if (ifVersion != null && elseVersion != null
&& ifVersion.getOutput().getDataType().equals(DataType.SCALAR)
&& elseVersion.getOutput().getDataType().equals(DataType.SCALAR)
&& ifVersion.getOutput().getValueType().equals(elseVersion.getOutput().getValueType()))
{
ConstIdentifier ifConstVersion = constVarsIfCopy.get(updatedVar);
ConstIdentifier elseConstVersion = constVarsElseCopy.get(updatedVar);
// IntIdentifier
if (ifConstVersion != null && elseConstVersion != null && ifConstVersion instanceof IntIdentifier && elseConstVersion instanceof IntIdentifier){
if ( ((IntIdentifier)ifConstVersion).getValue() == ((IntIdentifier) elseConstVersion).getValue() )
recConstVars.put(updatedVar, ifConstVersion);
}
// DoubleIdentifier
else if (ifConstVersion != null && elseConstVersion != null && ifConstVersion instanceof DoubleIdentifier && elseConstVersion instanceof DoubleIdentifier){
if ( ((DoubleIdentifier)ifConstVersion).getValue() == ((DoubleIdentifier) elseConstVersion).getValue() )
recConstVars.put(updatedVar, ifConstVersion);
}
// Boolean
else if (ifConstVersion != null && elseConstVersion != null && ifConstVersion instanceof BooleanIdentifier && elseConstVersion instanceof BooleanIdentifier){
if ( ((BooleanIdentifier)ifConstVersion).getValue() == ((BooleanIdentifier) elseConstVersion).getValue() )
recConstVars.put(updatedVar, ifConstVersion);
}
// String
else if (ifConstVersion != null && elseConstVersion != null && ifConstVersion instanceof StringIdentifier && elseConstVersion instanceof StringIdentifier){
if ( ((StringIdentifier)ifConstVersion).getValue().equals(((StringIdentifier) elseConstVersion).getValue()) )
recConstVars.put(updatedVar, ifConstVersion);
}
}
}
//////////////////////////////////////////////////////////////////////////////////
// handle DataIdentifier variables
// 1) (IF UNION ELSE) MINUS updated vars
// 2) reconcile size updated variables
// a) IF updated variables have same size in both if / else branch, THEN set updated size to updated size
// b) ELSE set size updated to (-1,-1)
// 3) add updated vars to reconciled set
/////////////////////////////////////////////////////////////////////////////////
// STEP 1: (IF UNION ELSE) MINUS updated vars
VariableSet recVars = new VariableSet();
for (String varName : idsIfCopy.getVariableNames()){
if (!_updated.containsVariable(varName))
recVars.addVariable(varName,idsIfCopy.getVariable(varName));
}
for (String varName : idsElseCopy.getVariableNames()){
if (!_updated.containsVariable(varName))
recVars.addVariable(varName,idsElseCopy.getVariable(varName));
}
// STEP 2: reconcile size of updated variables
for (String updatedVar : _updated.getVariableNames()){
DataIdentifier ifVersion = idsIfCopy.getVariable(updatedVar);
DataIdentifier elseVersion = idsElseCopy.getVariable(updatedVar);
DataIdentifier origVersion = idsOrigCopy.getVariable(updatedVar);
if (ifVersion != null && elseVersion != null) {
long updatedDim1 = -1, updatedDim2 = -1;
long updatedNnz = -1;
long ifVersionDim1 = (ifVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)ifVersion).getOrigDim1() : ifVersion.getDim1();
long elseVersionDim1 = (elseVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)elseVersion).getOrigDim1() : elseVersion.getDim1();
long ifVersionDim2 = (ifVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)ifVersion).getOrigDim2() : ifVersion.getDim2();
long elseVersionDim2 = (elseVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)elseVersion).getOrigDim2() : elseVersion.getDim2();
if( ifVersionDim1 == elseVersionDim1 ){
updatedDim1 = ifVersionDim1;
}
if( ifVersionDim2 == elseVersionDim2 ){
updatedDim2 = ifVersionDim2;
}
//NOTE: nnz not propagated via validate, and hence, we conservatively assume that nnz have been changed.
//if( ifVersion.getNnz() == elseVersion.getNnz() ){
// updatedNnz = ifVersion.getNnz();
//}
// add reconsiled version (deep copy of ifVersion, cast as DataIdentifier)
DataIdentifier recVersion = new DataIdentifier(ifVersion);
recVersion.setDimensions(updatedDim1, updatedDim2);
recVersion.setNnz(updatedNnz);
recVars.addVariable(updatedVar, recVersion);
}
else {
// CASE: defined only if branch
DataIdentifier recVersion = null;
if (ifVersion != null){
// add reconciled version (deep copy of ifVersion, cast as DataIdentifier)
recVersion = new DataIdentifier(ifVersion);
recVars.addVariable(updatedVar, recVersion);
}
// CASE: defined only else branch
else if (elseVersion != null){
// add reconciled version (deep copy of elseVersion, cast as DataIdentifier)
recVersion = new DataIdentifier(elseVersion);
recVars.addVariable(updatedVar, recVersion);
}
// CASE: updated, but not in either if or else branch
else {
// add reconciled version (deep copy of elseVersion, cast as DataIdentifier)
recVersion = new DataIdentifier(_updated.getVariable(updatedVar));
recVars.addVariable(updatedVar, recVersion);
}
long updatedDim1 = -1, updatedDim2 = -1;
long updatedNnz = -1;
if( origVersion != null ) {
long origVersionDim1 = (origVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)origVersion).getOrigDim1() : origVersion.getDim1();
long recVersionDim1 = recVersion.getDim1(); //always DataIdentifier (see above)
long origVersionDim2 = (origVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)origVersion).getOrigDim2() : origVersion.getDim2();
long recVersionDim2 = recVersion.getDim2(); //always DataIdentifier (see above)
if( origVersionDim1 == recVersionDim1 ){
updatedDim1 = origVersionDim1;
}
if( origVersionDim2 == recVersionDim2 ){
updatedDim2 = origVersionDim2;
}
//NOTE: nnz not propagated via validate, and hence, we conservatively assume that nnz have been changed.
//if( origVersion.getNnz() == recVersion.getNnz() ){
// updatedNnz = recVersion.getNnz();
//}
}
recVersion.setDimensions(updatedDim1, updatedDim2);
recVersion.setNnz(updatedNnz);
}
}
// propagate updated variables
VariableSet allIdVars = new VariableSet(recVars);
_constVarsIn.putAll(constVars);
_constVarsOut.putAll(recConstVars);
return allIdVars;
}
@Override
public VariableSet initializeforwardLV(VariableSet activeInPassed)
throws LanguageException
{
IfStatement ifstmt = (IfStatement)_statements.get(0);
if (_statements.size() > 1){
LOG.error(ifstmt.printErrorLocation() + "IfStatementBlock should have only 1 statement (if statement)");
throw new LanguageException(ifstmt.printErrorLocation() + "IfStatementBlock should have only 1 statement (if statement)");
}
_read = new VariableSet();
_gen = new VariableSet();
_kill = new VariableSet();
_warnSet = new VariableSet();
///////////////////////////////////////////////////////////////////////
// HANDLE PREDICATE
///////////////////////////////////////////////////////////////////////
_read.addVariables(ifstmt.getConditionalPredicate().variablesRead());
_updated.addVariables(ifstmt.getConditionalPredicate().variablesUpdated());
_gen.addVariables(ifstmt.getConditionalPredicate().variablesRead());
///////////////////////////////////////////////////////////////////////
// IF STATEMENT
///////////////////////////////////////////////////////////////////////
// initialize forward for each statement block in if body
VariableSet ifCurrent = new VariableSet(activeInPassed);
VariableSet genIfBody = new VariableSet();
VariableSet killIfBody = new VariableSet();
VariableSet updatedIfBody = new VariableSet();
VariableSet readIfBody = new VariableSet();
for (StatementBlock sb : ifstmt.getIfBody())
{
ifCurrent = sb.initializeforwardLV(ifCurrent);
// for each generated variable in this block, check variable not killed
// (assigned value) in prior statement block in ifstmt blody
for (String varName : sb._gen.getVariableNames()){
// IF the variable is NOT set in the while loop PRIOR to this stmt block,
// THEN needs to be generated
if (!killIfBody.getVariableNames().contains(varName)){
genIfBody.addVariable(varName, sb._gen.getVariable(varName));
}
}
readIfBody.addVariables(sb._read);
updatedIfBody.addVariables(sb._updated);
// only add kill variables for statement blocks guaranteed to execute
if (!(sb instanceof WhileStatementBlock) && !(sb instanceof ForStatementBlock) ){
killIfBody.addVariables(sb._kill);
}
}
///////////////////////////////////////////////////////////////////////
// ELSE STATEMENT
///////////////////////////////////////////////////////////////////////
// initialize forward for each statement block in if body
VariableSet elseCurrent = new VariableSet(activeInPassed);
VariableSet genElseBody = new VariableSet();
VariableSet killElseBody = new VariableSet();
VariableSet updatedElseBody = new VariableSet();
VariableSet readElseBody = new VariableSet();
// initialize forward for each statement block in else body
for (StatementBlock sb : ifstmt.getElseBody())
{
elseCurrent = sb.initializeforwardLV(elseCurrent);
// for each generated variable in this block, check variable not killed
// (assigned value) in prior statement block in ifstmt blody
for (String varName : sb._gen.getVariableNames()){
// IF the variable is NOT set in the while loop PRIOR to this stmt block,
// THEN needs to be generated
if (!killElseBody.getVariableNames().contains(varName)){
genElseBody.addVariable(varName, sb._gen.getVariable(varName));
}
}
readElseBody.addVariables(sb._read);
updatedElseBody.addVariables(sb._updated);
// only add kill variables for statement blocks guaranteed to execute
if (!(sb instanceof WhileStatementBlock) && !(sb instanceof ForStatementBlock) ){
killElseBody.addVariables(sb._kill);
}
}
///////////////////////////////////////////////////////////////////////
// PERFORM RECONCILIATION
///////////////////////////////////////////////////////////////////////
// "conservative" read -- union of read sets for if and else path
_read.addVariables(readIfBody);
_read.addVariables(readElseBody);
// "conservative" update -- union of updated
_updated.addVariables(updatedIfBody);
_updated.addVariables(updatedElseBody);
// "conservative" gen -- union of gen
_gen.addVariables(genIfBody);
_gen.addVariables(genElseBody);
// "conservative" kill -- kill set is intersection of if-kill and else-kill
for ( String varName : killIfBody.getVariableNames()){
if (killElseBody.containsVariable(varName)){
_kill.addVariable(varName, killIfBody.getVariable(varName));
}
}
// set preliminary "warn" set -- variables that if used later may cause runtime error
// if the loop is not executed
// warnSet = (updated MINUS (updatedIfBody INTERSECT updatedElseBody)) MINUS current
for (String varName : _updated.getVariableNames()){
if (!((updatedIfBody.containsVariable(varName) && updatedElseBody.containsVariable(varName))
|| activeInPassed.containsVariable(varName))) {
_warnSet.addVariable(varName, _updated.getVariable(varName));
}
}
// set activeOut to (if body current UNION else body current) UNION updated
_liveOut = new VariableSet();
_liveOut.addVariables(ifCurrent);
_liveOut.addVariables(elseCurrent);
_liveOut.addVariables(_updated);
return _liveOut;
}
@Override
public VariableSet initializebackwardLV(VariableSet loPassed)
throws LanguageException
{
IfStatement ifstmt = (IfStatement)_statements.get(0);
if (_statements.size() > 1){
LOG.error(ifstmt.printErrorLocation() + "IfStatementBlock should have only 1 statement (if statement)");
throw new LanguageException(ifstmt.printErrorLocation() + "IfStatementBlock should have only 1 statement (if statement)");
}
VariableSet currentLiveOutIf = new VariableSet(loPassed);
VariableSet currentLiveOutElse = new VariableSet(loPassed);
int numBlocks = ifstmt.getIfBody().size();
for (int i = numBlocks - 1; i >= 0; i--){
currentLiveOutIf = ifstmt.getIfBody().get(i).analyze(currentLiveOutIf);
}
numBlocks = ifstmt.getElseBody().size();
for (int i = numBlocks - 1; i >= 0; i--){
currentLiveOutElse = ifstmt.getElseBody().get(i).analyze(currentLiveOutElse);
}
// Any variable defined in either if-body or else-body is available for later use
VariableSet bothPathsLiveOut = new VariableSet();
bothPathsLiveOut.addVariables(currentLiveOutIf);
bothPathsLiveOut.addVariables(currentLiveOutElse);
return bothPathsLiveOut;
}
public void setPredicateHops(Hop hops) {
_predicateHops = hops;
}
public Hop getPredicateHops(){
return _predicateHops;
}
public Lop get_predicateLops() {
return _predicateLops;
}
public void set_predicateLops(Lop predicateLops) {
_predicateLops = predicateLops;
}
@Override
public VariableSet analyze(VariableSet loPassed) throws LanguageException{
VariableSet predVars = ((IfStatement)_statements.get(0)).getConditionalPredicate().variablesRead();
predVars.addVariables(((IfStatement)_statements.get(0)).getConditionalPredicate().variablesUpdated());
VariableSet candidateLO = new VariableSet(loPassed);
VariableSet origLiveOut = new VariableSet(_liveOut);
_liveOut = new VariableSet();
for (String name : candidateLO.getVariableNames()){
if (origLiveOut.containsVariable(name)){
_liveOut.addVariable(name, candidateLO.getVariable(name));
}
}
initializebackwardLV(_liveOut);
// set final warnSet: remove variables NOT in live out
VariableSet finalWarnSet = new VariableSet();
for (String varName : _warnSet.getVariableNames()){
if (_liveOut.containsVariable(varName)){
finalWarnSet.addVariable(varName,_warnSet.getVariable(varName));
}
}
_warnSet = finalWarnSet;
// for now just print the warn set
for (String varName : _warnSet.getVariableNames()){
LOG.warn(_warnSet.getVariable(varName).printWarningLocation() + "Initialization of " + varName + " depends on if-else execution");
}
//data flow equation: liveout = gen \cup (liveout - kill)
_liveIn = new VariableSet();
_liveIn.addVariables(_liveOut);
_liveIn.removeVariables(_kill);
_liveIn.addVariables(_gen);
VariableSet liveInReturn = new VariableSet();
liveInReturn.addVariables(_liveIn);
return liveInReturn;
}
/////////
// materialized hops recompilation flags
////
public boolean updatePredicateRecompilationFlag() throws HopsException {
return (_requiresPredicateRecompile =
ConfigurationManager.isDynamicRecompilation()
&& Recompiler.requiresRecompilation(getPredicateHops()));
}
public boolean requiresPredicateRecompilation() {
return _requiresPredicateRecompile;
}
}