org.apache.sysml.parser.DMLTranslator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.parser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataGenOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.FunctionOp.FunctionType;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.DataGenMethod;
import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.Hop.OpOp3;
import org.apache.sysml.hops.Hop.ParamBuiltinOp;
import org.apache.sysml.hops.Hop.ReOrgOp;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LeftIndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.MemoTable;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.ReorgOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.ipa.InterProceduralAnalysis;
import org.apache.sysml.hops.rewrite.ProgramRewriter;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.FormatType;
import org.apache.sysml.parser.Expression.ParameterizedBuiltinFunctionOp;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.parser.PrintStatement.PRINTTYPE;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.hops.ConvolutionOp;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.Expression.BuiltinFunctionOp;
public class DMLTranslator
{
private static final Log LOG = LogFactory.getLog(DMLTranslator.class.getName());
private DMLProgram _dmlProg = null;
public DMLTranslator(DMLProgram dmlp)
throws DMLRuntimeException
{
_dmlProg = dmlp;
//setup default size for unknown dimensions
OptimizerUtils.resetDefaultSize();
//reinit rewriter according to opt level flags
Recompiler.reinitRecompiler();
}
/**
* Validate parse tree
*
* @throws LanguageException
* @throws IOException
*/
public void validateParseTree(DMLProgram dmlp)
throws LanguageException, ParseException, IOException
{
//STEP1: Pre-processing steps for validate - e.g., prepare read-after-write meta data
boolean fWriteRead = prepareReadAfterWrite(dmlp, new HashMap());
//STEP2: Actual Validate
// handle functions in namespaces (current program has default namespace)
for (String namespaceKey : dmlp.getNamespaces().keySet()){
// for each function defined in the namespace
for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
FunctionStatementBlock fblock = dmlp.getFunctionStatementBlock(namespaceKey,fname);
HashMap constVars = new HashMap();
VariableSet vs = new VariableSet();
// add the input variables for the function to input variable list
FunctionStatement fstmt = (FunctionStatement)fblock.getStatement(0);
if (fblock.getNumStatements() > 1){
LOG.error(fstmt.printErrorLocation() + "FunctionStatementBlock can only have 1 FunctionStatement");
throw new LanguageException(fstmt.printErrorLocation() + "FunctionStatementBlock can only have 1 FunctionStatement");
}
for (DataIdentifier currVar : fstmt.getInputParams()) {
if (currVar.getDataType() == DataType.SCALAR){
currVar.setDimensions(0, 0);
}
vs.addVariable(currVar.getName(), currVar);
}
fblock.validate(dmlp, vs, constVars, false);
}
}
// handle regular blocks -- "main" program
VariableSet vs = new VariableSet();
HashMap constVars = new HashMap();
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock sb = dmlp.getStatementBlock(i);
vs = sb.validate(dmlp, vs, constVars, fWriteRead);
constVars = sb.getConstOut();
}
//STEP3: Post-processing steps after validate - e.g., prepare read-after-write meta data
if( fWriteRead )
{
//propagate size and datatypes into read
prepareReadAfterWrite(dmlp, new HashMap());
//re-validate main program for datatype propagation
vs = new VariableSet();
constVars = new HashMap();
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock sb = dmlp.getStatementBlock(i);
vs = sb.validate(dmlp, vs, constVars, fWriteRead);
constVars = sb.getConstOut();
}
}
return;
}
public void liveVariableAnalysis(DMLProgram dmlp) throws LanguageException {
// for each namespace, handle function program blocks -- forward direction
for (String namespaceKey : dmlp.getNamespaces().keySet()) {
for (String fname: dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
FunctionStatementBlock fsb = dmlp.getFunctionStatementBlock(namespaceKey, fname);
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
// perform function inlining
fstmt.setBody(StatementBlock.mergeFunctionCalls(fstmt.getBody(), dmlp));
VariableSet activeIn = new VariableSet();
for (DataIdentifier id : fstmt.getInputParams()){
activeIn.addVariable(id.getName(), id);
}
fsb.initializeforwardLV(activeIn);
}
}
// for each namespace, handle function program blocks -- backward direction
for (String namespaceKey : dmlp.getNamespaces().keySet()) {
for (String fname: dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
// add output variables to liveout / activeout set
FunctionStatementBlock fsb = dmlp.getFunctionStatementBlock(namespaceKey, fname);
VariableSet currentLiveOut = new VariableSet();
VariableSet currentLiveIn = new VariableSet();
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
for (DataIdentifier id : fstmt.getInputParams())
currentLiveIn.addVariable(id.getName(), id);
for (DataIdentifier id : fstmt.getOutputParams())
currentLiveOut.addVariable(id.getName(), id);
fsb._liveOut = currentLiveOut;
fsb.analyze(currentLiveIn, currentLiveOut);
}
}
// handle regular program blocks
VariableSet currentLiveOut = new VariableSet();
VariableSet activeIn = new VariableSet();
// handle function inlining
dmlp.setStatementBlocks(StatementBlock.mergeFunctionCalls(dmlp.getStatementBlocks(), dmlp));
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock sb = dmlp.getStatementBlock(i);
activeIn = sb.initializeforwardLV(activeIn);
}
if (dmlp.getNumStatementBlocks() > 0){
StatementBlock lastSb = dmlp.getStatementBlock(dmlp.getNumStatementBlocks() - 1);
lastSb._liveOut = new VariableSet();
for (int i = dmlp.getNumStatementBlocks() - 1; i >= 0; i--) {
StatementBlock sb = dmlp.getStatementBlock(i);
currentLiveOut = sb.analyze(currentLiveOut);
}
}
return;
}
/**
* Construct Hops from parse tree
*
* @throws ParseException
*/
public void constructHops(DMLProgram dmlp)
throws ParseException, LanguageException
{
// Step 1: construct hops for all functions
// for each namespace, handle function program blocks
for (String namespaceKey : dmlp.getNamespaces().keySet()){
for (String fname: dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
FunctionStatementBlock current = dmlp.getFunctionStatementBlock(namespaceKey, fname);
constructHops(current);
}
}
// Step 2: construct hops for main program
// handle regular program blocks
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock current = dmlp.getStatementBlock(i);
constructHops(current);
}
}
/**
*
* @param dmlp
* @throws ParseException
* @throws LanguageException
* @throws HopsException
*/
public void rewriteHopsDAG(DMLProgram dmlp)
throws ParseException, LanguageException, HopsException
{
//apply hop rewrites (static rewrites)
ProgramRewriter rewriter = new ProgramRewriter(true, false);
rewriter.rewriteProgramHopDAGs(dmlp);
resetHopsDAGVisitStatus(dmlp);
//propagate size information from main into functions (but conservatively)
if( OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS ) {
InterProceduralAnalysis ipa = new InterProceduralAnalysis();
ipa.analyzeProgram(dmlp);
resetHopsDAGVisitStatus(dmlp);
}
//apply hop rewrites (dynamic rewrites, after IPA)
ProgramRewriter rewriter2 = new ProgramRewriter(false, true);
rewriter2.rewriteProgramHopDAGs(dmlp);
resetHopsDAGVisitStatus(dmlp);
// Compute memory estimates for all the hops. These estimates are used
// subsequently in various optimizations, e.g. CP vs. MR scheduling and parfor.
refreshMemEstimates(dmlp);
resetHopsDAGVisitStatus(dmlp);
}
public void constructLops(DMLProgram dmlp) throws ParseException, LanguageException, HopsException, LopsException {
// for each namespace, handle function program blocks handle function
for (String namespaceKey : dmlp.getNamespaces().keySet()){
for (String fname: dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
FunctionStatementBlock current = dmlp.getFunctionStatementBlock(namespaceKey, fname);
constructLops(current);
}
}
// handle regular program blocks
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock current = dmlp.getStatementBlock(i);
constructLops(current);
}
}
/**
*
* @param sb
* @throws HopsException
* @throws LopsException
*/
public void constructLops(StatementBlock sb)
throws HopsException, LopsException
{
if (sb instanceof WhileStatementBlock)
{
WhileStatementBlock wsb = (WhileStatementBlock)sb;
WhileStatement whileStmt = (WhileStatement)wsb.getStatement(0);
ArrayList body = whileStmt.getBody();
if (sb.get_hops() != null && !sb.get_hops().isEmpty()) {
LOG.error(sb.printBlockErrorLocation() + "WhileStatementBlock should not have hops");
throw new HopsException(sb.printBlockErrorLocation() + "WhileStatementBlock should not have hops");
}
// step through stmt blocks in while stmt body
for (StatementBlock stmtBlock : body){
constructLops(stmtBlock);
}
// handle while stmt predicate
Lop l = wsb.getPredicateHops().constructLops();
wsb.set_predicateLops(l);
wsb.updatePredicateRecompilationFlag();
}
else if (sb instanceof IfStatementBlock)
{
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement ifStmt = (IfStatement)isb.getStatement(0);
ArrayList ifBody = ifStmt.getIfBody();
ArrayList elseBody = ifStmt.getElseBody();
if (sb.get_hops() != null && !sb.get_hops().isEmpty()){
LOG.error(sb.printBlockErrorLocation() + "IfStatementBlock should not have hops");
throw new HopsException(sb.printBlockErrorLocation() + "IfStatementBlock should not have hops");
}
// step through stmt blocks in if stmt ifBody
for (StatementBlock stmtBlock : ifBody)
constructLops(stmtBlock);
// step through stmt blocks in if stmt elseBody
for (StatementBlock stmtBlock : elseBody)
constructLops(stmtBlock);
// handle if stmt predicate
Lop l = isb.getPredicateHops().constructLops();
isb.set_predicateLops(l);
isb.updatePredicateRecompilationFlag();
}
else if (sb instanceof ForStatementBlock) //NOTE: applies to ForStatementBlock and ParForStatementBlock
{
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fs = (ForStatement)sb.getStatement(0);
ArrayList body = fs.getBody();
if (sb.get_hops() != null && !sb.get_hops().isEmpty() ) {
LOG.error(sb.printBlockErrorLocation() + "ForStatementBlock should not have hops");
throw new HopsException(sb.printBlockErrorLocation() + "ForStatementBlock should not have hops");
}
// step through stmt blocks in FOR stmt body
for (StatementBlock stmtBlock : body)
constructLops(stmtBlock);
// handle for stmt predicate
if (fsb.getFromHops() != null){
Lop llobs = fsb.getFromHops().constructLops();
fsb.setFromLops(llobs);
}
if (fsb.getToHops() != null){
Lop llobs = fsb.getToHops().constructLops();
fsb.setToLops(llobs);
}
if (fsb.getIncrementHops() != null){
Lop llobs = fsb.getIncrementHops().constructLops();
fsb.setIncrementLops(llobs);
}
fsb.updatePredicateRecompilationFlags();
}
else if (sb instanceof FunctionStatementBlock){
FunctionStatement functStmt = (FunctionStatement)sb.getStatement(0);
ArrayList body = functStmt.getBody();
if (sb.get_hops() != null && !sb.get_hops().isEmpty()) {
LOG.error(sb.printBlockErrorLocation() + "FunctionStatementBlock should not have hops");
throw new HopsException(sb.printBlockErrorLocation() + "FunctionStatementBlock should not have hops");
}
// step through stmt blocks in while stmt body
for (StatementBlock stmtBlock : body){
constructLops(stmtBlock);
}
}
// handle default case for regular StatementBlock
else {
if (sb.get_hops() == null)
sb.set_hops(new ArrayList());
ArrayList lops = new ArrayList();
for (Hop hop : sb.get_hops()) {
lops.add(hop.constructLops());
}
sb.setLops(lops);
sb.updateRecompilationFlag();
}
} // end method
public void printLops(DMLProgram dmlp) throws ParseException, LanguageException, HopsException, LopsException {
if (LOG.isDebugEnabled()){
// for each namespace, handle function program blocks
for (String namespaceKey : dmlp.getNamespaces().keySet()){
for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()){
FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname);
printLops(fsblock);
}
}
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock current = dmlp.getStatementBlock(i);
printLops(current);
}
}
}
public void printLops(StatementBlock current) throws ParseException, HopsException, LopsException {
if (LOG.isDebugEnabled()){
ArrayList lopsDAG = current.getLops();
LOG.debug("\n********************** LOPS DAG FOR BLOCK *******************");
if (current instanceof FunctionStatementBlock) {
if (current.getNumStatements() > 1)
LOG.debug("Function statement block has more than 1 stmt");
FunctionStatement fstmt = (FunctionStatement)current.getStatement(0);
for (StatementBlock child : fstmt.getBody()){
printLops(child);
}
}
if (current instanceof WhileStatementBlock) {
// print predicate lops
WhileStatementBlock wstb = (WhileStatementBlock) current;
Hop predicateHops = ((WhileStatementBlock) current).getPredicateHops();
LOG.debug("\n********************** PREDICATE LOPS *******************");
Lop predicateLops = predicateHops.getLops();
if (predicateLops == null)
predicateLops = predicateHops.constructLops();
predicateLops.printMe();
if (wstb.getNumStatements() > 1){
LOG.error(wstb.printBlockErrorLocation() + "WhileStatementBlock has more than 1 statement");
throw new HopsException(wstb.printBlockErrorLocation() + "WhileStatementBlock has more than 1 statement");
}
WhileStatement ws = (WhileStatement)wstb.getStatement(0);
for (StatementBlock sb : ws.getBody()){
printLops(sb);
}
}
if (current instanceof IfStatementBlock) {
// print predicate lops
IfStatementBlock istb = (IfStatementBlock) current;
Hop predicateHops = ((IfStatementBlock) current).getPredicateHops();
LOG.debug("\n********************** PREDICATE LOPS *******************");
Lop predicateLops = predicateHops.getLops();
if (predicateLops == null)
predicateLops = predicateHops.constructLops();
predicateLops.printMe();
if (istb.getNumStatements() > 1){
LOG.error(istb.printBlockErrorLocation() + "IfStatmentBlock has more than 1 statement");
throw new HopsException(istb.printBlockErrorLocation() + "IfStatmentBlock has more than 1 statement");
}
IfStatement is = (IfStatement)istb.getStatement(0);
LOG.debug("\n**** LOPS DAG FOR IF BODY ****");
for (StatementBlock sb : is.getIfBody()){
printLops(sb);
}
if ( !is.getElseBody().isEmpty() ){
LOG.debug("\n**** LOPS DAG FOR IF BODY ****");
for (StatementBlock sb : is.getElseBody()){
printLops(sb);
}
}
}
if (current instanceof ForStatementBlock) {
// print predicate lops
ForStatementBlock fsb = (ForStatementBlock) current;
LOG.debug("\n********************** PREDICATE LOPS *******************");
if( fsb.getFromHops() != null ){
LOG.debug("FROM:");
Lop llops = fsb.getFromLops();
if( llops == null )
llops = fsb.getFromHops().constructLops();
llops.printMe();
}
if( fsb.getToHops() != null ){
LOG.debug("TO:");
Lop llops = fsb.getToLops();
if( llops == null )
llops = fsb.getToHops().constructLops();
llops.printMe();
}
if( fsb.getIncrementHops() != null ){
LOG.debug("INCREMENT:");
Lop llops = fsb.getIncrementLops();
if( llops == null )
llops = fsb.getIncrementHops().constructLops();
llops.printMe();
}
if (fsb.getNumStatements() > 1){
LOG.error(fsb.printBlockErrorLocation() + "ForStatementBlock has more than 1 statement");
throw new HopsException(fsb.printBlockErrorLocation() + "ForStatementBlock has more than 1 statement");
}
ForStatement ws = (ForStatement)fsb.getStatement(0);
for (StatementBlock sb : ws.getBody()){
printLops(sb);
}
}
if (lopsDAG != null && !lopsDAG.isEmpty() ) {
Iterator iter = lopsDAG.iterator();
while (iter.hasNext()) {
LOG.debug("\n********************** OUTPUT LOPS *******************");
iter.next().printMe();
}
}
}
}
public void printHops(DMLProgram dmlp) throws ParseException, LanguageException, HopsException {
if (LOG.isDebugEnabled()) {
// for each namespace, handle function program blocks
for (String namespaceKey : dmlp.getNamespaces().keySet()){
for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()){
FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname);
printHops(fsblock);
}
}
// hand
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock current = dmlp.getStatementBlock(i);
printHops(current);
}
}
}
public void printHops(StatementBlock current) throws ParseException, HopsException {
if (LOG.isDebugEnabled()) {
ArrayList hopsDAG = current.get_hops();
LOG.debug("\n********************** HOPS DAG FOR BLOCK *******************");
if (current instanceof FunctionStatementBlock) {
if (current.getNumStatements() > 1)
LOG.debug("Function statement block has more than 1 stmt");
FunctionStatement fstmt = (FunctionStatement)current.getStatement(0);
for (StatementBlock child : fstmt.getBody()){
printHops(child);
}
}
if (current instanceof WhileStatementBlock) {
// print predicate hops
WhileStatementBlock wstb = (WhileStatementBlock) current;
Hop predicateHops = wstb.getPredicateHops();
LOG.debug("\n********************** PREDICATE HOPS *******************");
predicateHops.printMe();
if (wstb.getNumStatements() > 1)
LOG.debug("While statement block has more than 1 stmt");
WhileStatement ws = (WhileStatement)wstb.getStatement(0);
for (StatementBlock sb : ws.getBody()){
printHops(sb);
}
}
if (current instanceof IfStatementBlock) {
// print predicate hops
IfStatementBlock istb = (IfStatementBlock) current;
Hop predicateHops = istb.getPredicateHops();
LOG.debug("\n********************** PREDICATE HOPS *******************");
predicateHops.printMe();
if (istb.getNumStatements() > 1)
LOG.debug("If statement block has more than 1 stmt");
IfStatement is = (IfStatement)istb.getStatement(0);
for (StatementBlock sb : is.getIfBody()){
printHops(sb);
}
for (StatementBlock sb : is.getElseBody()){
printHops(sb);
}
}
if (current instanceof ForStatementBlock) {
// print predicate hops
ForStatementBlock fsb = (ForStatementBlock) current;
LOG.debug("\n********************** PREDICATE HOPS *******************");
if (fsb.getFromHops() != null) fsb.getFromHops().printMe();
if (fsb.getToHops() != null) fsb.getToHops().printMe();
if (fsb.getIncrementHops() != null) fsb.getIncrementHops().printMe();
if (fsb.getNumStatements() > 1)
LOG.debug("For statement block has more than 1 stmt");
ForStatement ws = (ForStatement)fsb.getStatement(0);
for (StatementBlock sb : ws.getBody()){
printHops(sb);
}
}
if (hopsDAG != null && !hopsDAG.isEmpty()) {
// hopsDAG.iterator().next().printMe();
Iterator iter = hopsDAG.iterator();
while (iter.hasNext()) {
LOG.debug("\n********************** OUTPUT HOPS *******************");
iter.next().printMe();
}
}
}
}
public void refreshMemEstimates(DMLProgram dmlp) throws ParseException, LanguageException, HopsException {
// for each namespace, handle function program blocks -- forward direction
for (String namespaceKey : dmlp.getNamespaces().keySet()){
for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()){
FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname);
refreshMemEstimates(fsblock);
}
}
// handle statement blocks in "main" method
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock current = dmlp.getStatementBlock(i);
refreshMemEstimates(current);
}
}
public void refreshMemEstimates(StatementBlock current) throws ParseException, HopsException {
MemoTable memo = new MemoTable();
ArrayList hopsDAG = current.get_hops();
if (hopsDAG != null && !hopsDAG.isEmpty()) {
Iterator iter = hopsDAG.iterator();
while (iter.hasNext()) {
iter.next().refreshMemEstimates(memo);
}
}
if (current instanceof FunctionStatementBlock) {
FunctionStatement fstmt = (FunctionStatement)current.getStatement(0);
for (StatementBlock sb : fstmt.getBody()){
refreshMemEstimates(sb);
}
}
if (current instanceof WhileStatementBlock) {
// handle predicate
WhileStatementBlock wstb = (WhileStatementBlock) current;
wstb.getPredicateHops().refreshMemEstimates(new MemoTable());
if (wstb.getNumStatements() > 1)
LOG.debug("While statement block has more than 1 stmt");
WhileStatement ws = (WhileStatement)wstb.getStatement(0);
for (StatementBlock sb : ws.getBody()){
refreshMemEstimates(sb);
}
}
if (current instanceof IfStatementBlock) {
// handle predicate
IfStatementBlock istb = (IfStatementBlock) current;
istb.getPredicateHops().refreshMemEstimates(new MemoTable());
if (istb.getNumStatements() > 1)
LOG.debug("If statement block has more than 1 stmt");
IfStatement is = (IfStatement)istb.getStatement(0);
for (StatementBlock sb : is.getIfBody()){
refreshMemEstimates(sb);
}
for (StatementBlock sb : is.getElseBody()){
refreshMemEstimates(sb);
}
}
if (current instanceof ForStatementBlock) {
// handle predicate
ForStatementBlock fsb = (ForStatementBlock) current;
if (fsb.getFromHops() != null)
fsb.getFromHops().refreshMemEstimates(new MemoTable());
if (fsb.getToHops() != null)
fsb.getToHops().refreshMemEstimates(new MemoTable());
if (fsb.getIncrementHops() != null)
fsb.getIncrementHops().refreshMemEstimates(new MemoTable());
if (fsb.getNumStatements() > 1)
LOG.debug("For statement block has more than 1 stmt");
ForStatement ws = (ForStatement)fsb.getStatement(0);
for (StatementBlock sb : ws.getBody()){
refreshMemEstimates(sb);
}
}
}
public static void resetHopsDAGVisitStatus(DMLProgram dmlp) throws ParseException, LanguageException, HopsException {
// for each namespace, handle function program blocks -- forward direction
for (String namespaceKey : dmlp.getNamespaces().keySet()){
for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()){
FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname);
resetHopsDAGVisitStatus(fsblock);
}
}
// handle statement blocks in "main" method
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock current = dmlp.getStatementBlock(i);
resetHopsDAGVisitStatus(current);
}
}
public static void resetHopsDAGVisitStatus(StatementBlock current) throws ParseException, HopsException {
ArrayList hopsDAG = current.get_hops();
if (hopsDAG != null && !hopsDAG.isEmpty() ) {
Hop.resetVisitStatus(hopsDAG);
}
if (current instanceof FunctionStatementBlock) {
FunctionStatement fstmt = (FunctionStatement)current.getStatement(0);
for (StatementBlock sb : fstmt.getBody()){
resetHopsDAGVisitStatus(sb);
}
}
if (current instanceof WhileStatementBlock) {
// handle predicate
WhileStatementBlock wstb = (WhileStatementBlock) current;
wstb.getPredicateHops().resetVisitStatus();
if (wstb.getNumStatements() > 1)
LOG.debug("While stmt block has more than 1 stmt");
WhileStatement ws = (WhileStatement)wstb.getStatement(0);
for (StatementBlock sb : ws.getBody()){
resetHopsDAGVisitStatus(sb);
}
}
if (current instanceof IfStatementBlock) {
// handle predicate
IfStatementBlock istb = (IfStatementBlock) current;
istb.getPredicateHops().resetVisitStatus();
if (istb.getNumStatements() > 1)
LOG.debug("If statement block has more than 1 stmt");
IfStatement is = (IfStatement)istb.getStatement(0);
for (StatementBlock sb : is.getIfBody()){
resetHopsDAGVisitStatus(sb);
}
for (StatementBlock sb : is.getElseBody()){
resetHopsDAGVisitStatus(sb);
}
}
if (current instanceof ForStatementBlock) {
// handle predicate
ForStatementBlock fsb = (ForStatementBlock) current;
if (fsb.getFromHops() != null)
fsb.getFromHops().resetVisitStatus();
if (fsb.getToHops() != null)
fsb.getToHops().resetVisitStatus();
if (fsb.getIncrementHops() != null)
fsb.getIncrementHops().resetVisitStatus();
if (fsb.getNumStatements() > 1)
LOG.debug("For statment block has more than 1 stmt");
ForStatement ws = (ForStatement)fsb.getStatement(0);
for (StatementBlock sb : ws.getBody()){
resetHopsDAGVisitStatus(sb);
}
}
}
public void resetLopsDAGVisitStatus(DMLProgram dmlp) throws HopsException, LanguageException {
// for each namespace, handle function program blocks
for (String namespaceKey : dmlp.getNamespaces().keySet()){
for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()){
FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname);
resetLopsDAGVisitStatus(fsblock);
}
}
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock current = dmlp.getStatementBlock(i);
resetLopsDAGVisitStatus(current);
}
}
public void resetLopsDAGVisitStatus(StatementBlock current) throws HopsException {
ArrayList hopsDAG = current.get_hops();
if (hopsDAG != null && !hopsDAG.isEmpty() ) {
Iterator iter = hopsDAG.iterator();
while (iter.hasNext()){
Hop currentHop = iter.next();
currentHop.getLops().resetVisitStatus();
}
}
if (current instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) current;
FunctionStatement fs = (FunctionStatement)fsb.getStatement(0);
for (StatementBlock sb : fs.getBody()){
resetLopsDAGVisitStatus(sb);
}
}
if (current instanceof WhileStatementBlock) {
WhileStatementBlock wstb = (WhileStatementBlock) current;
wstb.get_predicateLops().resetVisitStatus();
if (wstb.getNumStatements() > 1)
LOG.debug("While statement block has more than 1 stmt");
WhileStatement ws = (WhileStatement)wstb.getStatement(0);
for (StatementBlock sb : ws.getBody()){
resetLopsDAGVisitStatus(sb);
}
}
if (current instanceof IfStatementBlock) {
IfStatementBlock istb = (IfStatementBlock) current;
istb.get_predicateLops().resetVisitStatus();
if (istb.getNumStatements() > 1)
LOG.debug("If statement block has more than 1 stmt");
IfStatement is = (IfStatement)istb.getStatement(0);
for (StatementBlock sb : is.getIfBody()){
resetLopsDAGVisitStatus(sb);
}
for (StatementBlock sb : is.getElseBody()){
resetLopsDAGVisitStatus(sb);
}
}
if (current instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) current;
if (fsb.getFromLops() != null)
fsb.getFromLops().resetVisitStatus();
if (fsb.getToLops() != null)
fsb.getToLops().resetVisitStatus();
if (fsb.getIncrementLops() != null)
fsb.getIncrementLops().resetVisitStatus();
if (fsb.getNumStatements() > 1)
LOG.debug("For statement block has more than 1 stmt");
ForStatement ws = (ForStatement)fsb.getStatement(0);
for (StatementBlock sb : ws.getBody()){
resetLopsDAGVisitStatus(sb);
}
}
}
public void constructHops(StatementBlock sb)
throws ParseException, LanguageException {
if (sb instanceof WhileStatementBlock) {
constructHopsForWhileControlBlock((WhileStatementBlock) sb);
return;
}
if (sb instanceof IfStatementBlock) {
constructHopsForIfControlBlock((IfStatementBlock) sb);
return;
}
if (sb instanceof ForStatementBlock) { //NOTE: applies to ForStatementBlock and ParForStatementBlock
constructHopsForForControlBlock((ForStatementBlock) sb);
return;
}
if (sb instanceof FunctionStatementBlock) {
constructHopsForFunctionControlBlock((FunctionStatementBlock) sb);
return;
}
HashMap ids = new HashMap();
ArrayList output = new ArrayList();
VariableSet liveIn = sb.liveIn();
VariableSet liveOut = sb.liveOut();
VariableSet updated = sb._updated;
VariableSet gen = sb._gen;
VariableSet updatedLiveOut = new VariableSet();
// handle liveout variables that are updated --> target identifiers for Assignment
HashMap liveOutToTemp = new HashMap();
for (int i = 0; i < sb.getNumStatements(); i++) {
Statement current = sb.getStatement(i);
if (current instanceof AssignmentStatement) {
AssignmentStatement as = (AssignmentStatement) current;
DataIdentifier target = as.getTarget();
if (liveOut.containsVariable(target.getName())) {
liveOutToTemp.put(target.getName(), Integer.valueOf(i));
}
}
if (current instanceof MultiAssignmentStatement) {
MultiAssignmentStatement mas = (MultiAssignmentStatement) current;
for (DataIdentifier target : mas.getTargetList()){
if (liveOut.containsVariable(target.getName())) {
liveOutToTemp.put(target.getName(), Integer.valueOf(i));
}
}
}
}
// only create transient read operations for variables either updated or read-before-update
// (i.e., from LV analysis, updated and gen sets)
if ( !liveIn.getVariables().values().isEmpty() ) {
for (String varName : liveIn.getVariables().keySet()) {
if (updated.containsVariable(varName) || gen.containsVariable(varName)){
DataIdentifier var = liveIn.getVariables().get(varName);
long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim1() : var.getDim1();
long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim2() : var.getDim2();
DataOp read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), DataOpTypes.TRANSIENTREAD, null, actualDim1, actualDim2, var.getNnz(), var.getRowsInBlock(), var.getColumnsInBlock());
read.setAllPositions(var.getBeginLine(), var.getBeginColumn(), var.getEndLine(), var.getEndColumn());
ids.put(varName, read);
}
}
}
for( int i = 0; i < sb.getNumStatements(); i++ ) {
Statement current = sb.getStatement(i);
if (current instanceof OutputStatement) {
OutputStatement os = (OutputStatement) current;
DataExpression source = os.getSource();
DataIdentifier target = os.getIdentifier();
//error handling unsupported indexing expression in write statement
if( target instanceof IndexedIdentifier ) {
throw new LanguageException(source.printErrorLocation()+": Unsupported indexing expression in write statement. " +
"Please, assign the right indexing result to a variable and write this variable.");
}
DataOp ae = (DataOp)processExpression(source, target, ids);
String formatName = os.getExprParam(DataExpression.FORMAT_TYPE).toString();
ae.setInputFormatType(Expression.convertFormatType(formatName));
if (ae.getDataType() == DataType.SCALAR ) {
ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1, -1);
}
else {
switch(ae.getInputFormatType()) {
case TEXT:
case MM:
case CSV:
// write output in textcell format
ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1, -1);
break;
case BINARY:
// write output in binary block format
ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize());
break;
default:
throw new LanguageException("Unrecognized file format: " + ae.getInputFormatType());
}
}
output.add(ae);
}
if (current instanceof PrintStatement) {
PrintStatement ps = (PrintStatement) current;
Expression source = ps.getExpression();
PRINTTYPE ptype = ps.getType();
DataIdentifier target = createTarget();
target.setDataType(DataType.SCALAR);
target.setValueType(ValueType.STRING);
target.setAllPositions(current.getFilename(), current.getBeginLine(), target.getBeginColumn(), current.getEndLine(), current.getEndColumn());
Hop ae = processExpression(source, target, ids);
try {
Hop.OpOp1 op = (ptype == PRINTTYPE.PRINT ? Hop.OpOp1.PRINT : Hop.OpOp1.STOP);
Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
printHop.setAllPositions(current.getBeginLine(), current.getBeginColumn(), current.getEndLine(), current.getEndColumn());
output.add(printHop);
} catch ( HopsException e ) {
throw new LanguageException(e);
}
}
if (current instanceof AssignmentStatement) {
AssignmentStatement as = (AssignmentStatement) current;
DataIdentifier target = as.getTarget();
Expression source = as.getSource();
// CASE: regular assignment statement -- source is DML expression that is NOT user-defined or external function
if (!(source instanceof FunctionCallIdentifier)){
// CASE: target is regular data identifier
if (!(target instanceof IndexedIdentifier)) {
Hop ae = processExpression(source, target, ids);
ids.put(target.getName(), ae);
target.setProperties(source.getOutput());
Integer statementId = liveOutToTemp.get(target.getName());
if ((statementId != null) && (statementId.intValue() == i)) {
DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null);
transientwrite.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ae.getRowsInBlock(), ae.getColsInBlock());
transientwrite.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndLine());
updatedLiveOut.addVariable(target.getName(), target);
output.add(transientwrite);
}
} // end if (!(target instanceof IndexedIdentifier)) {
// CASE: target is indexed identifier (left-hand side indexed expression)
else {
Hop ae = processLeftIndexedExpression(source, (IndexedIdentifier)target, ids);
ids.put(target.getName(), ae);
// obtain origDim values BEFORE they are potentially updated during setProperties call
// (this is incorrect for LHS Indexing)
long origDim1 = ((IndexedIdentifier)target).getOrigDim1();
long origDim2 = ((IndexedIdentifier)target).getOrigDim2();
target.setProperties(source.getOutput());
((IndexedIdentifier)target).setOriginalDimensions(origDim1, origDim2);
// preserve data type matrix of any index identifier
// (required for scalar input to left indexing)
if( target.getDataType() != DataType.MATRIX ) {
target.setDataType(DataType.MATRIX);
target.setValueType(ValueType.DOUBLE);
target.setBlockDimensions(ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize());
}
Integer statementId = liveOutToTemp.get(target.getName());
if ((statementId != null) && (statementId.intValue() == i)) {
DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null);
transientwrite.setOutputParams(origDim1, origDim2, ae.getNnz(), ae.getUpdateType(), ae.getRowsInBlock(), ae.getColsInBlock());
transientwrite.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndColumn());
updatedLiveOut.addVariable(target.getName(), target);
output.add(transientwrite);
}
}
}
else
{
//assignment, function call
FunctionCallIdentifier fci = (FunctionCallIdentifier) source;
FunctionStatementBlock fsb = this._dmlProg.getFunctionStatementBlock(fci.getNamespace(),fci.getName());
//error handling missing function
if (fsb == null){
String error = source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace();
LOG.error(error);
throw new LanguageException(error);
}
//error handling unsupported function call in indexing expression
if( target instanceof IndexedIdentifier ) {
String fkey = DMLProgram.constructFunctionKey(fci.getNamespace(),fci.getName());
throw new LanguageException("Unsupported function call to '"+fkey+"' in left indexing expression. " +
"Please, assign the function output to a variable.");
}
ArrayList finputs = new ArrayList();
for (ParameterExpression paramName : fci.getParamExprs()){
Hop in = processExpression(paramName.getExpr(), null, ids);
finputs.add(in);
}
//create function op
FunctionType ftype = fsb.getFunctionOpType();
FunctionOp fcall = new FunctionOp(ftype, fci.getNamespace(), fci.getName(), finputs, new String[]{target.getName()});
output.add(fcall);
//TODO function output dataops (phase 3)
//DataOp trFoutput = new DataOp(target.getName(), target.getDataType(), target.getValueType(), fcall, DataOpTypes.FUNCTIONOUTPUT, null);
//DataOp twFoutput = new DataOp(target.getName(), target.getDataType(), target.getValueType(), trFoutput, DataOpTypes.TRANSIENTWRITE, null);
}
}
else if (current instanceof MultiAssignmentStatement) {
//multi-assignment, by definition a function call
MultiAssignmentStatement mas = (MultiAssignmentStatement) current;
Expression source = mas.getSource();
if ( source instanceof FunctionCallIdentifier ) {
FunctionCallIdentifier fci = (FunctionCallIdentifier) source;
FunctionStatementBlock fsb = this._dmlProg.getFunctionStatementBlock(fci.getNamespace(),fci.getName());
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
if (fstmt == null){
LOG.error(source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace());
throw new LanguageException(source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace());
}
ArrayList finputs = new ArrayList();
for (ParameterExpression paramName : fci.getParamExprs()){
Hop in = processExpression(paramName.getExpr(), null, ids);
finputs.add(in);
}
//create function op
String[] foutputs = new String[mas.getTargetList().size()];
int count = 0;
for ( DataIdentifier paramName : mas.getTargetList() ){
foutputs[count++]=paramName.getName();
}
FunctionType ftype = fsb.getFunctionOpType();
FunctionOp fcall = new FunctionOp(ftype, fci.getNamespace(), fci.getName(), finputs, foutputs);
output.add(fcall);
//TODO function output dataops (phase 3)
/*for ( DataIdentifier paramName : mas.getTargetList() ){
DataOp twFoutput = new DataOp(paramName.getName(), paramName.getDataType(), paramName.getValueType(), fcall, DataOpTypes.TRANSIENTWRITE, null);
output.add(twFoutput);
}*/
}
else if ( source instanceof BuiltinFunctionExpression && ((BuiltinFunctionExpression)source).multipleReturns() ) {
// construct input hops
Hop fcall = processMultipleReturnBuiltinFunctionExpression((BuiltinFunctionExpression)source, mas.getTargetList(), ids);
output.add(fcall);
}
else if ( source instanceof ParameterizedBuiltinFunctionExpression && ((ParameterizedBuiltinFunctionExpression)source).multipleReturns() ) {
// construct input hops
Hop fcall = processMultipleReturnParameterizedBuiltinFunctionExpression((ParameterizedBuiltinFunctionExpression)source, mas.getTargetList(), ids);
output.add(fcall);
}
else
throw new LanguageException("Class \"" + source.getClass() + "\" is not supported in Multiple Assignment statements");
}
}
sb.updateLiveVariablesOut(updatedLiveOut);
sb.set_hops(output);
}
public void constructHopsForIfControlBlock(IfStatementBlock sb) throws ParseException, LanguageException {
IfStatement ifsb = (IfStatement) sb.getStatement(0);
ArrayList ifBody = ifsb.getIfBody();
ArrayList elseBody = ifsb.getElseBody();
// construct hops for predicate in if statement
constructHopsForConditionalPredicate(sb);
// handle if statement body
for( StatementBlock current : ifBody ) {
constructHops(current);
}
// handle else stmt body
for( StatementBlock current : elseBody ) {
constructHops(current);
}
}
/**
* Constructs Hops for a given ForStatementBlock or ParForStatementBlock, respectively.
*
* @param sb
* @throws ParseException
* @throws LanguageException
*/
public void constructHopsForForControlBlock(ForStatementBlock sb)
throws ParseException, LanguageException
{
ForStatement fs = (ForStatement) sb.getStatement(0);
ArrayList body = fs.getBody();
// construct hops for iterable predicate
constructHopsForIterablePredicate(sb);
for( StatementBlock current : body ) {
constructHops(current);
}
}
public void constructHopsForFunctionControlBlock(FunctionStatementBlock fsb) throws ParseException, LanguageException {
ArrayList body = ((FunctionStatement)fsb.getStatement(0)).getBody();
for( StatementBlock current : body ) {
constructHops(current);
}
}
public void constructHopsForWhileControlBlock(WhileStatementBlock sb)
throws ParseException, LanguageException {
ArrayList body = ((WhileStatement)sb.getStatement(0)).getBody();
// construct hops for while predicate
constructHopsForConditionalPredicate(sb);
for( StatementBlock current : body ) {
constructHops(current);
}
}
public void constructHopsForConditionalPredicate(StatementBlock passedSB) throws ParseException {
HashMap _ids = new HashMap();
// set conditional predicate
ConditionalPredicate cp = null;
if (passedSB instanceof WhileStatementBlock){
WhileStatement ws = (WhileStatement) ((WhileStatementBlock)passedSB).getStatement(0);
cp = ws.getConditionalPredicate();
}
else if (passedSB instanceof IfStatementBlock) {
IfStatement ws = (IfStatement) ((IfStatementBlock)passedSB).getStatement(0);
cp = ws.getConditionalPredicate();
}
else {
throw new ParseException("ConditionalPredicate expected only for while or if statements.");
}
VariableSet varsRead = cp.variablesRead();
for (String varName : varsRead.getVariables().keySet()) {
// creating transient read for live in variables
DataIdentifier var = passedSB.liveIn().getVariables().get(varName);
DataOp read = null;
if (var == null) {
LOG.error("variable " + varName + " not live variable for conditional predicate");
throw new ParseException("variable " + varName + " not live variable for conditional predicate");
} else {
long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim1() : var.getDim1();
long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim2() : var.getDim2();
read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), DataOpTypes.TRANSIENTREAD,
null, actualDim1, actualDim2, var.getNnz(), var.getRowsInBlock(), var.getColumnsInBlock());
read.setAllPositions(var.getBeginLine(), var.getBeginColumn(), var.getEndLine(), var.getEndColumn());
}
_ids.put(varName, read);
}
DataIdentifier target = new DataIdentifier(Expression.getTempName());
target.setDataType(DataType.SCALAR);
target.setValueType(ValueType.BOOLEAN);
target.setAllPositions(passedSB.getFilename(), passedSB.getBeginLine(), passedSB.getBeginColumn(), passedSB.getEndLine(), passedSB.getEndColumn());
Hop predicateHops = null;
Expression predicate = cp.getPredicate();
if (predicate instanceof RelationalExpression) {
predicateHops = processRelationalExpression((RelationalExpression) cp.getPredicate(), target, _ids);
} else if (predicate instanceof BooleanExpression) {
predicateHops = processBooleanExpression((BooleanExpression) cp.getPredicate(), target, _ids);
} else if (predicate instanceof DataIdentifier) {
// handle data identifier predicate
predicateHops = processExpression(cp.getPredicate(), null, _ids);
} else if (predicate instanceof ConstIdentifier) {
// handle constant identifier
// a) translate 0 --> FALSE; translate 1 --> TRUE
// b) disallow string values
if ( (predicate instanceof IntIdentifier && ((IntIdentifier)predicate).getValue() == 0) || (predicate instanceof DoubleIdentifier && ((DoubleIdentifier)predicate).getValue() == 0.0)) {
cp.setPredicate(new BooleanIdentifier(false,
predicate.getFilename(),
predicate.getBeginLine(), predicate.getBeginColumn(),
predicate.getEndLine(), predicate.getEndColumn()));
}
else if ( (predicate instanceof IntIdentifier && ((IntIdentifier)predicate).getValue() == 1) || (predicate instanceof DoubleIdentifier && ((DoubleIdentifier)predicate).getValue() == 1.0)) {
cp.setPredicate(new BooleanIdentifier(true,
predicate.getFilename(),
predicate.getBeginLine(), predicate.getBeginColumn(),
predicate.getEndLine(), predicate.getEndColumn()));
}
else if (predicate instanceof IntIdentifier || predicate instanceof DoubleIdentifier){
cp.setPredicate(new BooleanIdentifier(true,
predicate.getFilename(),
predicate.getBeginLine(), predicate.getBeginColumn(),
predicate.getEndLine(), predicate.getEndColumn()));
LOG.warn(predicate.printWarningLocation() + "Numerical value '" + predicate.toString() + "' (!= 0/1) is converted to boolean TRUE by DML");
}
else if (predicate instanceof StringIdentifier) {
LOG.error(predicate.printErrorLocation() + "String value '" + predicate.toString() + "' is not allowed for iterable predicate");
throw new ParseException(predicate.printErrorLocation() + "String value '" + predicate.toString() + "' is not allowed for iterable predicate");
}
predicateHops = processExpression(cp.getPredicate(), null, _ids);
}
if (passedSB instanceof WhileStatementBlock)
((WhileStatementBlock)passedSB).setPredicateHops(predicateHops);
else if (passedSB instanceof IfStatementBlock)
((IfStatementBlock)passedSB).setPredicateHops(predicateHops);
}
/**
* Constructs all predicate Hops (for FROM, TO, INCREMENT) of an iterable predicate
* and assigns these Hops to the passed statement block.
*
* Method used for both ForStatementBlock and ParForStatementBlock.
*
* @param passedSB
* @throws ParseException
*/
public void constructHopsForIterablePredicate(ForStatementBlock fsb)
throws ParseException
{
HashMap _ids = new HashMap();
// set iterable predicate
ForStatement fs = (ForStatement) fsb.getStatement(0);
IterablePredicate ip = fs.getIterablePredicate();
for(int i=0; i < 3; i++) {
VariableSet varsRead = null;
if (i==0)
varsRead = ip.getFromExpr().variablesRead();
else if (i==1)
varsRead = ip.getToExpr().variablesRead();
else if( ip.getIncrementExpr() != null )
varsRead = ip.getIncrementExpr().variablesRead();
if(varsRead != null) {
for (String varName : varsRead.getVariables().keySet()) {
DataIdentifier var = fsb.liveIn().getVariable(varName);
DataOp read = null;
if (var == null) {
LOG.error("variable '" + varName + "' is not available for iterable predicate");
throw new ParseException("variable '" + varName + "' is not available for iterable predicate");
}
else {
long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim1() : var.getDim1();
long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim2() : var.getDim2();
read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), DataOpTypes.TRANSIENTREAD,
null, actualDim1, actualDim2, var.getNnz(), var.getRowsInBlock(), var.getColumnsInBlock());
read.setAllPositions(var.getBeginLine(), var.getBeginColumn(), var.getEndLine(), var.getEndColumn());
}
_ids.put(varName, read);
}
}
//construct hops for from, to, and increment expressions
if(i==0)
fsb.setFromHops( processTempIntExpression( ip.getFromExpr(), _ids ));
else if(i==1)
fsb.setToHops( processTempIntExpression( ip.getToExpr(), _ids ));
else if( ip.getIncrementExpr() != null )
fsb.setIncrementHops( processTempIntExpression( ip.getIncrementExpr(), _ids ));
}
/*VariableSet varsRead = ip.variablesRead();
for (String varName : varsRead.getVariables().keySet()) {
DataIdentifier var = passedSB.liveIn().getVariable(varName);
DataOp read = null;
if (var == null) {
LOG.error(var.printErrorLocation() + "variable '" + varName + "' is not available for iterable predicate");
throw new ParseException(var.printErrorLocation() + "variable '" + varName + "' is not available for iterable predicate");
}
else {
long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim1() : var.getDim1();
long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim2() : var.getDim2();
read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), DataOpTypes.TRANSIENTREAD,
null, actualDim1, actualDim2, var.getNnz(), var.getRowsInBlock(), var.getColumnsInBlock());
read.setAllPositions(var.getBeginLine(), var.getBeginColumn(), var.getEndLine(), var.getEndColumn());
}
_ids.put(varName, read);
}
//construct hops for from, to, and increment expressions
fsb.setFromHops( processTempIntExpression( ip.getFromExpr(), _ids ));
fsb.setToHops( processTempIntExpression( ip.getToExpr(), _ids ));
fsb.setIncrementHops( processTempIntExpression( ip.getIncrementExpr(), _ids ));*/
}
/**
* Construct Hops from parse tree : Process Expression in an assignment
* statement
*
* @throws ParseException
*/
private Hop processExpression(Expression source, DataIdentifier target, HashMap hops) throws ParseException {
if (source.getKind() == Expression.Kind.BinaryOp) {
return processBinaryExpression((BinaryExpression) source, target, hops);
} else if (source.getKind() == Expression.Kind.RelationalOp) {
return processRelationalExpression((RelationalExpression) source, target, hops);
} else if (source.getKind() == Expression.Kind.BooleanOp) {
return processBooleanExpression((BooleanExpression) source, target, hops);
} else if (source.getKind() == Expression.Kind.Data) {
if (source instanceof IndexedIdentifier){
IndexedIdentifier sourceIndexed = (IndexedIdentifier) source;
return processIndexingExpression(sourceIndexed,target,hops);
} else if (source instanceof IntIdentifier) {
IntIdentifier sourceInt = (IntIdentifier) source;
LiteralOp litop = new LiteralOp(sourceInt.getValue());
litop.setAllPositions(sourceInt.getBeginLine(), sourceInt.getBeginColumn(), sourceInt.getEndLine(), sourceInt.getEndColumn());
setIdentifierParams(litop, sourceInt);
return litop;
} else if (source instanceof DoubleIdentifier) {
DoubleIdentifier sourceDouble = (DoubleIdentifier) source;
LiteralOp litop = new LiteralOp(sourceDouble.getValue());
litop.setAllPositions(sourceDouble.getBeginLine(), sourceDouble.getBeginColumn(), sourceDouble.getEndLine(), sourceDouble.getEndColumn());
setIdentifierParams(litop, sourceDouble);
return litop;
} else if (source instanceof DataIdentifier) {
DataIdentifier sourceId = (DataIdentifier) source;
return hops.get(sourceId.getName());
} else if (source instanceof BooleanIdentifier) {
BooleanIdentifier sourceBoolean = (BooleanIdentifier) source;
LiteralOp litop = new LiteralOp(sourceBoolean.getValue());
litop.setAllPositions(sourceBoolean.getBeginLine(), sourceBoolean.getBeginColumn(), sourceBoolean.getEndLine(), sourceBoolean.getEndColumn());
setIdentifierParams(litop, sourceBoolean);
return litop;
} else if (source instanceof StringIdentifier) {
StringIdentifier sourceString = (StringIdentifier) source;
LiteralOp litop = new LiteralOp(sourceString.getValue());
litop.setAllPositions(sourceString.getBeginLine(), sourceString.getBeginColumn(), sourceString.getEndLine(), sourceString.getEndColumn());
setIdentifierParams(litop, sourceString);
return litop;
}
} else if (source.getKind() == Expression.Kind.BuiltinFunctionOp) {
try {
return processBuiltinFunctionExpression((BuiltinFunctionExpression) source, target, hops);
} catch (HopsException e) {
throw new ParseException(e.getMessage());
}
} else if (source.getKind() == Expression.Kind.ParameterizedBuiltinFunctionOp ) {
try {
return processParameterizedBuiltinFunctionExpression((ParameterizedBuiltinFunctionExpression)source, target, hops);
} catch ( HopsException e ) {
throw new ParseException(e.getMessage());
}
} else if (source.getKind() == Expression.Kind.DataOp ) {
try {
Hop ae = (Hop)processDataExpression((DataExpression)source, target, hops);
if (ae instanceof DataOp){
String formatName = ((DataExpression)source).getVarParam(DataExpression.FORMAT_TYPE).toString();
((DataOp)ae).setInputFormatType(Expression.convertFormatType(formatName));
}
//hops.put(target.getName(), ae);
return ae;
} catch ( Exception e ) {
throw new ParseException(e.getMessage());
}
}
return null;
} // end method processExpression
private DataIdentifier createTarget(Expression source) {
Identifier id = source.getOutput();
if (id instanceof DataIdentifier && !(id instanceof DataExpression))
return (DataIdentifier) id;
DataIdentifier target = new DataIdentifier(Expression.getTempName());
target.setProperties(id);
return target;
}
private DataIdentifier createTarget() {
DataIdentifier target = new DataIdentifier(Expression.getTempName());
return target;
}
/**
* Constructs the Hops for arbitrary expressions that eventually evaluate to an INT scalar.
*
* @param source
* @param hops
* @return
* @throws ParseException
*/
private Hop processTempIntExpression( Expression source, HashMap hops )
throws ParseException
{
DataIdentifier tmpOut = createTarget();
tmpOut.setDataType(DataType.SCALAR);
tmpOut.setValueType(ValueType.INT);
source.setOutput(tmpOut);
return processExpression(source, tmpOut, hops );
}
private Hop processLeftIndexedExpression(Expression source, IndexedIdentifier target, HashMap hops)
throws ParseException {
// process target indexed expressions
Hop rowLowerHops = null, rowUpperHops = null, colLowerHops = null, colUpperHops = null;
if (target.getRowLowerBound() != null)
rowLowerHops = processExpression(target.getRowLowerBound(),null,hops);
else
rowLowerHops = new LiteralOp(1);
if (target.getRowUpperBound() != null)
rowUpperHops = processExpression(target.getRowUpperBound(),null,hops);
else
{
if ( target.getDim1() != -1 )
rowUpperHops = new LiteralOp(target.getOrigDim1());
else
{
try {
//currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hops.OpOp1.NROW, expr);
rowUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NROW, hops.get(target.getName()));
rowUpperHops.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndColumn());
} catch (HopsException e) {
LOG.error(target.printErrorLocation() + "error processing row upper index for indexed expression " + target.toString());
throw new RuntimeException(target.printErrorLocation() + "error processing row upper index for indexed expression " + target.toString());
}
}
}
if (target.getColLowerBound() != null)
colLowerHops = processExpression(target.getColLowerBound(),null,hops);
else
colLowerHops = new LiteralOp(1);
if (target.getColUpperBound() != null)
colUpperHops = processExpression(target.getColUpperBound(),null,hops);
else
{
if ( target.getDim2() != -1 )
colUpperHops = new LiteralOp(target.getOrigDim2());
else
{
try {
colUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NCOL, hops.get(target.getName()));
} catch (HopsException e) {
LOG.error(target.printErrorLocation() + " error processing column upper index for indexed expression " + target.toString());
throw new RuntimeException(target.printErrorLocation() + " error processing column upper index for indexed expression " + target.toString(), e);
}
}
}
//if (target == null) {
// target = createTarget(source);
//}
// process the source expression to get source Hops
Hop sourceOp = processExpression(source, target, hops);
// process the target to get targetHops
Hop targetOp = hops.get(target.getName());
if (targetOp == null){
LOG.error(target.printErrorLocation() + " must define matrix " + target.getName() + " before indexing operations are allowed ");
throw new ParseException(target.printErrorLocation() + " must define matrix " + target.getName() + " before indexing operations are allowed ");
}
//TODO Doug, please verify this (we need probably a cleaner way than this postprocessing)
if( sourceOp.getDataType() == DataType.MATRIX && source.getOutput().getDataType() == DataType.SCALAR )
sourceOp.setDataType(DataType.SCALAR);
Hop leftIndexOp = new LeftIndexingOp(target.getName(), target.getDataType(), target.getValueType(),
targetOp, sourceOp, rowLowerHops, rowUpperHops, colLowerHops, colUpperHops,
target.getRowLowerEqualsUpper(), target.getColLowerEqualsUpper());
setIdentifierParams(leftIndexOp, target);
leftIndexOp.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndColumn());
leftIndexOp.setDim1(target.getOrigDim1());
leftIndexOp.setDim2(target.getOrigDim2());
return leftIndexOp;
}
private Hop processIndexingExpression(IndexedIdentifier source, DataIdentifier target, HashMap hops)
throws ParseException {
// process Hops for indexes (for source)
Hop rowLowerHops = null, rowUpperHops = null, colLowerHops = null, colUpperHops = null;
if (source.getRowLowerBound() != null)
rowLowerHops = processExpression(source.getRowLowerBound(),null,hops);
else
rowLowerHops = new LiteralOp(1);
if (source.getRowUpperBound() != null)
rowUpperHops = processExpression(source.getRowUpperBound(),null,hops);
else
{
if ( source.getOrigDim1() != -1 )
rowUpperHops = new LiteralOp(source.getOrigDim1());
else
{
try {
//currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hops.OpOp1.NROW, expr);
rowUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NROW, hops.get(source.getName()));
rowUpperHops.setAllPositions(source.getBeginLine(),source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
} catch (HopsException e) {
LOG.error(source.printErrorLocation() + "error processing row upper index for indexed identifier " + source.toString());
throw new RuntimeException(source.printErrorLocation() + "error processing row upper index for indexed identifier " + source.toString() + e);
}
}
}
if (source.getColLowerBound() != null)
colLowerHops = processExpression(source.getColLowerBound(),null,hops);
else
colLowerHops = new LiteralOp(1);
if (source.getColUpperBound() != null)
colUpperHops = processExpression(source.getColUpperBound(),null,hops);
else
{
if ( source.getOrigDim2() != -1 )
colUpperHops = new LiteralOp(source.getOrigDim2());
else
{
try {
colUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NCOL, hops.get(source.getName()));
} catch (HopsException e) {
LOG.error(source.printErrorLocation() + "error processing column upper index for indexed indentifier " + source.toString(), e);
throw new RuntimeException(source.printErrorLocation() + "error processing column upper index for indexed indentifier " + source.toString(), e);
}
}
}
if (target == null) {
target = createTarget(source);
}
//unknown nnz after range indexing (applies to indexing op but also
//data dependent operations)
target.setNnz(-1);
Hop indexOp = new IndexingOp(target.getName(), target.getDataType(), target.getValueType(),
hops.get(source.getName()), rowLowerHops, rowUpperHops, colLowerHops, colUpperHops,
source.getRowLowerEqualsUpper(), source.getColLowerEqualsUpper());
indexOp.setAllPositions(indexOp.getBeginLine(), indexOp.getBeginColumn(), indexOp.getEndLine(), indexOp.getEndColumn());
setIdentifierParams(indexOp, target);
return indexOp;
}
/**
* Construct Hops from parse tree : Process Binary Expression in an
* assignment statement
*
* @throws ParseException
*/
private Hop processBinaryExpression(BinaryExpression source, DataIdentifier target, HashMap hops)
throws ParseException
{
Hop left = processExpression(source.getLeft(), null, hops);
Hop right = processExpression(source.getRight(), null, hops);
if (left == null || right == null){
left = processExpression(source.getLeft(), null, hops);
right = processExpression(source.getRight(), null, hops);
}
Hop currBop = null;
//prepare target identifier and ensure that output type is of inferred type
//(type should not be determined by target (e.g., string for print)
if (target == null) {
target = createTarget(source);
}
target.setValueType(source.getOutput().getValueType());
if (source.getOpCode() == Expression.BinaryOp.PLUS) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.PLUS, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.MINUS) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MINUS, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.MULT) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MULT, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.DIV) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.DIV, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.MODULUS) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MODULUS, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.INTDIV) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.INTDIV, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.MATMULT) {
currBop = new AggBinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MULT, AggOp.SUM, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.POW) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.POW, left, right);
}
else {
throw new ParseException("Unsupported parsing of binary expression: "+source.getOpCode());
}
setIdentifierParams(currBop, source.getOutput());
currBop.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
return currBop;
}
private Hop processRelationalExpression(RelationalExpression source, DataIdentifier target,
HashMap hops) throws ParseException {
Hop left = processExpression(source.getLeft(), null, hops);
Hop right = processExpression(source.getRight(), null, hops);
Hop currBop = null;
if (target == null) {
target = createTarget(source);
if(left.getDataType() == DataType.MATRIX || right.getDataType() == DataType.MATRIX) {
// Added to support matrix relational comparison
target.setDataType(DataType.MATRIX);
target.setValueType(ValueType.BOOLEAN);
}
else {
// Added to support scalar relational comparison
target.setDataType(DataType.SCALAR);
target.setValueType(ValueType.BOOLEAN);
}
}
OpOp2 op = null;
if (source.getOpCode() == Expression.RelationalOp.LESS) {
op = OpOp2.LESS;
} else if (source.getOpCode() == Expression.RelationalOp.LESSEQUAL) {
op = OpOp2.LESSEQUAL;
} else if (source.getOpCode() == Expression.RelationalOp.GREATER) {
op = OpOp2.GREATER;
} else if (source.getOpCode() == Expression.RelationalOp.GREATEREQUAL) {
op = OpOp2.GREATEREQUAL;
} else if (source.getOpCode() == Expression.RelationalOp.EQUAL) {
op = OpOp2.EQUAL;
} else if (source.getOpCode() == Expression.RelationalOp.NOTEQUAL) {
op = OpOp2.NOTEQUAL;
}
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, left, right);
currBop.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
return currBop;
}
/**
*
* @param source
* @param target
* @param hops
* @return
* @throws ParseException
*/
private Hop processBooleanExpression(BooleanExpression source, DataIdentifier target, HashMap hops)
throws ParseException
{
// Boolean Not has a single parameter
boolean constLeft = (source.getLeft().getOutput() instanceof ConstIdentifier);
boolean constRight = false;
if (source.getRight() != null) {
constRight = (source.getRight().getOutput() instanceof ConstIdentifier);
}
if (constLeft || constRight) {
LOG.error(source.printErrorLocation() + "Boolean expression with constant unsupported");
throw new RuntimeException(source.printErrorLocation() + "Boolean expression with constant unsupported");
}
Hop left = processExpression(source.getLeft(), null, hops);
Hop right = null;
if (source.getRight() != null) {
right = processExpression(source.getRight(), null, hops);
}
//prepare target identifier and ensure that output type is boolean
//(type should not be determined by target (e.g., string for print)
if (target == null) {
target = createTarget(source);
}
target.setValueType(ValueType.BOOLEAN);
if (source.getRight() == null) {
Hop currUop = null;
try {
currUop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NOT, left);
currUop.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
} catch (HopsException e) {
throw new ParseException(e.getMessage());
}
return currUop;
} else {
Hop currBop = null;
OpOp2 op = null;
if (source.getOpCode() == Expression.BooleanOp.LOGICALAND) {
op = OpOp2.AND;
} else if (source.getOpCode() == Expression.BooleanOp.LOGICALOR) {
op = OpOp2.OR;
} else {
LOG.error(source.printErrorLocation() + "Unknown boolean operation " + source.getOpCode());
throw new RuntimeException(source.printErrorLocation() + "Unknown boolean operation " + source.getOpCode());
}
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, left, right);
currBop.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
// setIdentifierParams(currBop,source.getOutput());
return currBop;
}
}
private Hop constructDfHop(String name, DataType dt, ValueType vt, ParameterizedBuiltinFunctionOp op, HashMap paramHops) throws HopsException {
// Add a hop to paramHops to store distribution information.
// Distribution parameter hops would have been already present in paramHops.
Hop distLop = null;
switch(op) {
case QNORM:
case PNORM:
distLop = new LiteralOp("normal");
break;
case QT:
case PT:
distLop = new LiteralOp("t");
break;
case QF:
case PF:
distLop = new LiteralOp("f");
break;
case QCHISQ:
case PCHISQ:
distLop = new LiteralOp("chisq");
break;
case QEXP:
case PEXP:
distLop = new LiteralOp("exp");
break;
case CDF:
case INVCDF:
break;
default:
throw new HopsException("Invalid operation: " + op);
}
if (distLop != null)
paramHops.put("dist", distLop);
return new ParameterizedBuiltinOp(name, dt, vt, ParameterizedBuiltinFunctionExpression.pbHopMap.get(op), paramHops);
}
/**
*
* @param source
* @param targetList
* @param hops
* @return
* @throws ParseException
*/
private Hop processMultipleReturnParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFunctionExpression source, ArrayList targetList,
HashMap hops) throws ParseException
{
FunctionType ftype = FunctionType.MULTIRETURN_BUILTIN;
String nameSpace = DMLProgram.INTERNAL_NAMESPACE;
// Create an array list to hold the outputs of this lop.
// Exact list of outputs are added based on opcode.
ArrayList outputs = new ArrayList();
// Construct Hop for current builtin function expression based on its type
Hop currBuiltinOp = null;
switch (source.getOpCode()) {
case TRANSFORMENCODE:
ArrayList inputs = new ArrayList();
inputs.add( processExpression(source.getVarParam("target"), null, hops) );
inputs.add( processExpression(source.getVarParam("spec"), null, hops) );
String[] outputNames = new String[targetList.size()];
outputNames[0] = ((DataIdentifier)targetList.get(0)).getName();
outputNames[1] = ((DataIdentifier)targetList.get(1)).getName();
outputs.add(new DataOp(outputNames[0], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[0]));
outputs.add(new DataOp(outputNames[1], DataType.FRAME, ValueType.STRING, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[1]));
currBuiltinOp = new FunctionOp(ftype, nameSpace, source.getOpCode().toString(), inputs, outputNames, outputs);
break;
default:
throw new ParseException("Invaid Opcode in DMLTranslator:processMultipleReturnParameterizedBuiltinFunctionExpression(): " + source.getOpCode());
}
// set properties for created hops based on outputs of source expression
for ( int i=0; i < source.getOutputs().length; i++ ) {
setIdentifierParams( outputs.get(i), source.getOutputs()[i]);
outputs.get(i).setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
}
currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
return currBuiltinOp;
}
/**
* Construct Hops from parse tree : Process ParameterizedBuiltinFunction Expression in an
* assignment statement
*
* @throws ParseException
* @throws HopsException
*/
private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFunctionExpression source, DataIdentifier target,
HashMap hops) throws ParseException, HopsException {
// this expression has multiple "named" parameters
HashMap paramHops = new HashMap();
// -- construct hops for all input parameters
// -- store them in hashmap so that their "name"s are maintained
Hop pHop = null;
for ( String paramName : source.getVarParams().keySet() ) {
pHop = processExpression(source.getVarParam(paramName), null, hops);
paramHops.put(paramName, pHop);
}
Hop currBuiltinOp = null;
if (target == null) {
target = createTarget(source);
}
// construct hop based on opcode
switch(source.getOpCode()) {
case CDF:
case INVCDF:
case QNORM:
case QT:
case QF:
case QCHISQ:
case QEXP:
case PNORM:
case PT:
case PF:
case PCHISQ:
case PEXP:
currBuiltinOp = constructDfHop(target.getName(), target.getDataType(), target.getValueType(), source.getOpCode(), paramHops);
break;
case GROUPEDAGG:
currBuiltinOp = new ParameterizedBuiltinOp(
target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.GROUPEDAGG, paramHops);
break;
case RMEMPTY:
currBuiltinOp = new ParameterizedBuiltinOp(
target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.RMEMPTY, paramHops);
break;
case REPLACE:
currBuiltinOp = new ParameterizedBuiltinOp(
target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.REPLACE, paramHops);
break;
case ORDER:
ArrayList inputs = new ArrayList();
inputs.add(paramHops.get("target"));
inputs.add(paramHops.get("by"));
inputs.add(paramHops.get("decreasing"));
inputs.add(paramHops.get("index.return"));
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), ReOrgOp.SORT, inputs);
break;
case TRANSFORM:
currBuiltinOp = new ParameterizedBuiltinOp(
target.getName(), target.getDataType(),
target.getValueType(), ParamBuiltinOp.TRANSFORM,
paramHops);
break;
case TRANSFORMAPPLY:
currBuiltinOp = new ParameterizedBuiltinOp(
target.getName(), target.getDataType(),
target.getValueType(), ParamBuiltinOp.TRANSFORMAPPLY,
paramHops);
break;
case TRANSFORMDECODE:
currBuiltinOp = new ParameterizedBuiltinOp(
target.getName(), target.getDataType(),
target.getValueType(), ParamBuiltinOp.TRANSFORMDECODE,
paramHops);
break;
case TRANSFORMMETA:
currBuiltinOp = new ParameterizedBuiltinOp(
target.getName(), target.getDataType(),
target.getValueType(), ParamBuiltinOp.TRANSFORMMETA,
paramHops);
break;
case TOSTRING:
currBuiltinOp = new ParameterizedBuiltinOp(
target.getName(), target.getDataType(),
target.getValueType(), ParamBuiltinOp.TOSTRING,
paramHops);
break;
default:
LOG.error(source.printErrorLocation() +
"processParameterizedBuiltinFunctionExpression() -- Unknown operation: "
+ source.getOpCode());
throw new ParseException(source.printErrorLocation() +
"processParameterizedBuiltinFunctionExpression() -- Unknown operation: "
+ source.getOpCode());
}
setIdentifierParams(currBuiltinOp, source.getOutput());
currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
return currBuiltinOp;
}
/**
* Construct Hops from parse tree : Process ParameterizedExpression in a
* read/write/rand statement
*
* @throws ParseException
* @throws HopsException
*/
private Hop processDataExpression(DataExpression source, DataIdentifier target,
HashMap hops) throws ParseException, HopsException {
// this expression has multiple "named" parameters
HashMap paramHops = new HashMap();
// -- construct hops for all input parameters
// -- store them in hashmap so that their "name"s are maintained
Hop pHop = null;
for ( String paramName : source.getVarParams().keySet() ) {
pHop = processExpression(source.getVarParam(paramName), null, hops);
paramHops.put(paramName, pHop);
}
Hop currBuiltinOp = null;
if (target == null) {
target = createTarget(source);
}
// construct hop based on opcode
switch(source.getOpCode()) {
case READ:
currBuiltinOp = new DataOp(target.getName(), target.getDataType(), target.getValueType(), DataOpTypes.PERSISTENTREAD, paramHops);
((DataOp)currBuiltinOp).setFileName(((StringIdentifier)source.getVarParam(DataExpression.IO_FILENAME)).getValue());
break;
case WRITE:
String name = target.getName();
currBuiltinOp = new DataOp(
target.getName(), target.getDataType(), target.getValueType(), DataOpTypes.PERSISTENTWRITE, hops.get(name), paramHops);
//MB: commented for dynamic write
/*Identifier ioFilename = (Identifier)source.getVarParam(DataExpression.IO_FILENAME);
if (!(ioFilename instanceof StringIdentifier)) {
LOG.error(source.printErrorLocation() + "processDataExpression():: Filename must be a constant string value");
throw new ParseException(source.printErrorLocation() + "processDataExpression():: Filename must be a constant string value");
} else {
((DataOp)currBuiltinOp).setFileName(((StringIdentifier)ioFilename).getValue());
}*/
break;
case RAND:
// We limit RAND_MIN, RAND_MAX, RAND_SPARSITY, RAND_SEED, and RAND_PDF to be constants
DataGenMethod method = (paramHops.get(DataExpression.RAND_MIN).getValueType()==ValueType.STRING) ?
DataGenMethod.SINIT : DataGenMethod.RAND;
currBuiltinOp = new DataGenOp(method, target, paramHops);
break;
case MATRIX:
ArrayList tmp = new ArrayList();
tmp.add( 0, paramHops.get(DataExpression.RAND_DATA) );
tmp.add( 1, paramHops.get(DataExpression.RAND_ROWS) );
tmp.add( 2, paramHops.get(DataExpression.RAND_COLS) );
tmp.add( 3, paramHops.get(DataExpression.RAND_BY_ROW) );
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), ReOrgOp.RESHAPE, tmp);
break;
default:
LOG.error(source.printErrorLocation() +
"processDataExpression():: Unknown operation: "
+ source.getOpCode());
throw new ParseException(source.printErrorLocation() +
"processDataExpression():: Unknown operation: "
+ source.getOpCode());
}
//set identifier meta data (incl dimensions and blocksizes)
setIdentifierParams(currBuiltinOp, source.getOutput());
if( source.getOpCode()==DataExpression.DataOp.READ )
((DataOp)currBuiltinOp).setInputBlockSizes(target.getRowsInBlock(), target.getColumnsInBlock());
currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
return currBuiltinOp;
}
/**
* Construct HOps from parse tree: process BuiltinFunction Expressions in
* MultiAssignment Statements. For all other builtin function expressions,
* processBuiltinFunctionExpression()
is used.
*/
private Hop processMultipleReturnBuiltinFunctionExpression(BuiltinFunctionExpression source, ArrayList targetList,
HashMap hops) throws ParseException {
// Construct Hops for all inputs
ArrayList inputs = new ArrayList();
inputs.add( processExpression(source.getFirstExpr(), null, hops) );
if ( source.getSecondExpr() != null )
inputs.add( processExpression(source.getSecondExpr(), null, hops) );
if ( source.getThirdExpr() != null )
inputs.add( processExpression(source.getThirdExpr(), null, hops) );
FunctionType ftype = FunctionType.MULTIRETURN_BUILTIN;
String nameSpace = DMLProgram.INTERNAL_NAMESPACE;
// Create an array list to hold the outputs of this lop.
// Exact list of outputs are added based on opcode.
ArrayList outputs = new ArrayList();
// Construct Hop for current builtin function expression based on its type
Hop currBuiltinOp = null;
switch (source.getOpCode()) {
case QR:
case LU:
case EIGEN:
// Number of outputs = size of targetList = #of identifiers in source.getOutputs
String[] outputNames = new String[targetList.size()];
for ( int i=0; i < targetList.size(); i++ ) {
outputNames[i] = ((DataIdentifier)targetList.get(i)).getName();
Hop output = new DataOp(outputNames[i], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[i]);
outputs.add(output);
}
// Create the hop for current function call
FunctionOp fcall = new FunctionOp(ftype, nameSpace, source.getOpCode().toString(), inputs, outputNames, outputs);
currBuiltinOp = fcall;
break;
default:
throw new ParseException("Invaid Opcode in DMLTranslator:processMultipleReturnBuiltinFunctionExpression(): " + source.getOpCode());
}
// set properties for created hops based on outputs of source expression
for ( int i=0; i < source.getOutputs().length; i++ ) {
setIdentifierParams( outputs.get(i), source.getOutputs()[i]);
outputs.get(i).setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
}
currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
return currBuiltinOp;
}
/**
* Construct Hops from parse tree : Process BuiltinFunction Expression in an
* assignment statement
*
* @throws ParseException
* @throws HopsException
*/
private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, DataIdentifier target,
HashMap hops) throws ParseException, HopsException {
Hop expr = processExpression(source.getFirstExpr(), null, hops);
Hop expr2 = null;
if (source.getSecondExpr() != null) {
expr2 = processExpression(source.getSecondExpr(), null, hops);
}
Hop expr3 = null;
if (source.getThirdExpr() != null) {
expr3 = processExpression(source.getThirdExpr(), null, hops);
}
Hop currBuiltinOp = null;
if (target == null) {
target = createTarget(source);
}
// Construct the hop based on the type of Builtin function
switch (source.getOpCode()) {
case COLSUM:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM,
Direction.Col, expr);
break;
case COLMAX:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX,
Direction.Col, expr);
break;
case COLMIN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN,
Direction.Col, expr);
break;
case COLMEAN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN,
Direction.Col, expr);
break;
case COLSD:
// colStdDevs = sqrt(colVariances)
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(),
target.getValueType(), AggOp.VAR, Direction.Col, expr);
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
break;
case COLVAR:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(),
target.getValueType(), AggOp.VAR, Direction.Col, expr);
break;
case ROWSUM:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM,
Direction.Row, expr);
break;
case ROWMAX:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX,
Direction.Row, expr);
break;
case ROWINDEXMAX:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAXINDEX,
Direction.Row, expr);
break;
case ROWINDEXMIN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MININDEX,
Direction.Row, expr);
break;
case ROWMIN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN,
Direction.Row, expr);
break;
case ROWMEAN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN,
Direction.Row, expr);
break;
case ROWSD:
// rowStdDevs = sqrt(rowVariances)
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(),
target.getValueType(), AggOp.VAR, Direction.Row, expr);
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
break;
case ROWVAR:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(),
target.getValueType(), AggOp.VAR, Direction.Row, expr);
break;
case NROW:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
long nRows = expr.getDim1();
if (nRows == -1) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NROW, expr);
}
else {
currBuiltinOp = new LiteralOp(nRows);
}
break;
case NCOL:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
long nCols = expr.getDim2();
if (nCols == -1) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NCOL, expr);
}
else {
currBuiltinOp = new LiteralOp(nCols);
}
break;
case LENGTH:
long nRows2 = expr.getDim1();
long nCols2 = expr.getDim2();
/*
* If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
* Else create a UnaryOp so that a control program instruction is generated
*/
if ((nCols2 == -1) || (nRows2 == -1)) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.LENGTH, expr);
}
else {
long lval = (nCols2 * nRows2);
currBuiltinOp = new LiteralOp(lval);
}
break;
case SUM:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM,
Direction.RowCol, expr);
break;
case MEAN:
if ( expr2 == null ) {
// example: x = mean(Y);
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN,
Direction.RowCol, expr);
}
else {
// example: x = mean(Y,W);
// stable weighted mean is implemented by using centralMoment with order = 0
Hop orderHop = new LiteralOp(0);
currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp3.CENTRALMOMENT, expr, expr2, orderHop);
}
break;
case SD:
// stdDev = sqrt(variance)
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(),
target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
break;
case VAR:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(),
target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
break;
case MIN:
//construct AggUnary for min(X) but BinaryOp for min(X,Y)
if( expr2 == null ) {
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(),
AggOp.MIN, Direction.RowCol, expr);
}
else {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MIN,
expr, expr2);
}
break;
case MAX:
//construct AggUnary for max(X) but BinaryOp for max(X,Y)
if( expr2 == null ) {
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(),
AggOp.MAX, Direction.RowCol, expr);
} else {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MAX,
expr, expr2);
}
break;
case PPRED:
String sop = ((StringIdentifier)source.getThirdExpr()).getValue();
sop = sop.replace("\"", "");
OpOp2 operation;
if ( sop.equalsIgnoreCase(">=") )
operation = OpOp2.GREATEREQUAL;
else if ( sop.equalsIgnoreCase(">") )
operation = OpOp2.GREATER;
else if ( sop.equalsIgnoreCase("<=") )
operation = OpOp2.LESSEQUAL;
else if ( sop.equalsIgnoreCase("<") )
operation = OpOp2.LESS;
else if ( sop.equalsIgnoreCase("==") )
operation = OpOp2.EQUAL;
else if ( sop.equalsIgnoreCase("!=") )
operation = OpOp2.NOTEQUAL;
else {
LOG.error(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED.");
throw new ParseException(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED.");
}
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), operation, expr, expr2);
break;
case PROD:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.PROD,
Direction.RowCol, expr);
break;
case TRACE:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.TRACE,
Direction.RowCol, expr);
break;
case TRANS:
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.ReOrgOp.TRANSPOSE, expr);
break;
case REV:
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.ReOrgOp.REV, expr);
break;
case CBIND:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp2.CBIND, expr, expr2);
break;
case RBIND:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp2.RBIND, expr, expr2);
break;
case DIAG:
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.ReOrgOp.DIAG, expr);
break;
case TABLE:
// Always a TertiaryOp is created for table().
// - create a hop for weights, if not provided in the function call.
int numTableArgs = source._args.length;
switch(numTableArgs) {
case 2:
case 4:
// example DML statement: F = ctable(A,B) or F = ctable(A,B,10,15)
// here, weight is interpreted as 1.0
Hop weightHop = new LiteralOp(1.0);
// set dimensions
weightHop.setDim1(0);
weightHop.setDim2(0);
weightHop.setNnz(-1);
weightHop.setRowsInBlock(0);
weightHop.setColsInBlock(0);
if ( numTableArgs == 2 )
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop);
else {
Hop outDim1 = processExpression(source._args[2], null, hops);
Hop outDim2 = processExpression(source._args[3], null, hops);
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop, outDim1, outDim2);
}
break;
case 3:
case 5:
// example DML statement: F = ctable(A,B,W) or F = ctable(A,B,W,10,15)
if (numTableArgs == 3)
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3);
else {
Hop outDim1 = processExpression(source._args[3], null, hops);
Hop outDim2 = processExpression(source._args[4], null, hops);
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3, outDim1, outDim2);
}
break;
default:
throw new ParseException("Invalid number of arguments "+ numTableArgs + " to table() function.");
}
break;
//data type casts
case CAST_AS_SCALAR:
currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), Hop.OpOp1.CAST_AS_SCALAR, expr);
break;
case CAST_AS_MATRIX:
currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), Hop.OpOp1.CAST_AS_MATRIX, expr);
break;
case CAST_AS_FRAME:
currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), Hop.OpOp1.CAST_AS_FRAME, expr);
break;
//value type casts
case CAST_AS_DOUBLE:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.DOUBLE, Hop.OpOp1.CAST_AS_DOUBLE, expr);
break;
case CAST_AS_INT:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.INT, Hop.OpOp1.CAST_AS_INT, expr);
break;
case CAST_AS_BOOLEAN:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.BOOLEAN, Hop.OpOp1.CAST_AS_BOOLEAN, expr);
break;
case ABS:
case SIN:
case COS:
case TAN:
case ASIN:
case ACOS:
case ATAN:
case SIGN:
case SQRT:
case EXP:
case ROUND:
case CEIL:
case FLOOR:
case CUMSUM:
case CUMPROD:
case CUMMIN:
case CUMMAX:
Hop.OpOp1 mathOp1;
switch (source.getOpCode()) {
case ABS:
mathOp1 = Hop.OpOp1.ABS;
break;
case SIN:
mathOp1 = Hop.OpOp1.SIN;
break;
case COS:
mathOp1 = Hop.OpOp1.COS;
break;
case TAN:
mathOp1 = Hop.OpOp1.TAN;
break;
case ASIN:
mathOp1 = Hop.OpOp1.ASIN;
break;
case ACOS:
mathOp1 = Hop.OpOp1.ACOS;
break;
case ATAN:
mathOp1 = Hop.OpOp1.ATAN;
break;
case SIGN:
mathOp1 = Hop.OpOp1.SIGN;
break;
case SQRT:
mathOp1 = Hop.OpOp1.SQRT;
break;
case EXP:
mathOp1 = Hop.OpOp1.EXP;
break;
case ROUND:
mathOp1 = Hop.OpOp1.ROUND;
break;
case CEIL:
mathOp1 = Hop.OpOp1.CEIL;
break;
case FLOOR:
mathOp1 = Hop.OpOp1.FLOOR;
break;
case CUMSUM:
mathOp1 = Hop.OpOp1.CUMSUM;
break;
case CUMPROD:
mathOp1 = Hop.OpOp1.CUMPROD;
break;
case CUMMIN:
mathOp1 = Hop.OpOp1.CUMMIN;
break;
case CUMMAX:
mathOp1 = Hop.OpOp1.CUMMAX;
break;
default:
LOG.error(source.printErrorLocation() +
"processBuiltinFunctionExpression():: Could not find Operation type for builtin function: "
+ source.getOpCode());
throw new ParseException(source.printErrorLocation() +
"processBuiltinFunctionExpression():: Could not find Operation type for builtin function: "
+ source.getOpCode());
}
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp1, expr);
break;
case LOG:
if (expr2 == null) {
Hop.OpOp1 mathOp2;
switch (source.getOpCode()) {
case LOG:
mathOp2 = Hop.OpOp1.LOG;
break;
default:
LOG.error(source.printErrorLocation() +
"processBuiltinFunctionExpression():: Could not find Operation type for builtin function: "
+ source.getOpCode());
throw new ParseException(source.printErrorLocation() +
"processBuiltinFunctionExpression():: Could not find Operation type for builtin function: "
+ source.getOpCode());
}
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp2,
expr);
} else {
Hop.OpOp2 mathOp3;
switch (source.getOpCode()) {
case LOG:
mathOp3 = Hop.OpOp2.LOG;
break;
default:
LOG.error(source.printErrorLocation() +
"processBuiltinFunctionExpression():: Could not find Operation type for builtin function: "
+ source.getOpCode());
throw new ParseException(source.printErrorLocation() +
"processBuiltinFunctionExpression():: Could not find Operation type for builtin function: "
+ source.getOpCode());
}
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp3,
expr, expr2);
}
break;
case MOMENT:
if (expr3 == null){
currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp2.CENTRALMOMENT, expr, expr2);
}
else {
currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp3.CENTRALMOMENT, expr, expr2,expr3);
}
break;
case COV:
if (expr3 == null){
currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp2.COVARIANCE, expr, expr2);
}
else {
currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp3.COVARIANCE, expr, expr2,expr3);
}
break;
case QUANTILE:
if (expr3 == null){
currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp2.QUANTILE, expr, expr2);
}
else {
currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp3.QUANTILE, expr, expr2,expr3);
}
break;
case INTERQUANTILE:
if ( expr3 == null ) {
currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp2.INTERQUANTILE, expr, expr2);
}
else {
currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp3.INTERQUANTILE, expr, expr2,expr3);
}
break;
case IQM:
if ( expr2 == null ) {
currBuiltinOp=new UnaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp1.IQM, expr);
}
else {
currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp2.IQM, expr, expr2);
}
break;
case MEDIAN:
if ( expr2 == null ) {
currBuiltinOp=new UnaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp1.MEDIAN, expr);
}
else {
currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp2.MEDIAN, expr, expr2);
}
break;
case SEQ:
HashMap randParams = new HashMap();
randParams.put(Statement.SEQ_FROM, expr);
randParams.put(Statement.SEQ_TO, expr2);
randParams.put(Statement.SEQ_INCR, (expr3!=null)?expr3 : new LiteralOp(1));
//note incr: default -1 (for from>to) handled during runtime
currBuiltinOp = new DataGenOp(DataGenMethod.SEQ, target, randParams);
break;
case SAMPLE:
{
Expression[] in = source.getAllExpr();
// arguments: range/size/replace/seed; defaults: replace=FALSE
HashMap tmpparams = new HashMap();
tmpparams.put(DataExpression.RAND_MAX, expr); //range
tmpparams.put(DataExpression.RAND_ROWS, expr2);
tmpparams.put(DataExpression.RAND_COLS, new LiteralOp(1));
if ( in.length == 4 )
{
tmpparams.put(DataExpression.RAND_PDF, expr3);
Hop seed = processExpression(in[3], null, hops);
tmpparams.put(DataExpression.RAND_SEED, seed);
}
else if ( in.length == 3 )
{
// check if the third argument is "replace" or "seed"
if ( expr3.getValueType() == ValueType.BOOLEAN )
{
tmpparams.put(DataExpression.RAND_PDF, expr3);
tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) );
}
else if ( expr3.getValueType() == ValueType.INT )
{
tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
tmpparams.put(DataExpression.RAND_SEED, expr3 );
}
else
throw new HopsException("Invalid input type " + expr3.getValueType() + " in sample().");
}
else if ( in.length == 2 )
{
tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) );
}
currBuiltinOp = new DataGenOp(DataGenMethod.SAMPLE, target, tmpparams);
break;
}
case SOLVE:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.SOLVE, expr, expr2);
break;
case INVERSE:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp1.INVERSE, expr);
break;
case CHOLESKY:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp1.CHOLESKY, expr);
break;
case OUTER:
if( !(expr3 instanceof LiteralOp) )
throw new HopsException("Operator for outer builtin function must be a constant: "+expr3);
OpOp2 op = Hop.getOpOp2ForOuterVectorOperation(((LiteralOp)expr3).getStringValue());
if( op == null )
throw new HopsException("Unsupported outer vector binary operation: "+((LiteralOp)expr3).getStringValue());
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, expr, expr2);
((BinaryOp)currBuiltinOp).setOuterVectorOperation(true); //flag op as specific outer vector operation
currBuiltinOp.refreshSizeInformation(); //force size reevaluation according to 'outer' flag otherwise danger of incorrect dims
break;
case CONV2D:
{
Hop image = expr;
ArrayList inHops1 = getALHopsForConvOp(image, source, 1, hops);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
case AVG_POOL:
case MAX_POOL:
{
Hop image = expr;
ArrayList inHops1 = getALHopsForPoolingForwardIM2COL(image, source, 1, hops);
if(source.getOpCode() == BuiltinFunctionOp.MAX_POOL)
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING, inHops1);
else
throw new HopsException("Average pooling is not implemented");
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
case MAX_POOL_BACKWARD:
{
Hop image = expr;
ArrayList inHops1 = getALHopsForConvOpPoolingCOL2IM(image, source, 1, hops); // process dout as well
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING_BACKWARD, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
case CONV2D_BACKWARD_FILTER:
{
Hop image = expr;
ArrayList inHops1 = getALHopsForConvOp(image, source, 1, hops);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D_BACKWARD_FILTER, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
case CONV2D_BACKWARD_DATA:
{
Hop image = expr;
ArrayList inHops1 = getALHopsForConvOp(image, source, 1, hops);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D_BACKWARD_DATA, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
default:
throw new ParseException("Unsupported builtin function type: "+source.getOpCode());
}
setIdentifierParams(currBuiltinOp, source.getOutput());
currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
return currBuiltinOp;
}
private void setBlockSizeAndRefreshSizeInfo(Hop in, Hop out) {
HopRewriteUtils.setOutputBlocksizes(out, in.getRowsInBlock(), in.getColsInBlock());
HopRewriteUtils.copyLineNumbers(in, out);
out.refreshSizeInformation();
}
private ArrayList getALHopsForConvOpPoolingCOL2IM(Hop first, BuiltinFunctionExpression source, int skip, HashMap hops) throws ParseException {
ArrayList ret = new ArrayList();
ret.add(first);
Expression[] allExpr = source.getAllExpr();
for(int i = skip; i < allExpr.length; i++) {
if(i == 11) {
ret.add(processExpression(allExpr[7], null, hops)); // Make number of channels of images and filter the same
}
else
ret.add(processExpression(allExpr[i], null, hops));
}
return ret;
}
private ArrayList getALHopsForPoolingForwardIM2COL(Hop first, BuiltinFunctionExpression source, int skip, HashMap hops) throws ParseException {
ArrayList ret = new ArrayList();
ret.add(first);
Expression[] allExpr = source.getAllExpr();
if(skip != 1) {
throw new ParseException("Unsupported skip");
}
Expression numChannels = allExpr[6];
for(int i = skip; i < allExpr.length; i++) {
if(i == 10) {
ret.add(processExpression(numChannels, null, hops));
}
else
ret.add(processExpression(allExpr[i], null, hops));
}
return ret;
}
@SuppressWarnings("unused") //TODO remove if not used
private ArrayList getALHopsForConvOpPoolingIM2COL(Hop first, BuiltinFunctionExpression source, int skip, HashMap hops) throws ParseException {
ArrayList ret = new ArrayList();
ret.add(first);
Expression[] allExpr = source.getAllExpr();
int numImgIndex = -1;
if(skip == 1) {
numImgIndex = 5;
}
else if(skip == 2) {
numImgIndex = 6;
}
else {
throw new ParseException("Unsupported skip");
}
for(int i = skip; i < allExpr.length; i++) {
if(i == numImgIndex) { // skip=1 ==> i==5 and skip=2 => i==6
Expression numImg = allExpr[numImgIndex];
Expression numChannels = allExpr[numImgIndex+1];
BinaryExpression tmp = new BinaryExpression(org.apache.sysml.parser.Expression.BinaryOp.MULT,
numImg.getFilename(), numImg.getBeginLine(), numImg.getBeginColumn(), numImg.getEndLine(), numImg.getEndColumn());
tmp.setLeft(numImg);
tmp.setRight(numChannels);
ret.add(processTempIntExpression(tmp, hops));
ret.add(processExpression(new IntIdentifier(1, numImg.getFilename(), numImg.getBeginLine(), numImg.getBeginColumn(),
numImg.getEndLine(), numImg.getEndColumn()), null, hops));
i++;
}
else
ret.add(processExpression(allExpr[i], null, hops));
}
return ret;
}
private ArrayList getALHopsForConvOp(Hop first, BuiltinFunctionExpression source, int skip, HashMap hops) throws ParseException {
ArrayList ret = new ArrayList();
ret.add(first);
Expression[] allExpr = source.getAllExpr();
for(int i = skip; i < allExpr.length; i++) {
ret.add(processExpression(allExpr[i], null, hops));
}
return ret;
}
public void setIdentifierParams(Hop h, Identifier id) {
if( id.getDim1()>= 0 )
h.setDim1(id.getDim1());
if( id.getDim2()>= 0 )
h.setDim2(id.getDim2());
if( id.getNnz()>= 0 )
h.setNnz(id.getNnz());
h.setRowsInBlock(id.getRowsInBlock());
h.setColsInBlock(id.getColumnsInBlock());
}
public void setIdentifierParams(Hop h, Hop source) {
h.setDim1(source.getDim1());
h.setDim2(source.getDim2());
h.setNnz(source.getNnz());
h.setRowsInBlock(source.getRowsInBlock());
h.setColsInBlock(source.getColsInBlock());
}
/**
*
* @param prog
* @param pWrites
* @throws LanguageException
*/
private boolean prepareReadAfterWrite( DMLProgram prog, HashMap pWrites )
throws LanguageException
{
boolean ret = false;
//process functions
/*MB: for the moment we only support read-after-write in the main program
for( FunctionStatementBlock fsb : prog.getFunctionStatementBlocks() )
ret |= prepareReadAfterWrite(fsb, pWrites);
*/
//process main program
for( StatementBlock sb : prog.getStatementBlocks() )
ret |= prepareReadAfterWrite(sb, pWrites);
return ret;
}
/**
*
* @param sb
* @param pWrites
*/
private boolean prepareReadAfterWrite( StatementBlock sb, HashMap pWrites )
{
boolean ret = false;
if(sb instanceof FunctionStatementBlock)
{
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
for (StatementBlock csb : fstmt.getBody())
ret |= prepareReadAfterWrite(csb, pWrites);
}
else if(sb instanceof WhileStatementBlock)
{
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
for (StatementBlock csb : wstmt.getBody())
ret |= prepareReadAfterWrite(csb, pWrites);
}
else if(sb instanceof IfStatementBlock)
{
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement)isb.getStatement(0);
for (StatementBlock csb : istmt.getIfBody())
ret |= prepareReadAfterWrite(csb, pWrites);
for (StatementBlock csb : istmt.getElseBody())
ret |= prepareReadAfterWrite(csb, pWrites);
}
else if(sb instanceof ForStatementBlock) //incl parfor
{
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
for (StatementBlock csb : fstmt.getBody())
ret |= prepareReadAfterWrite(csb, pWrites);
}
else //generic (last-level)
{
for( Statement s : sb.getStatements() )
{
//collect persistent write information
if( s instanceof OutputStatement )
{
OutputStatement os = (OutputStatement) s;
String pfname = os.getExprParam(DataExpression.IO_FILENAME).toString();
DataIdentifier di = (DataIdentifier) os.getSource().getOutput();
pWrites.put(pfname, di);
}
//propagate size information into reads-after-write
else if( s instanceof AssignmentStatement
&& ((AssignmentStatement)s).getSource() instanceof DataExpression )
{
DataExpression dexpr = (DataExpression) ((AssignmentStatement)s).getSource();
if( dexpr.isRead() ){
String pfname = dexpr.getVarParam(DataExpression.IO_FILENAME).toString();
if( pWrites.containsKey(pfname) && !pfname.trim().isEmpty() ) //found read-after-write
{
//update read with essential write meta data
DataIdentifier di = pWrites.get(pfname);
FormatType ft = (di.getFormatType()!=null) ? di.getFormatType() : FormatType.TEXT;
dexpr.addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(ft.toString(),di.getFilename(),di.getBeginLine(),di.getBeginColumn(),di.getEndLine(),di.getEndColumn()));
if( di.getDim1()>=0 )
dexpr.addVarParam(DataExpression.READROWPARAM, new IntIdentifier(di.getDim1(),di.getFilename(),di.getBeginLine(),di.getBeginColumn(),di.getEndLine(),di.getEndColumn()));
if( di.getDim2()>=0 )
dexpr.addVarParam(DataExpression.READCOLPARAM, new IntIdentifier(di.getDim2(),di.getFilename(),di.getBeginLine(),di.getBeginColumn(),di.getEndLine(),di.getEndColumn()));
if( di.getValueType()!=ValueType.UNKNOWN )
dexpr.addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier(di.getValueType().toString(),di.getFilename(),di.getBeginLine(),di.getBeginColumn(),di.getEndLine(),di.getEndColumn()));
if( di.getDataType()!=DataType.UNKNOWN )
dexpr.addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier(di.getDataType().toString(),di.getFilename(),di.getBeginLine(),di.getBeginColumn(),di.getEndLine(),di.getEndColumn()));
ret = true;
}
}
}
}
}
return ret;
}
}