org.apache.sysml.hops.rewrite.RewriteConstantFolding Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.rewrite;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.Hop.VisitStatus;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.compile.Dag;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
/**
* Rule: Constant Folding. For all statement blocks,
* eliminate simple binary expressions of literals within dags by
* computing them and replacing them with a new Literal op once.
* For the moment, this only applies within a dag, later this should be
* extended across statements block (global, inter-procedure).
*/
public class RewriteConstantFolding extends HopRewriteRule
{
private static final String TMP_VARNAME = "__cf_tmp";
//reuse basic execution runtime
private static ProgramBlock _tmpPB = null;
private static ExecutionContext _tmpEC = null;
@Override
public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state)
throws HopsException
{
if( roots == null )
return null;
for( int i=0; i 0 ) //broot is NOT a DAG root
{
for( int i=0; i dag = new Dag();
Recompiler.rClearLops(tmpWrite); //prevent lops reuse
Lop lops = tmpWrite.constructLops(); //reconstruct lops
lops.addToDag( dag );
ArrayList inst = dag.getJobs(null, ConfigurationManager.getConfig());
//execute instructions
ExecutionContext ec = getExecutionContext();
ProgramBlock pb = getProgramBlock();
pb.setInstructions( inst );
pb.execute( ec );
//get scalar result (check before invocation)
ScalarObject so = (ScalarObject) ec.getVariable(TMP_VARNAME);
LiteralOp literal = null;
switch( bop.getValueType() ){
case DOUBLE: literal = new LiteralOp(so.getDoubleValue()); break;
case INT: literal = new LiteralOp(so.getLongValue()); break;
case BOOLEAN: literal = new LiteralOp(so.getBooleanValue()); break;
case STRING: literal = new LiteralOp(so.getStringValue()); break;
default:
throw new HopsException("Unsupported literal value type: "+bop.getValueType());
}
//cleanup
tmpWrite.getInput().clear();
bop.getParent().remove(tmpWrite);
pb.setInstructions(null);
ec.getVariables().removeAll();
//set literal properties (scalar)
literal.setDim1(0);
literal.setDim2(0);
literal.setRowsInBlock(-1);
literal.setColsInBlock(-1);
//System.out.println("Constant folded in "+time.stop()+"ms.");
return literal;
}
/**
*
* @return
* @throws DMLRuntimeException
*/
private static ProgramBlock getProgramBlock()
throws DMLRuntimeException
{
if( _tmpPB == null )
_tmpPB = new ProgramBlock( new Program() );
return _tmpPB;
}
/**
*
* @return
*/
private static ExecutionContext getExecutionContext()
{
if( _tmpEC == null )
_tmpEC = ExecutionContextFactory.createContext();
return _tmpEC;
}
/**
*
* @param hop
* @return
*/
private boolean isApplicableBinaryOp( Hop hop )
{
ArrayList in = hop.getInput();
return ( hop instanceof BinaryOp
&& in.get(0) instanceof LiteralOp
&& in.get(1) instanceof LiteralOp
&& ((BinaryOp)hop).getOp()!=OpOp2.CBIND
&& ((BinaryOp)hop).getOp()!=OpOp2.RBIND);
//string append is rejected although possible because it
//messes up the explain runtime output due to introduced \n
}
/**
*
* @param hop
* @return
*/
private boolean isApplicableUnaryOp( Hop hop )
{
ArrayList in = hop.getInput();
return ( hop instanceof UnaryOp
&& in.get(0) instanceof LiteralOp
&& HopRewriteUtils.isValueTypeCast(((UnaryOp)hop).getOp()));
}
/**
*
* @param hop
* @return
* @throws HopsException
*/
private boolean isApplicableFalseConjunctivePredicate( Hop hop )
throws HopsException
{
ArrayList in = hop.getInput();
return ( hop instanceof BinaryOp
&& ((BinaryOp)hop).getOp()==OpOp2.AND
&& ( (in.get(0) instanceof LiteralOp && !((LiteralOp)in.get(0)).getBooleanValue())
||(in.get(1) instanceof LiteralOp && !((LiteralOp)in.get(1)).getBooleanValue())) );
}
/**
*
* @param hop
* @return
* @throws HopsException
*/
private boolean isApplicableTrueDisjunctivePredicate( Hop hop )
throws HopsException
{
ArrayList in = hop.getInput();
return ( hop instanceof BinaryOp
&& ((BinaryOp)hop).getOp()==OpOp2.OR
&& ( (in.get(0) instanceof LiteralOp && ((LiteralOp)in.get(0)).getBooleanValue())
||(in.get(1) instanceof LiteralOp && ((LiteralOp)in.get(1)).getBooleanValue())) );
}
}