org.apache.sysml.hops.codegen.cplan.CNodeTpl Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.codegen.cplan;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.parser.Expression.DataType;
public abstract class CNodeTpl extends CNode implements Cloneable
{
public CNodeTpl(ArrayList inputs, CNode output ) {
if(inputs.size() < 1)
throw new RuntimeException("Cannot pass empty inputs to the CNodeTpl");
for(CNode input : inputs)
addInput(input);
_output = output;
}
public void addInput(CNode in) {
//check for duplicate entries or literals
if( containsInput(in) || in.isLiteral() )
return;
_inputs.add(in);
}
public void cleanupInputs(HashSet filter) {
ArrayList tmp = new ArrayList();
for( CNode in : _inputs )
if( in instanceof CNodeData && filter.contains(((CNodeData) in).getHopID()) )
tmp.add(in);
_inputs = tmp;
}
public String[] getInputNames() {
String[] ret = new String[_inputs.size()];
for( int i=0; i<_inputs.size(); i++ )
ret[i] = _inputs.get(i).getVarname();
return ret;
}
public void resetVisitStatusOutputs() {
getOutput().resetVisitStatus();
}
public String codegen() {
return codegen(false);
}
public abstract CNodeTpl clone();
public abstract SpoofOutputDimsType getOutputDimType();
public abstract String getTemplateInfo();
protected void renameInputs(ArrayList inputs, int startIndex) {
renameInputs(Collections.singletonList(_output), inputs, startIndex);
}
protected void renameInputs(List outputs, ArrayList inputs, int startIndex) {
//create map of hopID to data nodes with new names, used for CSE
HashMap nodes = new HashMap();
for(int i=startIndex, sPos=0, mPos=0; i < inputs.size(); i++) {
CNode cnode = inputs.get(i);
if( cnode instanceof CNodeData && ((CNodeData)cnode).isLiteral() )
continue;
CNodeData cdata = (CNodeData)cnode;
if( cdata.getDataType() == DataType.SCALAR || ( cdata.getNumCols() == 0 && cdata.getNumRows() == 0) )
nodes.put(cdata.getHopID(), new CNodeData(cdata, "scalars["+ mPos++ +"]"));
else
nodes.put(cdata.getHopID(), new CNodeData(cdata, "b["+ sPos++ +"]"));
}
//single pass to replace all names
for( CNode output : outputs )
rReplaceDataNode(output, nodes, new HashMap());
}
protected void rReplaceDataNode( CNode root, CNode input, String newName ) {
if( !(input instanceof CNodeData) )
return;
//create temporary name mapping
HashMap names = new HashMap();
CNodeData tmp = (CNodeData)input;
names.put(tmp.getHopID(), new CNodeData(tmp, newName));
rReplaceDataNode(root, names, new HashMap());
}
protected void rReplaceDataNode( ArrayList roots, CNode input, String newName ) {
if( !(input instanceof CNodeData) )
return;
//create temporary name mapping
HashMap names = new HashMap();
CNodeData tmp = (CNodeData)input;
names.put(tmp.getHopID(), new CNodeData(tmp, newName));
for( CNode root : roots )
rReplaceDataNode(root, names, new HashMap());
}
/**
* Recursively searches for data nodes and replaces them if found.
*
* @param node current node in recursive descend
* @param dnodes prepared data nodes, identified by own hop id
* @param lnodes memoized lookup nodes, identified by data node hop id
*/
protected void rReplaceDataNode( CNode node, HashMap dnodes, HashMap lnodes )
{
for( int i=0; i memo, UnaryType lookupType )
{
for( int i=0; i input1, ArrayList input2) {
boolean ret = (current1.getInput().size() == current2.getInput().size());
//process childs recursively
for( int i=0; ret && i current1, ArrayList current2, ArrayList input1, ArrayList input2) {
boolean ret = (current1.size() == current2.size());
for( int i=0; ret && i inputs, CNodeData probe) {
for( int i=0; i