All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.hops.codegen.SpoofFusedOp Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.codegen;

import java.util.ArrayList;

import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.MultiThreadedHop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.MemoTable;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopProperties.ExecType;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.SpoofFused;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.codegen.SpoofRowwise;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;

public class SpoofFusedOp extends Hop implements MultiThreadedHop
{
	public enum SpoofOutputDimsType {
		INPUT_DIMS,
		INPUT_DIMS_CONST2,
		ROW_DIMS,
		COLUMN_DIMS_ROWS,
		COLUMN_DIMS_COLS,
		RANK_DIMS_COLS,
		SCALAR,
		MULTI_SCALAR,
		ROW_RANK_DIMS, // right wdivmm, row mm
		COLUMN_RANK_DIMS,  // left wdivmm, row mm
		COLUMN_RANK_DIMS_T,
		VECT_CONST2;
	}
	
	private Class _class = null;
	private boolean _distSupported = false;
	private int _numThreads = -1;
	private long _constDim2 = -1;
	private SpoofOutputDimsType _dimsType;
	
	public SpoofFusedOp ( ) {
	
	}
	
	public SpoofFusedOp( String name, DataType dt, ValueType vt, Class cla, boolean dist, SpoofOutputDimsType type ) {
		super(name, dt, vt);
		_class = cla;
		_distSupported = dist;
		_dimsType = type;
	}

	@Override
	public void checkArity() throws HopsException {}

	@Override
	public void setMaxNumThreads(int k) {
		_numThreads = k;
	}

	@Override
	public int getMaxNumThreads() {
		return _numThreads;
	}

	@Override
	public boolean allowsAllExecTypes() {
		return _distSupported;
	}
	
	public void setConstDim2(long constDim2) {
		_constDim2 = constDim2;
	}

	@Override
	protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
		return _class.getGenericSuperclass().equals(SpoofRowwise.class) ?
			OptimizerUtils.estimateSize(dim1, dim2) :
			OptimizerUtils.estimatePartitionedSizeExactSparsity(
				dim1, dim2, getRowsInBlock(), getColsInBlock(), nnz);
	}

	@Override
	protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
		return 0;
	}
	
	@Override
	public Lop constructLops() throws HopsException, LopsException {
		if( getLops() != null )
			return getLops();
		
		ExecType et = optFindExecType();
		
		ArrayList inputs = new ArrayList<>();
		for( Hop c : getInput() )
			inputs.add(c.constructLops());
		
		int k = OptimizerUtils.getConstrainedNumThreads(_numThreads);
		SpoofFused lop = new SpoofFused(inputs, getDataType(), getValueType(), _class, k, et);
		setOutputDimensions(lop);
		setLineNumbers(lop);
		setLops(lop);
	
		return lop;
	}
	
	@Override
	protected ExecType optFindExecType() throws HopsException {
		
		checkAndSetForcedPlatform();
		
		if( _etypeForced != null ) {		
			_etype = _etypeForced;
		}
		else {
			_etype = findExecTypeByMemEstimate();
			checkAndSetInvalidCPDimsAndSize();
		}
		
		//ensure valid execution plans
		if( _etype == ExecType.MR )
			_etype = ExecType.CP;
		
		return _etype;
	}

	@Override
	public String getOpString() {
		return "spoof("+_class.getSimpleName()+")";
	}
	
	@Override
	protected long[] inferOutputCharacteristics( MemoTable memo )
	{
		long[] ret = null;
	
		//get statistics of main input
		MatrixCharacteristics mc = memo.getAllInputStats(getInput().get(0));
		
		if( mc.dimsKnown() ) {
			switch(_dimsType)
			{
				case ROW_DIMS:
					ret = new long[]{mc.getRows(), 1, -1};
					break;
				case COLUMN_DIMS_ROWS:
					ret = new long[]{mc.getCols(), 1, -1};
					break;
				case COLUMN_DIMS_COLS:
					ret = new long[]{1, mc.getCols(), -1};
					break;
				case RANK_DIMS_COLS: {
					MatrixCharacteristics mc2 = memo.getAllInputStats(getInput().get(1));
					if( mc2.dimsKnown() )
						ret = new long[]{1, mc2.getCols(), -1};
					break;
				}
				case INPUT_DIMS:
					ret = new long[]{mc.getRows(), mc.getCols(), -1};
					break;
				case INPUT_DIMS_CONST2:
					ret = new long[]{mc.getRows(), _constDim2, -1};
					break;
				case VECT_CONST2:
					ret = new long[]{1, _constDim2, -1};
					break;	
				case SCALAR:
					ret = new long[]{0, 0, -1};
					break;
				case MULTI_SCALAR:
					//dim2 statically set from outside
					ret = new long[]{1, _dim2, -1};
					break;
				case ROW_RANK_DIMS: {
					MatrixCharacteristics mc2 = memo.getAllInputStats(getInput().get(1));
					if( mc2.dimsKnown() )
						ret = new long[]{mc.getRows(), mc2.getCols(), -1};
					break;
				}
				case COLUMN_RANK_DIMS: {
					MatrixCharacteristics mc2 = memo.getAllInputStats(getInput().get(1));
					if( mc2.dimsKnown() )
						ret = new long[]{mc.getCols(), mc2.getCols(), -1};
					break;
				}
				case COLUMN_RANK_DIMS_T: {
					MatrixCharacteristics mc2 = memo.getAllInputStats(getInput().get(1));
					if( mc2.dimsKnown() )
						ret = new long[]{mc2.getCols(), mc.getCols(), -1};
					break;
				}
				default:
					throw new RuntimeException("Failed to infer worst-case size information "
							+ "for type: "+_dimsType.toString());
			}
		}
		
		return ret;
	}
	
	@Override
	public void refreshSizeInformation() {
		switch(_dimsType)
		{
			case ROW_DIMS:
				setDim1(getInput().get(0).getDim1());
				setDim2(1);
				break;
			case COLUMN_DIMS_ROWS:
				setDim1(getInput().get(0).getDim2());
				setDim2(1);
				break;
			case COLUMN_DIMS_COLS:
				setDim1(1);
				setDim2(getInput().get(0).getDim2());
				break;
			case RANK_DIMS_COLS:
				setDim1(1);
				setDim2(getInput().get(1).getDim2());
				break;
			case INPUT_DIMS:
				setDim1(getInput().get(0).getDim1());
				setDim2(getInput().get(0).getDim2());
				break;
			case INPUT_DIMS_CONST2:
				setDim1(getInput().get(0).getDim1());
				setDim2(_constDim2);
				break;
			case VECT_CONST2:
				setDim1(1);
				setDim2(_constDim2);
				break;
			case SCALAR:
				setDim1(0);
				setDim2(0);
				break;
			case MULTI_SCALAR:
				setDim1(1); //row vector
				//dim2 statically set from outside
				break;
			case ROW_RANK_DIMS:
				setDim1(getInput().get(0).getDim1());
				setDim2(getInput().get(1).getDim2());
				break;
			case COLUMN_RANK_DIMS:
				setDim1(getInput().get(0).getDim2());
				setDim2(getInput().get(1).getDim2());
				break;
			case COLUMN_RANK_DIMS_T:
				setDim1(getInput().get(1).getDim2());
				setDim2(getInput().get(0).getDim2());
				break;	
			default:
				throw new RuntimeException("Failed to refresh size information "
					+ "for type: "+_dimsType.toString());
		}
	}

	@Override
	public Object clone() throws CloneNotSupportedException 
	{
		SpoofFusedOp ret = new SpoofFusedOp();
		
		//copy generic attributes
		ret.clone(this, false);
		
		//copy specific attributes
		ret._class = _class;
		ret._distSupported = _distSupported;
		ret._numThreads = _numThreads;
		ret._dimsType = _dimsType;
		return ret;
	}
	
	@Override
	public boolean compare( Hop that )
	{
		if( !(that instanceof SpoofFusedOp) )
			return false;
		
		SpoofFusedOp that2 = (SpoofFusedOp)that;		
		boolean ret = ( _class.equals(that2._class)
				&& _distSupported == that2._distSupported
				&& _numThreads == that2._numThreads
				&& getInput().size() == that2.getInput().size());
		
		if( ret ) {
			for( int i=0; i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy