All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.hops.codegen.cplan.CNodeTpl Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.codegen.cplan;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;

import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.parser.Expression.DataType;

public abstract class CNodeTpl extends CNode implements Cloneable
{
	public CNodeTpl(ArrayList inputs, CNode output ) {
		if(inputs.size() < 1)
			throw new RuntimeException("Cannot pass empty inputs to the CNodeTpl");

		for(CNode input : inputs)
			addInput(input);
		_output = output;	
	}
	
	public void addInput(CNode in) {
		//check for duplicate entries or literals
		if( containsInput(in) || in.isLiteral() )
			return;
		
		_inputs.add(in);
	}
	
	public void cleanupInputs(HashSet filter) {
		ArrayList tmp = new ArrayList();
		for( CNode in : _inputs )
			if( in instanceof CNodeData && filter.contains(((CNodeData) in).getHopID()) )
				tmp.add(in);
		_inputs = tmp;
	}
	
	public String[] getInputNames() {
		String[] ret = new String[_inputs.size()];
		for( int i=0; i<_inputs.size(); i++ )
			ret[i] = _inputs.get(i).getVarname();
		return ret;
	}
	
	public void resetVisitStatusOutputs() {
		getOutput().resetVisitStatus();
	}
	
	public String codegen() {
		return codegen(false);
	}
	
	public abstract CNodeTpl clone();
	
	public abstract SpoofOutputDimsType getOutputDimType();
	
	public abstract String getTemplateInfo();
	
	protected void renameInputs(ArrayList inputs, int startIndex) {
		renameInputs(Collections.singletonList(_output), inputs, startIndex);
	}
	
	protected void renameInputs(List outputs, ArrayList inputs, int startIndex) {
		//create map of hopID to data nodes with new names, used for CSE
		HashMap nodes = new HashMap();
		for(int i=startIndex, sPos=0, mPos=0; i < inputs.size(); i++) {
			CNode cnode = inputs.get(i);
			if( cnode instanceof CNodeData && ((CNodeData)cnode).isLiteral() )
				continue;
			CNodeData cdata = (CNodeData)cnode;
			if( cdata.getDataType() == DataType.SCALAR  || ( cdata.getNumCols() == 0 && cdata.getNumRows() == 0) ) 
				nodes.put(cdata.getHopID(), new CNodeData(cdata, "scalars["+ mPos++ +"]"));
			else
				nodes.put(cdata.getHopID(), new CNodeData(cdata, "b["+ sPos++ +"]"));
		}
		
		//single pass to replace all names
		for( CNode output : outputs )
			rReplaceDataNode(output, nodes, new HashMap());
	}
	
	protected void rReplaceDataNode( CNode root, CNode input, String newName ) {
		if( !(input instanceof CNodeData) )
			return;
			
		//create temporary name mapping
		HashMap names = new HashMap();
		CNodeData tmp = (CNodeData)input;
		names.put(tmp.getHopID(), new CNodeData(tmp, newName));
		
		rReplaceDataNode(root, names, new HashMap());
	}
	
	protected void rReplaceDataNode( ArrayList roots, CNode input, String newName ) {
		if( !(input instanceof CNodeData) )
			return;
			
		//create temporary name mapping
		HashMap names = new HashMap();
		CNodeData tmp = (CNodeData)input;
		names.put(tmp.getHopID(), new CNodeData(tmp, newName));
		
		for( CNode root : roots )
			rReplaceDataNode(root, names, new HashMap());
	}
	
	/**
	 * Recursively searches for data nodes and replaces them if found.
	 * 
	 * @param node current node in recursive descend
	 * @param dnodes prepared data nodes, identified by own hop id
	 * @param lnodes memoized lookup nodes, identified by data node hop id
	 */
	protected void rReplaceDataNode( CNode node, HashMap dnodes, HashMap lnodes ) 
	{	
		for( int i=0; i memo, UnaryType lookupType ) 
	{
		for( int i=0; i input1, ArrayList input2) {
		boolean ret = (current1.getInput().size() == current2.getInput().size());
		
		//process childs recursively
		for( int i=0; ret && i current1, ArrayList current2, ArrayList input1, ArrayList input2) {
		boolean ret = (current1.size() == current2.size());
		for( int i=0; ret && i inputs, CNodeData probe) {
		for( int i=0; i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy