All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.hops.codegen.cplan.CNodeRow Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.codegen.cplan;

import java.util.ArrayList;

import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysml.runtime.util.UtilFunctions;

public class CNodeRow extends CNodeTpl
{
	private static final String TEMPLATE = 
			  "package codegen;\n"
			+ "import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;\n"
			+ "import org.apache.sysml.runtime.codegen.SpoofOperator.SideInput;\n"
			+ "import org.apache.sysml.runtime.codegen.SpoofRowwise;\n"
			+ "import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType;\n"
			+ "import org.apache.commons.math3.util.FastMath;\n"
			+ "\n"
			+ "public final class %TMP% extends SpoofRowwise { \n"
			+ "  public %TMP%() {\n"
			+ "    super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n"
			+ "  }\n"
			+ "  protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, int rix) { \n"
			+ "%BODY_dense%"
			+ "  }\n"
			+ "  protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, int rix) { \n"
			+ "%BODY_sparse%"
			+ "  }\n"
			+ "}\n";

	private static final String TEMPLATE_ROWAGG_OUT  = "    c[rix] = %IN%;\n";
	private static final String TEMPLATE_FULLAGG_OUT = "    c[0] += %IN%;\n";
	private static final String TEMPLATE_NOAGG_OUT   = "    LibSpoofPrimitives.vectWrite(%IN%, c, ci, %LEN%);\n";
	
	public CNodeRow(ArrayList inputs, CNode output ) {
		super(inputs, output);
	}
	
	private RowType _type = null; //access pattern 
	private long _constDim2 = -1; //constant number of output columns
	private int _numVectors = -1; //number of intermediate vectors
	
	public void setRowType(RowType type) {
		_type = type;
		_hash = 0;
	}
	
	public RowType getRowType() {
		return _type;
	}
	
	public void setNumVectorIntermediates(int num) {
		_numVectors = num;
		_hash = 0;
	}
	
	public int getNumVectorIntermediates() {
		return _numVectors;
	}
	
	public void setConstDim2(long dim2) {
		_constDim2 = dim2;
		_hash = 0;
	}
	
	public long getConstDim2() {
		return _constDim2;
	}
	
	@Override
	public void renameInputs() {
		rRenameDataNode(_output, _inputs.get(0), "a"); // input matrix
		renameInputs(_inputs, 1);
	}
	
	@Override
	public String codegen(boolean sparse) {
		// note: ignore sparse flag, generate both
		String tmp = TEMPLATE;
		
		//generate dense/sparse bodies
		String tmpDense = _output.codegen(false)
			+ getOutputStatement(_output.getVarname());
		_output.resetGenerated();
		String tmpSparse = _output.codegen(true)
			+ getOutputStatement(_output.getVarname());
		tmp = tmp.replace("%TMP%", createVarname());
		tmp = tmp.replace("%BODY_dense%", tmpDense);
		tmp = tmp.replace("%BODY_sparse%", tmpSparse);
		
		//replace outputs 
		tmp = tmp.replace("%OUT%", "c");
		tmp = tmp.replace("%POSOUT%", "0");
		
		//replace size information
		tmp = tmp.replace("%LEN%", "len");
		
		//replace colvector information and number of vector intermediates
		tmp = tmp.replace("%TYPE%", _type.name());
		tmp = tmp.replace("%CONST_DIM2%", String.valueOf(_constDim2));
		tmp = tmp.replace("%TB1%", String.valueOf(
			TemplateUtils.containsBinary(_output, BinType.VECT_MATRIXMULT)));
		tmp = tmp.replace("%VECT_MEM%", String.valueOf(_numVectors));
		
		return tmp;
	}
	
	private String getOutputStatement(String varName) {
		switch( _type ) {
			case NO_AGG:
			case NO_AGG_B1:
			case NO_AGG_CONST:
				return TEMPLATE_NOAGG_OUT.replace("%IN%", varName)
					.replace("%LEN%", _output.getVarname()+".length");
			case FULL_AGG:
				return TEMPLATE_FULLAGG_OUT.replace("%IN%", varName);
			case ROW_AGG:
				return TEMPLATE_ROWAGG_OUT.replace("%IN%", varName);
			default:
				return ""; //_type.isColumnAgg()
		}
	}

	@Override
	public void setOutputDims() {
		// TODO Auto-generated method stub
		
	}

	@Override
	public SpoofOutputDimsType getOutputDimType() {
		switch( _type ) {
			case NO_AGG:        return SpoofOutputDimsType.INPUT_DIMS;
			case NO_AGG_B1:     return SpoofOutputDimsType.ROW_RANK_DIMS;
			case NO_AGG_CONST:  return SpoofOutputDimsType.INPUT_DIMS_CONST2; 
			case FULL_AGG:      return SpoofOutputDimsType.SCALAR;
			case ROW_AGG:       return SpoofOutputDimsType.ROW_DIMS;
			case COL_AGG:       return SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector
			case COL_AGG_T:     return SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector
			case COL_AGG_B1:    return SpoofOutputDimsType.COLUMN_RANK_DIMS; 
			case COL_AGG_B1_T:  return SpoofOutputDimsType.COLUMN_RANK_DIMS_T;
			case COL_AGG_B1R:   return SpoofOutputDimsType.RANK_DIMS_COLS;
			case COL_AGG_CONST: return SpoofOutputDimsType.VECT_CONST2;
			default:
				throw new RuntimeException("Unsupported row type: "+_type.toString());
		}
	}
	
	@Override
	public CNodeTpl clone() {
		CNodeRow tmp = new CNodeRow(_inputs, _output);
		tmp.setRowType(_type);
		tmp.setNumVectorIntermediates(_numVectors);
		return tmp;
	}
	
	@Override
	public int hashCode() {
		if( _hash == 0 ) {
			int h = UtilFunctions.intHashCode(super.hashCode(), _type.hashCode());
			h = UtilFunctions.intHashCode(h, Long.hashCode(_constDim2));
			_hash = UtilFunctions.intHashCode(h, Integer.hashCode(_numVectors));
		}
		return _hash;
	}
	
	@Override 
	public boolean equals(Object o) {
		if(!(o instanceof CNodeRow))
			return false;
		
		CNodeRow that = (CNodeRow)o;
		return super.equals(o)
			&& _type == that._type
			&& _numVectors == that._numVectors
			&& _constDim2 == that._constDim2
			&& equalInputReferences(
				_output, that._output, _inputs, that._inputs);
	}
	
	@Override
	public String getTemplateInfo() {
		StringBuilder sb = new StringBuilder();
		sb.append("SPOOF ROWAGGREGATE [type=");
		sb.append(_type.name());
		sb.append(", reqVectMem=");
		sb.append(_numVectors);
		sb.append("]");
		return sb.toString();
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy