All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.parser.WhileStatementBlock Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.parser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;

import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.lops.Lop;


public class WhileStatementBlock extends StatementBlock 
{
	
	private Hop _predicateHops;
	private Lop _predicateLops = null;
	private boolean _requiresPredicateRecompile = false;
	
	@Override
	public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap constVars, boolean conditional) 
		throws LanguageException, ParseException, IOException 
	{	
		if (_statements.size() > 1){
			raiseValidateError("WhileStatementBlock should have only 1 statement (while statement)", conditional);
		}
		
		WhileStatement wstmt = (WhileStatement) _statements.get(0);
		ConditionalPredicate predicate = wstmt.getConditionalPredicate();
		
		// Record original size information before loop for ALL variables 
		// Will compare size / type info for these after loop completes
		// Replace variables with changed size with unknown value 
		VariableSet origVarsBeforeBody = new VariableSet();
		for (String key : ids.getVariableNames()){
			DataIdentifier origId = ids.getVariable(key);
			DataIdentifier copyId = new DataIdentifier(origId);
			origVarsBeforeBody.addVariable(key, copyId);
		}
		
		//////////////////////////////////////////////////////////////////////////////
		// FIRST PASS: process the predicate / statement blocks in the body of the for statement
		///////////////////////////////////////////////////////////////////////////////
		
		//remove updated vars from constants
		for( String var : _updated.getVariableNames() )
			if( constVars.containsKey( var ) )
				constVars.remove( var );
		
		// process the statement blocks in the body of the while statement
		predicate.getPredicate().validateExpression(ids.getVariables(), constVars, conditional);
		ArrayList body = wstmt.getBody();
		
		_dmlProg = dmlProg;
		for(StatementBlock sb : body)
		{
			//always conditional
			ids = sb.validate(dmlProg, ids, constVars, true);
			constVars = sb.getConstOut();
		}
				
		if (!body.isEmpty()) {
			_constVarsIn.putAll(body.get(0).getConstIn());
			_constVarsOut.putAll(body.get(body.size()-1).getConstOut());
		}
		
		// for each updated variable 
		boolean revalidationRequired = false;
		for (String key : _updated.getVariableNames())
		{	
			DataIdentifier startVersion = origVarsBeforeBody.getVariable(key);
			DataIdentifier endVersion   = ids.getVariable(key);
			
			if (startVersion != null && endVersion != null)
			{	
				//handle data type change (reject) 
				if (!startVersion.getOutput().getDataType().equals(endVersion.getOutput().getDataType())){
					raiseValidateError("WhileStatementBlock has unsupported conditional data type change of variable '"+key+"' in loop body.", conditional);
				}
				
				//handle size change
				long startVersionDim1 	= (startVersion instanceof IndexedIdentifier)   ? ((IndexedIdentifier)startVersion).getOrigDim1() : startVersion.getDim1(); 
				long endVersionDim1		= (endVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)endVersion).getOrigDim1() : endVersion.getDim1(); 
				long startVersionDim2 	= (startVersion instanceof IndexedIdentifier)   ? ((IndexedIdentifier)startVersion).getOrigDim2() : startVersion.getDim2(); 
				long endVersionDim2		= (endVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)endVersion).getOrigDim2() : endVersion.getDim2(); 
				
				boolean sizeUnchanged = ((startVersionDim1 == endVersionDim1) &&
						                 (startVersionDim2 == endVersionDim2) );
				
				//handle sparsity change
				//NOTE: nnz not propagated via validate, and hence, we conservatively assume that nnz have been changed.
				//long startVersionNNZ 	= startVersion.getNnz();
				//long endVersionNNZ    = endVersion.getNnz(); 
				//boolean nnzUnchanged  = (startVersionNNZ == endVersionNNZ);
				boolean nnzUnchanged = false;
				
				// IF size has changed -- 
				if (!sizeUnchanged || !nnzUnchanged){
					revalidationRequired = true;
					DataIdentifier recVersion = new DataIdentifier(endVersion);
					if(!sizeUnchanged)
						recVersion.setDimensions(-1, -1);
					if(!nnzUnchanged)
						recVersion.setNnz(-1);
					origVarsBeforeBody.addVariable(key, recVersion);
				}
			}
		}
		
			
		// revalidation is required -- size was updated for at least 1 variable
		if (revalidationRequired)
		{
			// update ids to the reconciled values
			ids = origVarsBeforeBody;
		
			//////////////////////////////////////////////////////////////////////////////
			// SECOND PASS: process the predicate / statement blocks in the body of the for statement
			///////////////////////////////////////////////////////////////////////////////
		
			//remove updated vars from constants
			for( String var : _updated.getVariableNames() )
				if( constVars.containsKey( var ) )
					constVars.remove( var );
			
			// process the statement blocks in the body of the while statement
			predicate.getPredicate().validateExpression(ids.getVariables(), constVars, conditional);
			body = wstmt.getBody();
			
			_dmlProg = dmlProg;
			for(StatementBlock sb : body)
			{
				//always conditional
				ids = sb.validate(dmlProg, ids, constVars, true);
				constVars = sb.getConstOut();
			}
					
			if (!body.isEmpty()) {
				_constVarsIn.putAll(body.get(0).getConstIn());
				_constVarsOut.putAll(body.get(body.size()-1).getConstOut());
			}		
		}
		
		return ids;
	}


	public VariableSet initializeforwardLV(VariableSet activeInPassed) throws LanguageException {
		
		WhileStatement wstmt = (WhileStatement)_statements.get(0);
		if (_statements.size() > 1){
			LOG.error(_statements.get(0).printErrorLocation() + "WhileStatementBlock should have only 1 statement (while statement)");
			throw new LanguageException(_statements.get(0).printErrorLocation() + "WhileStatementBlock should have only 1 statement (while statement)");
		}
		
		_read = new VariableSet();
		_read.addVariables(wstmt.getConditionalPredicate().variablesRead());
		_updated.addVariables(wstmt.getConditionalPredicate().variablesUpdated());
		
		_gen = new VariableSet();
		_gen.addVariables(wstmt.getConditionalPredicate().variablesRead());
				
		VariableSet current = new VariableSet();
		current.addVariables(activeInPassed);
		
		for( StatementBlock sb : wstmt.getBody() )
		{
			current = sb.initializeforwardLV(current);	
			
			// for each generated variable in this block, check variable not killed
			// in prior statement block in while stmt blody
			for (String varName : sb._gen.getVariableNames()){
				
				// IF the variable is NOT set in the while loop PRIOR to this stmt block, 
				// THEN needs to be generated
				if (!_kill.getVariableNames().contains(varName)){
					_gen.addVariable(varName, sb._gen.getVariable(varName));	
				}
			}
			
			_read.addVariables(sb._read);
			_updated.addVariables(sb._updated);
		
			// only add kill variables for statement blocks guaranteed to execute
			if (!(sb instanceof WhileStatementBlock) && !(sb instanceof ForStatementBlock) ){
				_kill.addVariables(sb._kill);
			}	
		}
		
		// set preliminary "warn" set -- variables that if used later may cause runtime error
		// if the loop is not executed
		// warnSet = (updated MINUS (updatedIfBody INTERSECT updatedElseBody)) MINUS current
		for (String varName : _updated.getVariableNames()){
			if (!activeInPassed.containsVariable(varName)) {
				_warnSet.addVariable(varName, _updated.getVariable(varName));
			}
		}
		
		// activeOut includes variables from passed live in and updated in the while body
		_liveOut = new VariableSet();
		_liveOut.addVariables(current);
		_liveOut.addVariables(_updated);
		return _liveOut;
	}

	public VariableSet initializebackwardLV(VariableSet loPassed) throws LanguageException{
		
		WhileStatement wstmt = (WhileStatement)_statements.get(0);
			
		VariableSet lo = new VariableSet();
		lo.addVariables(loPassed);
		
		// calls analyze for each statement block in while stmt body
		int numBlocks = wstmt.getBody().size();
		for (int i = numBlocks - 1; i >= 0; i--){
			lo = wstmt.getBody().get(i).analyze(lo);
		}	
		
		VariableSet loReturn = new VariableSet();
		loReturn.addVariables(lo);
		return loReturn;
	
	}
	
	public void setPredicateHops(Hop hops) {
		_predicateHops = hops;
	}
	
	public ArrayList get_hops() throws HopsException {
		
		if (_hops != null && !_hops.isEmpty()){
			LOG.error(this._statements.get(0).printErrorLocation() + "there should be no HOPs associated with the WhileStatementBlock");
			throw new HopsException(this._statements.get(0).printErrorLocation() + "there should be no HOPs associated with the WhileStatementBlock");
		}
		
		return _hops;
	}
	
	public Hop getPredicateHops(){
		return _predicateHops;
	}
	
	public Lop get_predicateLops() {
		return _predicateLops;
	}

	public void set_predicateLops(Lop predicateLops) {
		_predicateLops = predicateLops;
	}
	
	public VariableSet analyze(VariableSet loPassed) throws LanguageException{
	 		
		VariableSet predVars = new VariableSet();
		predVars.addVariables(((WhileStatement)_statements.get(0)).getConditionalPredicate().variablesRead());
		predVars.addVariables(((WhileStatement)_statements.get(0)).getConditionalPredicate().variablesUpdated());
		
		VariableSet candidateLO = new VariableSet();
		candidateLO.addVariables(loPassed);
		candidateLO.addVariables(_gen);
		candidateLO.addVariables(predVars);
		
		VariableSet origLiveOut = new VariableSet();
		origLiveOut.addVariables(_liveOut);
		origLiveOut.addVariables(predVars);
		origLiveOut.addVariables(_gen);
		
		_liveOut = new VariableSet();
	 	for (String name : candidateLO.getVariableNames()){
	 		if (origLiveOut.containsVariable(name)){
	 			_liveOut.addVariable(name, candidateLO.getVariable(name));
	 		}
	 	}
	 	
		initializebackwardLV(_liveOut);
		
		// set final warnSet: remove variables NOT in live out
		VariableSet finalWarnSet = new VariableSet();
		for (String varName : _warnSet.getVariableNames()){
			if (_liveOut.containsVariable(varName)){
				finalWarnSet.addVariable(varName,_warnSet.getVariable(varName));
			}
		}
		_warnSet = finalWarnSet;
		
		// for now just print the warn set
		for (String varName : _warnSet.getVariableNames()){
			LOG.warn(_warnSet.getVariable(varName).printWarningLocation() + "Initialization of " + varName + " depends on while execution");
		}
		
		// Cannot remove kill variables
		_liveIn = new VariableSet();
		_liveIn.addVariables(_liveOut);
		_liveIn.addVariables(_gen);
		
		VariableSet liveInReturn = new VariableSet();
		liveInReturn.addVariables(_liveIn);
		
		return liveInReturn;
	}
	
	/////////
	// materialized hops recompilation flags
	////
	
	public void updatePredicateRecompilationFlag() 
		throws HopsException
	{
		_requiresPredicateRecompile =  OptimizerUtils.ALLOW_DYN_RECOMPILATION 
			                           && Recompiler.requiresRecompilation(getPredicateHops());
	}
	
	public boolean requiresPredicateRecompilation()
	{
		return _requiresPredicateRecompile;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy