org.apache.sysml.parser.ForStatementBlock Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.parser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.runtime.instructions.cp.BooleanObject;
import org.apache.sysml.runtime.instructions.cp.DoubleObject;
public class ForStatementBlock extends StatementBlock
{
protected Hop _fromHops = null;
protected Hop _toHops = null;
protected Hop _incrementHops = null;
protected Lop _fromLops = null;
protected Lop _toLops = null;
protected Lop _incrementLops = null;
protected boolean _requiresFromRecompile = false;
protected boolean _requiresToRecompile = false;
protected boolean _requiresIncrementRecompile = false;
public IterablePredicate getIterPredicate(){
return ((ForStatement)_statements.get(0)).getIterablePredicate();
}
@Override
public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap constVars, boolean conditional)
throws LanguageException, ParseException, IOException
{
if (_statements.size() > 1){
raiseValidateError("ForStatementBlock should have only 1 statement (for statement)", conditional);
}
ForStatement fs = (ForStatement) _statements.get(0);
IterablePredicate predicate = fs.getIterablePredicate();
// Record original size information before loop for ALL variables
// Will compare size / type info for these after loop completes
// Replace variables with changed size with unknown value
VariableSet origVarsBeforeBody = new VariableSet();
for (String key : ids.getVariableNames()){
DataIdentifier origId = ids.getVariable(key);
DataIdentifier copyId = new DataIdentifier(origId);
origVarsBeforeBody.addVariable(key, copyId);
}
//////////////////////////////////////////////////////////////////////////////
// FIRST PASS: process the predicate / statement blocks in the body of the for statement
///////////////////////////////////////////////////////////////////////////////
//remove updated vars from constants
for( String var : _updated.getVariableNames() )
if( constVars.containsKey( var ) )
constVars.remove( var );
predicate.validateExpression(ids.getVariables(), constVars, conditional);
ArrayList body = fs.getBody();
//perform constant propagation for ( from, to, incr )
//(e.g., useful for reducing false positives in parfor dependency analysis)
performConstantPropagation(constVars);
//validate body
_dmlProg = dmlProg;
for(StatementBlock sb : body)
{
ids = sb.validate(dmlProg, ids, constVars, true);
constVars = sb.getConstOut();
}
if (!body.isEmpty()){
_constVarsIn.putAll(body.get(0).getConstIn());
_constVarsOut.putAll(body.get(body.size()-1).getConstOut());
}
// for each updated variable
boolean revalidationRequired = false;
for (String key : _updated.getVariableNames())
{
DataIdentifier startVersion = origVarsBeforeBody.getVariable(key);
DataIdentifier endVersion = ids.getVariable(key);
if (startVersion != null && endVersion != null)
{
//handle data type change (reject)
if (!startVersion.getOutput().getDataType().equals(endVersion.getOutput().getDataType())){
raiseValidateError("ForStatementBlock has unsupported conditional data type change of variable '"+key+"' in loop body.", conditional);
}
//handle size change
long startVersionDim1 = (startVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)startVersion).getOrigDim1() : startVersion.getDim1();
long endVersionDim1 = (endVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)endVersion).getOrigDim1() : endVersion.getDim1();
long startVersionDim2 = (startVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)startVersion).getOrigDim2() : startVersion.getDim2();
long endVersionDim2 = (endVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)endVersion).getOrigDim2() : endVersion.getDim2();
boolean sizeUnchanged = ((startVersionDim1 == endVersionDim1) &&
(startVersionDim2 == endVersionDim2) );
//handle sparsity change
//NOTE: nnz not propagated via validate, and hence, we conservatively assume that nnz have been changed.
//long startVersionNNZ = startVersion.getNnz();
//long endVersionNNZ = endVersion.getNnz();
//boolean nnzUnchanged = (startVersionNNZ == endVersionNNZ);
boolean nnzUnchanged = false;
// IF size has changed --
if (!sizeUnchanged || !nnzUnchanged){
revalidationRequired = true;
DataIdentifier recVersion = new DataIdentifier(endVersion);
if(!sizeUnchanged)
recVersion.setDimensions(-1, -1);
if(!nnzUnchanged)
recVersion.setNnz(-1);
origVarsBeforeBody.addVariable(key, recVersion);
}
}
}
// revalidation is required -- size was updated for at least 1 variable
if (revalidationRequired){
// update ids to the reconciled values
ids = origVarsBeforeBody;
//////////////////////////////////////////////////////////////////////////////
// SECOND PASS: process the predicate / statement blocks in the body of the for statement
///////////////////////////////////////////////////////////////////////////////
//remove updated vars from constants
for( String var : _updated.getVariableNames() )
if( constVars.containsKey( var ) )
constVars.remove( var );
//perform constant propagation for ( from, to, incr )
//(e.g., useful for reducing false positives in parfor dependency analysis)
performConstantPropagation(constVars);
predicate.validateExpression(ids.getVariables(), constVars, conditional);
body = fs.getBody();
//validate body
_dmlProg = dmlProg;
for(StatementBlock sb : body)
{
ids = sb.validate(dmlProg, ids, constVars, true);
constVars = sb.getConstOut();
}
if (!body.isEmpty()){
_constVarsIn.putAll(body.get(0).getConstIn());
_constVarsOut.putAll(body.get(body.size()-1).getConstOut());
}
}
return ids;
}
public VariableSet initializeforwardLV(VariableSet activeInPassed) throws LanguageException {
ForStatement fstmt = (ForStatement)_statements.get(0);
if (_statements.size() > 1){
LOG.error(_statements.get(0).printErrorLocation() + "ForStatementBlock should have only 1 statement (for statement)");
throw new LanguageException(_statements.get(0).printErrorLocation() + "ForStatementBlock should have only 1 statement (for statement)");
}
_read = new VariableSet();
_read.addVariables(fstmt.getIterablePredicate().variablesRead());
_updated.addVariables(fstmt.getIterablePredicate().variablesUpdated());
_gen = new VariableSet();
_gen.addVariables(fstmt.getIterablePredicate().variablesRead());
// add the iterVar from iterable predicate to kill set
_kill.addVariables(fstmt.getIterablePredicate().variablesUpdated());
VariableSet current = new VariableSet();
current.addVariables(activeInPassed);
current.addVariables(_updated);
for( StatementBlock sb : fstmt.getBody())
{
current = sb.initializeforwardLV(current);
// for each generated variable in this block, check variable not killed
// in prior statement block in while stmt blody
for (String varName : sb._gen.getVariableNames()){
// IF the variable is NOT set in the while loop PRIOR to this stmt block,
// THEN needs to be generated
if (!_kill.getVariableNames().contains(varName)){
_gen.addVariable(varName, sb._gen.getVariable(varName));
}
}
_read.addVariables(sb._read);
_updated.addVariables(sb._updated);
// only add kill variables for statement blocks guaranteed to execute
if (!(sb instanceof WhileStatementBlock) && !(sb instanceof ForStatementBlock) ){
_kill.addVariables(sb._kill);
}
}
// set preliminary "warn" set -- variables that if used later may cause runtime error
// if the loop is not executed
// warnSet = (updated MINUS (updatedIfBody INTERSECT updatedElseBody)) MINUS current
for (String varName : _updated.getVariableNames()){
if (!activeInPassed.containsVariable(varName)) {
_warnSet.addVariable(varName, _updated.getVariable(varName));
}
}
// activeOut includes variables from passed live in and updated in the while body
_liveOut = new VariableSet();
_liveOut.addVariables(current);
_liveOut.addVariables(_updated);
return _liveOut;
}
public VariableSet initializebackwardLV(VariableSet loPassed) throws LanguageException{
ForStatement fstmt = (ForStatement)_statements.get(0);
VariableSet lo = new VariableSet();
lo.addVariables(loPassed);
// calls analyze for each statement block in while stmt body
int numBlocks = fstmt.getBody().size();
for (int i = numBlocks - 1; i >= 0; i--){
lo = fstmt.getBody().get(i).analyze(lo);
}
VariableSet loReturn = new VariableSet();
loReturn.addVariables(lo);
return loReturn;
}
public ArrayList get_hops() throws HopsException {
if (_hops != null && !_hops.isEmpty()){
LOG.error(this.printBlockErrorLocation() + "there should be no HOPs associated with the ForStatementBlock");
throw new HopsException(this.printBlockErrorLocation() + "there should be no HOPs associated with the ForStatementBlock");
}
return _hops;
}
public void setFromHops(Hop hops) { _fromHops = hops; }
public void setToHops(Hop hops) { _toHops = hops; }
public void setIncrementHops(Hop hops) { _incrementHops = hops; }
public Hop getFromHops() { return _fromHops; }
public Hop getToHops() { return _toHops; }
public Hop getIncrementHops() { return _incrementHops; }
public void setFromLops(Lop lops) {
_fromLops = lops;
}
public void setToLops(Lop lops) { _toLops = lops; }
public void setIncrementLops(Lop lops) { _incrementLops = lops; }
public Lop getFromLops() { return _fromLops; }
public Lop getToLops() { return _toLops; }
public Lop getIncrementLops() { return _incrementLops; }
public VariableSet analyze(VariableSet loPassed) throws LanguageException{
VariableSet predVars = new VariableSet();
IterablePredicate ip = ((ForStatement)_statements.get(0)).getIterablePredicate();
predVars.addVariables(ip.variablesRead());
predVars.addVariables(ip.variablesUpdated());
VariableSet candidateLO = new VariableSet();
candidateLO.addVariables(loPassed);
candidateLO.addVariables(_gen);
candidateLO.addVariables(predVars);
VariableSet origLiveOut = new VariableSet();
origLiveOut.addVariables(_liveOut);
origLiveOut.addVariables(predVars);
origLiveOut.addVariables(_gen);
_liveOut = new VariableSet();
for (String name : candidateLO.getVariableNames()){
if (origLiveOut.containsVariable(name)){
_liveOut.addVariable(name, candidateLO.getVariable(name));
}
}
initializebackwardLV(_liveOut);
// set final warnSet: remove variables NOT in live out
VariableSet finalWarnSet = new VariableSet();
for (String varName : _warnSet.getVariableNames()){
if (_liveOut.containsVariable(varName)){
finalWarnSet.addVariable(varName,_warnSet.getVariable(varName));
}
}
_warnSet = finalWarnSet;
// for now just print the warn set
for (String varName : _warnSet.getVariableNames()) {
if( !ip.getIterVar().getName().equals( varName) )
LOG.warn(_warnSet.getVariable(varName).printWarningLocation() + "Initialization of " + varName + " depends on for execution");
}
// Cannot remove kill variables
_liveIn = new VariableSet();
_liveIn.addVariables(_liveOut);
_liveIn.addVariables(_gen);
VariableSet liveInReturn = new VariableSet();
liveInReturn.addVariables(_liveIn);
return liveInReturn;
}
public void performConstantPropagation(HashMap currConstVars)
throws LanguageException
{
IterablePredicate ip = getIterPredicate();
// handle replacement in from expression
Expression replacementExpr = replaceConstantVar(ip.getFromExpr(), currConstVars);
if (replacementExpr != null)
ip.setFromExpr(replacementExpr);
// handle replacement in to expression
replacementExpr = replaceConstantVar(ip.getToExpr(), currConstVars);
if (replacementExpr != null)
ip.setToExpr(replacementExpr);
// handle replacement in increment expression
replacementExpr = replaceConstantVar(ip.getIncrementExpr(), currConstVars);
if (replacementExpr != null)
ip.setIncrementExpr(replacementExpr);
}
private Expression replaceConstantVar(Expression expr, HashMap currConstVars)
{
Expression ret = null;
if (expr instanceof DataIdentifier && !(expr instanceof IndexedIdentifier))
{
// check if the DataIdentifier variable is a ConstIdentifier
String identifierName = ((DataIdentifier)expr).getName();
if (currConstVars.containsKey(identifierName))
{
ConstIdentifier constValue = currConstVars.get(identifierName);
//AUTO CASTING (using runtime operations for consistency)
switch( constValue.getValueType() )
{
case DOUBLE:
ret = new IntIdentifier(new DoubleObject(((DoubleIdentifier)constValue).getValue()).getLongValue(),
expr.getFilename(), expr.getBeginLine(), expr.getBeginColumn(),
expr.getEndLine(), expr.getEndColumn());
break;
case INT:
ret = new IntIdentifier((IntIdentifier)constValue,
expr.getFilename(), expr.getBeginLine(), expr.getBeginColumn(),
expr.getEndLine(), expr.getEndColumn());
break;
case BOOLEAN:
ret = new IntIdentifier(new BooleanObject(((BooleanIdentifier)constValue).getValue()).getLongValue(),
expr.getFilename(), expr.getBeginLine(), expr.getBeginColumn(),
expr.getEndLine(), expr.getEndColumn());
break;
default:
//do nothing
}
}
}
else
{
//do nothing, cannot replace full expression
ret = expr;
}
return ret;
}
/////////
// materialized hops recompilation flags
////
public void updatePredicateRecompilationFlags()
throws HopsException
{
if( ConfigurationManager.isDynamicRecompilation() ) {
_requiresFromRecompile = Recompiler.requiresRecompilation(getFromHops());
_requiresToRecompile = Recompiler.requiresRecompilation(getToHops());
_requiresIncrementRecompile = Recompiler.requiresRecompilation(getIncrementHops());
}
}
public boolean requiresFromRecompilation()
{
return _requiresFromRecompile;
}
public boolean requiresToRecompilation()
{
return _requiresToRecompile;
}
public boolean requiresIncrementRecompilation()
{
return _requiresIncrementRecompile;
}
}