All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.matrix.MatrixCharacteristics Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */


package org.apache.sysml.runtime.matrix;

import java.io.Serializable;
import java.util.HashMap;

import org.apache.sysml.lops.MMTSJ.MMTSJType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.instructions.mr.AggregateBinaryInstruction;
import org.apache.sysml.runtime.instructions.mr.AggregateInstruction;
import org.apache.sysml.runtime.instructions.mr.AggregateUnaryInstruction;
import org.apache.sysml.runtime.instructions.mr.AppendInstruction;
import org.apache.sysml.runtime.instructions.mr.BinaryInstruction;
import org.apache.sysml.runtime.instructions.mr.BinaryMInstruction;
import org.apache.sysml.runtime.instructions.mr.BinaryMRInstructionBase;
import org.apache.sysml.runtime.instructions.mr.CM_N_COVInstruction;
import org.apache.sysml.runtime.instructions.mr.CombineBinaryInstruction;
import org.apache.sysml.runtime.instructions.mr.CombineTernaryInstruction;
import org.apache.sysml.runtime.instructions.mr.CombineUnaryInstruction;
import org.apache.sysml.runtime.instructions.mr.CumulativeAggregateInstruction;
import org.apache.sysml.runtime.instructions.mr.DataGenMRInstruction;
import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction;
import org.apache.sysml.runtime.instructions.mr.GroupedAggregateMInstruction;
import org.apache.sysml.runtime.instructions.mr.MMTSJMRInstruction;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.instructions.mr.MapMultChainInstruction;
import org.apache.sysml.runtime.instructions.mr.MatrixReshapeMRInstruction;
import org.apache.sysml.runtime.instructions.mr.PMMJMRInstruction;
import org.apache.sysml.runtime.instructions.mr.ParameterizedBuiltinMRInstruction;
import org.apache.sysml.runtime.instructions.mr.QuaternaryInstruction;
import org.apache.sysml.runtime.instructions.mr.RandInstruction;
import org.apache.sysml.runtime.instructions.mr.RangeBasedReIndexInstruction;
import org.apache.sysml.runtime.instructions.mr.ReblockInstruction;
import org.apache.sysml.runtime.instructions.mr.RemoveEmptyMRInstruction;
import org.apache.sysml.runtime.instructions.mr.ReorgInstruction;
import org.apache.sysml.runtime.instructions.mr.ReplicateInstruction;
import org.apache.sysml.runtime.instructions.mr.ScalarInstruction;
import org.apache.sysml.runtime.instructions.mr.SeqInstruction;
import org.apache.sysml.runtime.instructions.mr.TernaryInstruction;
import org.apache.sysml.runtime.instructions.mr.UaggOuterChainInstruction;
import org.apache.sysml.runtime.instructions.mr.UnaryInstruction;
import org.apache.sysml.runtime.instructions.mr.UnaryMRInstructionBase;
import org.apache.sysml.runtime.instructions.mr.ZeroOutInstruction;
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
import org.apache.sysml.runtime.util.IndexRange;


public class MatrixCharacteristics implements Serializable
{
	private static final long serialVersionUID = 8300479822915546000L;

	private long numRows = -1;
	private long numColumns = -1;
	private int numRowsPerBlock = 1;
	private int numColumnsPerBlock = 1;
	private long nonZero = -1;
	
	public MatrixCharacteristics() {
	
	}
	
	public MatrixCharacteristics(long nr, long nc, int bnr, int bnc)
	{
		set(nr, nc, bnr, bnc);
	}

	public MatrixCharacteristics(long nr, long nc, int bnr, int bnc, long nnz)
	{
		set(nr, nc, bnr, bnc, nnz);
	}
	
	public MatrixCharacteristics(MatrixCharacteristics that)
	{
		set(that.numRows, that.numColumns, that.numRowsPerBlock, that.numColumnsPerBlock, that.nonZero);
	}

	public void set(long nr, long nc, int bnr, int bnc) {
		numRows = nr;
		numColumns = nc;
		numRowsPerBlock = bnr;
		numColumnsPerBlock = bnc;
	}
	
	public void set(long nr, long nc, int bnr, int bnc, long nnz) {
		numRows = nr;
		numColumns = nc;
		numRowsPerBlock = bnr;
		numColumnsPerBlock = bnc;
		nonZero = nnz;
	}
	
	public void set(MatrixCharacteristics that) {
		numRows = that.numRows;
		numColumns = that.numColumns;
		numRowsPerBlock = that.numRowsPerBlock;
		numColumnsPerBlock = that.numColumnsPerBlock;
		nonZero = that.nonZero;
	}
	
	public long getRows(){
		return numRows;
	}

	public long getCols(){
		return numColumns;
	}
	
	public int getRowsPerBlock() {
		return numRowsPerBlock;
	}
	
	public void setRowsPerBlock( int brlen){
		numRowsPerBlock = brlen;
	} 
	
	public int getColsPerBlock() {
		return numColumnsPerBlock;
	}
	
	public void setColsPerBlock( int bclen){
		numColumnsPerBlock = bclen;
	} 
	
	public long getNumRowBlocks(){
		return (long) Math.ceil((double)getRows() / getRowsPerBlock());
	}
	
	public long getNumColBlocks(){
		return (long) Math.ceil((double)getCols() / getColsPerBlock());
	}
	
	public String toString()
	{
		return "["+numRows+" x "+numColumns+", nnz="+nonZero
		+", blocks ("+numRowsPerBlock+" x "+numColumnsPerBlock+")]";
	}
	
	public void setDimension(long nr, long nc)
	{
		numRows = nr;
		numColumns = nc;
	}
	
	public void setBlockSize(int bnr, int bnc)
	{
		numRowsPerBlock = bnr;
		numColumnsPerBlock = bnc;
	}
	
	public void setNonZeros(long nnz) {
		nonZero = nnz;
	}
	
	public long getNonZeros() {
		return nonZero;
	}
	
	public boolean dimsKnown() {
		return ( numRows > 0 && numColumns > 0 );
	}
	
	public boolean dimsKnown(boolean includeNnz) {
		return ( numRows > 0 && numColumns > 0 && (!includeNnz || nonZero>=0));
	}
	
	public boolean rowsKnown() {
		return ( numRows > 0 );
	}

	public boolean colsKnown() {
		return ( numColumns > 0 );
	}
	
	public boolean nnzKnown() {
		return ( nonZero >= 0 );
	}
	
	public boolean mightHaveEmptyBlocks() 
	{
		long singleBlk =  Math.min(numRows, numRowsPerBlock) 
				        * Math.min(numColumns, numColumnsPerBlock);
		return !nnzKnown() || (nonZero < numRows*numColumns - singleBlk);
	}
	
	public static void reorg(MatrixCharacteristics dim, ReorgOperator op, 
			MatrixCharacteristics dimOut) throws DMLUnsupportedOperationException, DMLRuntimeException
	{
		op.fn.computeDimension(dim, dimOut);
	}
	
	public static void aggregateUnary(MatrixCharacteristics dim, AggregateUnaryOperator op, 
			MatrixCharacteristics dimOut) throws DMLUnsupportedOperationException, DMLRuntimeException
	{
		op.indexFn.computeDimension(dim, dimOut);
	}
	
	public static void aggregateBinary(MatrixCharacteristics dim1, MatrixCharacteristics dim2,
			AggregateBinaryOperator op, MatrixCharacteristics dimOut) 
	throws DMLUnsupportedOperationException
	{
		//set dimension
		dimOut.set(dim1.numRows, dim2.numColumns, dim1.numRowsPerBlock, dim2.numColumnsPerBlock);
	}
	
	public static void computeDimension(HashMap dims, MRInstruction ins) 
		throws DMLUnsupportedOperationException, DMLRuntimeException
	{
		MatrixCharacteristics dimOut=dims.get(ins.output);
		if(dimOut==null)
		{
			dimOut=new MatrixCharacteristics();
			dims.put(ins.output, dimOut);
		}
		
		if(ins instanceof ReorgInstruction)
		{
			ReorgInstruction realIns=(ReorgInstruction)ins;
			reorg(dims.get(realIns.input), (ReorgOperator)realIns.getOperator(), dimOut);
		}
		else if(ins instanceof AppendInstruction )
		{
			AppendInstruction realIns = (AppendInstruction)ins;
			MatrixCharacteristics in_dim1 = dims.get(realIns.input1);
			MatrixCharacteristics in_dim2 = dims.get(realIns.input2);
			if( realIns.isCBind() )
				dimOut.set(in_dim1.numRows, in_dim1.numColumns+in_dim2.numColumns, in_dim1.numRowsPerBlock, in_dim2.numColumnsPerBlock);
			else
				dimOut.set(in_dim1.numRows+in_dim2.numRows, in_dim1.numColumns, in_dim1.numRowsPerBlock, in_dim2.numColumnsPerBlock);
		}
		else if(ins instanceof CumulativeAggregateInstruction)
		{
			AggregateUnaryInstruction realIns=(AggregateUnaryInstruction)ins;
			MatrixCharacteristics in = dims.get(realIns.input);
			dimOut.set((long)Math.ceil( (double)in.getRows()/in.getRowsPerBlock()), in.getCols(), in.getRowsPerBlock(), in.getColsPerBlock());
		}
		else if(ins instanceof AggregateUnaryInstruction)
		{
			AggregateUnaryInstruction realIns=(AggregateUnaryInstruction)ins;
			aggregateUnary(dims.get(realIns.input), 
					(AggregateUnaryOperator)realIns.getOperator(), dimOut);
		}
		else if(ins instanceof AggregateBinaryInstruction)
		{
			AggregateBinaryInstruction realIns=(AggregateBinaryInstruction)ins;
			aggregateBinary(dims.get(realIns.input1), dims.get(realIns.input2),
					(AggregateBinaryOperator)realIns.getOperator(), dimOut);
		}
		else if(ins instanceof MapMultChainInstruction)
		{
			//output size independent of chain type
			MapMultChainInstruction realIns=(MapMultChainInstruction)ins;
			MatrixCharacteristics mc1 = dims.get(realIns.getInput1());
			MatrixCharacteristics mc2 = dims.get(realIns.getInput2());
			dimOut.set(mc1.numColumns, mc2.numColumns, mc1.numRowsPerBlock, mc1.numColumnsPerBlock);	
		}
		else if(ins instanceof QuaternaryInstruction)
		{
			QuaternaryInstruction realIns=(QuaternaryInstruction)ins;
			MatrixCharacteristics mc1 = dims.get(realIns.getInput1());
			MatrixCharacteristics mc2 = dims.get(realIns.getInput2());
			MatrixCharacteristics mc3 = dims.get(realIns.getInput3());
			realIns.computeMatrixCharacteristics(mc1, mc2, mc3, dimOut);
		}
		else if(ins instanceof ReblockInstruction)
		{
			ReblockInstruction realIns=(ReblockInstruction)ins;
			MatrixCharacteristics in_dim=dims.get(realIns.input);
			dimOut.set(in_dim.numRows, in_dim.numColumns, realIns.brlen, realIns.bclen, in_dim.nonZero);
		}
		else if( ins instanceof MatrixReshapeMRInstruction )
		{
			MatrixReshapeMRInstruction mrinst = (MatrixReshapeMRInstruction) ins;
			MatrixCharacteristics in_dim=dims.get(mrinst.input);
			dimOut.set(mrinst.getNumRows(),mrinst.getNumColunms(),in_dim.getRowsPerBlock(), in_dim.getColsPerBlock(), in_dim.getNonZeros());
		}
		else if(ins instanceof RandInstruction
				|| ins instanceof SeqInstruction) 
		{
			DataGenMRInstruction dataIns=(DataGenMRInstruction)ins;
			dimOut.set(dims.get(dataIns.getInput()));
		}
		else if( ins instanceof ReplicateInstruction )
		{
			ReplicateInstruction realIns=(ReplicateInstruction)ins;
			realIns.computeOutputDimension(dims.get(realIns.input), dimOut);
		}
		else if( ins instanceof ParameterizedBuiltinMRInstruction ) //before unary
		{
			ParameterizedBuiltinMRInstruction realIns = (ParameterizedBuiltinMRInstruction)ins;
			realIns.computeOutputCharacteristics(dims.get(realIns.input), dimOut);
		}
		else if(ins instanceof ScalarInstruction 
				|| ins instanceof AggregateInstruction
				||(ins instanceof UnaryInstruction && !(ins instanceof MMTSJMRInstruction))
				|| ins instanceof ZeroOutInstruction)
		{
			UnaryMRInstructionBase realIns=(UnaryMRInstructionBase)ins;
			dimOut.set(dims.get(realIns.input));
		}
		else if (ins instanceof MMTSJMRInstruction)
		{
			MMTSJMRInstruction mmtsj = (MMTSJMRInstruction)ins;
			MMTSJType tstype = mmtsj.getMMTSJType();
			MatrixCharacteristics mc = dims.get(mmtsj.input);
			dimOut.set( tstype.isLeft() ? mc.numColumns : mc.numRows,
					    tstype.isLeft() ? mc.numColumns : mc.numRows,
					     mc.numRowsPerBlock, mc.numColumnsPerBlock );
		}
		else if( ins instanceof PMMJMRInstruction )
		{
			PMMJMRInstruction pmmins = (PMMJMRInstruction) ins;
			MatrixCharacteristics mc = dims.get(pmmins.input2);
			dimOut.set( pmmins.getNumRows(),
					     mc.numColumns,
					     mc.numRowsPerBlock, mc.numColumnsPerBlock );
		}
		else if( ins instanceof RemoveEmptyMRInstruction )
		{
			RemoveEmptyMRInstruction realIns=(RemoveEmptyMRInstruction)ins;
			MatrixCharacteristics mc = dims.get(realIns.input1);
			if( realIns.isRemoveRows() )
				dimOut.set(realIns.getOutputLen(), mc.getCols(), mc.numRowsPerBlock, mc.numColumnsPerBlock);
			else
				dimOut.set(mc.getRows(), realIns.getOutputLen(), mc.numRowsPerBlock, mc.numColumnsPerBlock);
		}
		else if(ins instanceof UaggOuterChainInstruction) //needs to be checked before binary
		{
			UaggOuterChainInstruction realIns=(UaggOuterChainInstruction)ins;
			MatrixCharacteristics mc1 = dims.get(realIns.input1);
			MatrixCharacteristics mc2 = dims.get(realIns.input2);
			realIns.computeOutputCharacteristics(mc1, mc2, dimOut);
		}
		else if( ins instanceof GroupedAggregateMInstruction )
		{
			GroupedAggregateMInstruction realIns = (GroupedAggregateMInstruction) ins;
			MatrixCharacteristics mc1 = dims.get(realIns.input1);
			realIns.computeOutputCharacteristics(mc1, dimOut);
		}
		else if(ins instanceof BinaryInstruction || ins instanceof BinaryMInstruction || ins instanceof CombineBinaryInstruction )
		{
			BinaryMRInstructionBase realIns=(BinaryMRInstructionBase)ins;
			MatrixCharacteristics mc1 = dims.get(realIns.input1);
			MatrixCharacteristics mc2 = dims.get(realIns.input2);
			if(    mc1.getRows()>1 && mc1.getCols()==1 
				&& mc2.getRows()==1 && mc2.getCols()>1 ) //outer
			{
				dimOut.set(mc1.getRows(), mc2.getCols(), mc1.getRowsPerBlock(), mc2.getColsPerBlock());
			}
			else { //default case
				dimOut.set(mc1);	
			}
		}
		else if (ins instanceof CombineTernaryInstruction ) {
			TernaryInstruction realIns=(TernaryInstruction)ins;
			dimOut.set(dims.get(realIns.input1));
		}
		else if (ins instanceof CombineUnaryInstruction ) {
			dimOut.set( dims.get(((CombineUnaryInstruction) ins).input));
		}
		else if(ins instanceof CM_N_COVInstruction || ins instanceof GroupedAggregateInstruction )
		{
			dimOut.set(1, 1, 1, 1);
		}
		else if(ins instanceof RangeBasedReIndexInstruction)
		{
			RangeBasedReIndexInstruction realIns=(RangeBasedReIndexInstruction)ins;
			MatrixCharacteristics in_dim=dims.get(realIns.input);
			IndexRange ixrange = realIns.getIndexRange(); 
			long nrow=ixrange.rowEnd-ixrange.rowStart+1;
			long ncol=ixrange.colEnd-ixrange.colStart+1;
			dimOut.set(nrow, ncol, in_dim.numRowsPerBlock, in_dim.numColumnsPerBlock);
		}
		else if (ins instanceof TernaryInstruction) {
			TernaryInstruction realIns = (TernaryInstruction)ins;
			MatrixCharacteristics in_dim=dims.get(realIns.input1);
			dimOut.set(realIns.getOutputDim1(), realIns.getOutputDim2(), in_dim.numRowsPerBlock, in_dim.numColumnsPerBlock);
		}
		else { 
			/*
			 * if ins is none of the above cases then we assume that dim_out dimensions are unknown
			 */
			dimOut.numRows = -1;
			dimOut.numColumns = -1;
			dimOut.numRowsPerBlock=1;
			dimOut.numColumnsPerBlock=1;
		}
	}

	@Override
	public boolean equals (Object anObject)
	{
		if (anObject instanceof MatrixCharacteristics)
		{
			MatrixCharacteristics mc = (MatrixCharacteristics) anObject;
			return ((numRows == mc.numRows) && 
					(numColumns == mc.numColumns) && 
					(numRowsPerBlock == mc.numRowsPerBlock) && 
					(numColumnsPerBlock == mc.numColumnsPerBlock) && 
					(nonZero == mc.nonZero)) ;
		}
		else
			return false;
	}
	
	@Override
	public int hashCode()
	{
		return super.hashCode();
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy