All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.matrix.data.LibMatrixAgg Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.matrix.data;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import org.apache.sysml.lops.PartialAggregate.CorrectionLocationType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinFunctionCode;
import org.apache.sysml.runtime.functionobjects.CM;
import org.apache.sysml.runtime.functionobjects.IndexFunction;
import org.apache.sysml.runtime.functionobjects.KahanFunction;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.functionobjects.KahanPlusSq;
import org.apache.sysml.runtime.functionobjects.Mean;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.functionobjects.ReduceCol;
import org.apache.sysml.runtime.functionobjects.ReduceDiag;
import org.apache.sysml.runtime.functionobjects.ReduceRow;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
import org.apache.sysml.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
import org.apache.sysml.runtime.util.UtilFunctions;

/**
 * MB:
 * Library for matrix aggregations including ak+, uak+ for all
 * combinations of dense and sparse representations, and corrections.
 * Those are performance-critical operations because they are used
 * on combiners/reducers of important operations like tsmm, mvmult,
 * indexing, but also basic sum/min/max/mean, row*, col*, etc. Specific
 * handling is especially required for all non sparse-safe operations
 * in order to prevent unnecessary worse asymptotic behavior.
 *
 * This library currently covers the following opcodes:
 * ak+, uak+, uark+, uack+, uasqk+, uarsqk+, uacsqk+,
 * uamin, uarmin, uacmin, uamax, uarmax, uacmax,
 * ua*, uamean, uarmean, uacmean, uavar, uarvar, uacvar,
 * uarimax, uaktrace, cumk+, cummin, cummax, cum*, tak+.
 * 
 * TODO next opcode extensions: a+, colindexmax
 */
public class LibMatrixAgg 
{
	//internal configuration parameters
	private static final boolean NAN_AWARENESS = false;
	private static final long PAR_NUMCELL_THRESHOLD = 1024*1024;   //Min 1M elements
	private static final long PAR_INTERMEDIATE_SIZE_THRESHOLD = 2*1024*1024; //Max 2MB
	
	////////////////////////////////
	// public matrix agg interface
	////////////////////////////////
	
	private enum AggType {
		KAHAN_SUM,
		KAHAN_SUM_SQ,
		CUM_KAHAN_SUM,
		CUM_MIN,
		CUM_MAX,
		CUM_PROD,
		MIN,
		MAX,
		MEAN,
		VAR,
		MAX_INDEX,
		MIN_INDEX,
		PROD,
		INVALID,
	}

	private LibMatrixAgg() {
		//prevent instantiation via private constructor
	}
	
	/**
	 * Core incremental matrix aggregate (ak+) as used in mapmult, tsmm, 
	 * cpmm, etc. Note that we try to keep the current 
	 * aggVal and aggCorr in dense format in order to allow efficient
	 * access according to the dense/sparse input. 
	 * 
	 * 
	 * @param in input matrix
	 * @param aggVal current aggregate values (in/out)
	 * @param aggCorr current aggregate correction (in/out)
	 * @throws DMLRuntimeException 
	 */
	public static void aggregateBinaryMatrix(MatrixBlock in, MatrixBlock aggVal, MatrixBlock aggCorr) 
		throws DMLRuntimeException
	{	
		//Timing time = new Timing(true);
		//boolean saggVal = aggVal.isInSparseFormat(), saggCorr = aggCorr.isInSparseFormat(); 
		//long naggVal = aggVal.getNonZeros(), naggCorr = aggCorr.getNonZeros();
		
		//core aggregation
		if(!in.sparse && !aggVal.sparse && !aggCorr.sparse)
			aggregateBinaryMatrixAllDense(in, aggVal, aggCorr);
		else if(in.sparse && !aggVal.sparse && !aggCorr.sparse)
			aggregateBinaryMatrixSparseDense(in, aggVal, aggCorr);
		else if(in.sparse ) //any aggVal, aggCorr
			aggregateBinaryMatrixSparseGeneric(in, aggVal, aggCorr);
		else //if( !in.sparse ) //any aggVal, aggCorr
			aggregateBinaryMatrixDenseGeneric(in, aggVal, aggCorr);
		
		//System.out.println("agg ("+in.rlen+","+in.clen+","+in.getNonZeros()+","+in.sparse+"), " +
		//		          "("+naggVal+","+saggVal+"), ("+naggCorr+","+saggCorr+") -> " +
		//		          "("+aggVal.getNonZeros()+","+aggVal.isInSparseFormat()+"), ("+aggCorr.getNonZeros()+","+aggCorr.isInSparseFormat()+") " +
		//		          "in "+time.stop()+"ms.");
	}
	
	/**
	 * Core incremental matrix aggregate (ak+) as used for uack+ and acrk+.
	 * Embedded correction values.
	 * 
	 * @param in
	 * @param aggVal
	 * @throws DMLRuntimeException
	 */
	public static void aggregateBinaryMatrix(MatrixBlock in, MatrixBlock aggVal, AggregateOperator aop) 
		throws DMLRuntimeException
	{	
		//sanity check matching dimensions 
		if( in.getNumRows()!=aggVal.getNumRows() || in.getNumColumns()!=aggVal.getNumColumns() )
			throw new DMLRuntimeException("Dimension mismatch on aggregate: "+in.getNumRows()+"x"+in.getNumColumns()+
					" vs "+aggVal.getNumRows()+"x"+aggVal.getNumColumns());
		
		//Timing time = new Timing(true);
		
		//core aggregation
		boolean lastRowCorr = (aop.correctionLocation == CorrectionLocationType.LASTROW);
		boolean lastColCorr = (aop.correctionLocation == CorrectionLocationType.LASTCOLUMN);
		if( !in.sparse && lastRowCorr )
			aggregateBinaryMatrixLastRowDenseGeneric(in, aggVal);
		else if( in.sparse && lastRowCorr )
			aggregateBinaryMatrixLastRowSparseGeneric(in, aggVal);
		else if( !in.sparse && lastColCorr )
			aggregateBinaryMatrixLastColDenseGeneric(in, aggVal);
		else //if( in.sparse && lastColCorr )
			aggregateBinaryMatrixLastColSparseGeneric(in, aggVal);
		
		//System.out.println("agg ("+in.rlen+","+in.clen+","+in.getNonZeros()+","+in.sparse+"), ("+naggVal+","+saggVal+") -> " +
		//                   "("+aggVal.getNonZeros()+","+aggVal.isInSparseFormat()+") in "+time.stop()+"ms.");
	}
	
	/**
	 * 
	 * @param in
	 * @param out
	 * @param vFn
	 * @param ixFn
	 * @throws DMLRuntimeException
	 */
	public static void aggregateUnaryMatrix(MatrixBlock in, MatrixBlock out, AggregateUnaryOperator uaop) 
		throws DMLRuntimeException
	{
		//prepare meta data 
		AggType aggtype = getAggType(uaop);
		final int m = in.rlen;
		final int m2 = out.rlen;
		final int n2 = out.clen;
		
		//filter empty input blocks (incl special handling for sparse-unsafe operations)
		if( in.isEmptyBlock(false) ){
			aggregateUnaryMatrixEmpty(in, out, aggtype, uaop.indexFn);
			return;
		}	
		
		//Timing time = new Timing(true);
		
		//allocate output arrays (if required)
		out.reset(m2, n2, false); //always dense
		out.allocateDenseBlock();
		
		if( !in.sparse )
			aggregateUnaryMatrixDense(in, out, aggtype, uaop.aggOp.increOp.fn, uaop.indexFn, 0, m);
		else
			aggregateUnaryMatrixSparse(in, out, aggtype, uaop.aggOp.increOp.fn, uaop.indexFn, 0, m);
				
		//cleanup output and change representation (if necessary)
		out.recomputeNonZeros();
		out.examSparsity();
		
		//System.out.println("uagg ("+in.rlen+","+in.clen+","+in.sparse+") in "+time.stop()+"ms.");
	}
	
	/**
	 * 
	 * @param in
	 * @param out
	 * @param uaop
	 * @param k
	 * @throws DMLRuntimeException
	 */
	public static void aggregateUnaryMatrix(MatrixBlock in, MatrixBlock out, AggregateUnaryOperator uaop, int k) 
		throws DMLRuntimeException
	{
		//fall back to sequential version if necessary
		if(    k <= 1 || (long)in.rlen*in.clen < PAR_NUMCELL_THRESHOLD || in.rlen <= k
			|| (!(uaop.indexFn instanceof ReduceCol) &&  out.clen*8*k > PAR_INTERMEDIATE_SIZE_THRESHOLD ) ) {
			aggregateUnaryMatrix(in, out, uaop);
			return;
		}
		
		//prepare meta data
		AggType aggtype = getAggType(uaop);
		final int m = in.rlen;
		final int m2 = out.rlen;
		final int n2 = out.clen;
		
		//filter empty input blocks (incl special handling for sparse-unsafe operations)
		if( in.isEmptyBlock(false) ){
			aggregateUnaryMatrixEmpty(in, out, aggtype, uaop.indexFn);
			return;
		}	
		
		//Timing time = new Timing(true);
		
		//allocate output arrays (if required)
		if( uaop.indexFn instanceof ReduceCol ) {
			out.reset(m2, n2, false); //always dense
			out.allocateDenseBlock();
		}
		
		//core multi-threaded unary aggregate computation
		//(currently: always parallelization over number of rows)
		try {
			ExecutorService pool = Executors.newFixedThreadPool( k );
			ArrayList tasks = new ArrayList();
			int blklen = (int)(Math.ceil((double)m/k));
			for( int i=0; i tasks = new ArrayList();
			int blklen = (int)(Math.ceil((double)in1.rlen/k));
			for( int i=0; i1);
		if( k <= 1 || (long)target.rlen*target.clen < PAR_NUMCELL_THRESHOLD || rowVector || target.clen==1 ) {
			groupedAggregate(groups, target, weights, result, numGroups, op);
			return;
		}
		
		if( !(op instanceof CMOperator || op instanceof AggregateOperator) ) {
			throw new DMLRuntimeException("Invalid operator (" + op + ") encountered while processing groupedAggregate.");
		}
		
		//preprocessing
		result.sparse = false;
		result.allocateDenseBlock();
		
		//core multi-threaded grouped aggregate computation
		//(currently: parallelization over columns to avoid additional memory requirements)
		try {
			ExecutorService pool = Executors.newFixedThreadPool( k );
			ArrayList tasks = new ArrayList();
			int blklen = (int)(Math.ceil((double)target.clen/k));
			for( int i=0; i1);
		int numCols = (!rowVector) ? target.getNumColumns() : 1;
		double w = 1; //default weight
		
		//skip empty blocks (sparse-safe operation)
		if( target.isEmptyBlock(false) ) 
			return;
		
		//init group buffers
		int numCols2 = cu-cl;
		KahanObject[][] buffer = new KahanObject[numGroups][numCols2];
		for( int i=0; i numGroups )
							continue;
						if ( weights != null )
							w = weights.quickGetValue(aix[j],0);
						aggop.increOp.fn.execute(buffer[g-1][0], avals[j]*w);						
					}
				}
					
			}
			else //DENSE target
			{
				for ( int i=0; i < target.getNumColumns(); i++ ) {
					double d = target.denseBlock[ i ];
					if( d != 0 ) //sparse-safe
					{
						int g = (int) groups.quickGetValue(i, 0);		
						if ( g > numGroups )
							continue;
						if ( weights != null )
							w = weights.quickGetValue(i,0);
						// buffer is 0-indexed, whereas range of values for g = [1,numGroups]
						aggop.increOp.fn.execute(buffer[g-1][0], d*w);
					}
				}
			}
		}
		else //column vector or matrix 
		{
			if( target.sparse ) //SPARSE target
			{
				SparseRow[] a = target.sparseRows;
				
				for( int i=0; i < groups.getNumRows(); i++ ) 
				{
					int g = (int) groups.quickGetValue(i, 0);		
					if ( g > numGroups )
						continue;
					
					if( a[i] != null && !a[i].isEmpty() )
					{
						int len = a[i].size();
						int[] aix = a[i].getIndexContainer();
						double[] avals = a[i].getValueContainer();	
						int j = (cl==0) ? 0 : a[i].searchIndexesFirstGTE(cl);
						j = (j>=0) ? j : len;
						
						for( ; j numGroups )
						continue;
				
					for( int j=cl; j < cu; j++ ) {
						double d = a[ aix+j ];
						if( d != 0 ) { //sparse-safe
							if ( weights != null )
								w = weights.quickGetValue(i,0);
							// buffer is 0-indexed, whereas range of values for g = [1,numGroups]
							aggop.increOp.fn.execute(buffer[g-1][j-cl], d*w);
						}
					}
				}
			}
		}
		
		// extract the results from group buffers
		for( int i=0; i < numGroups; i++ )
			for( int j=0; j < numCols2; j++ )
				result.appendValue(i, j+cl, buffer[i][j]._sum);
	}

	/**
	 * 
	 * @param target
	 * @param weights
	 * @param result
	 * @param cmOp
	 * @throws DMLRuntimeException
	 */
	private static void groupedAggregateCM( MatrixBlock groups, MatrixBlock target, MatrixBlock weights, MatrixBlock result, int numGroups, CMOperator cmOp, int cl, int cu ) 
		throws DMLRuntimeException
	{
		CM cmFn = CM.getCMFnObject(((CMOperator) cmOp).getAggOpType());
		double w = 1; //default weight
		
		//init group buffers
		int numCols2 = cu-cl;
		CM_COV_Object[][] cmValues = new CM_COV_Object[numGroups][numCols2];
		for ( int i=0; i < numGroups; i++ )
			for( int j=0; j < numCols2; j++  )
				cmValues[i][j] = new CM_COV_Object();
		
		//column vector or matrix
		if( target.sparse ) //SPARSE target
		{
			SparseRow[] a = target.sparseRows;
			
			for( int i=0; i < groups.getNumRows(); i++ ) 
			{
				int g = (int) groups.quickGetValue(i, 0);		
				if ( g > numGroups )
					continue;
				
				if( a[i] != null && !a[i].isEmpty() )
				{
					int len = a[i].size();
					int[] aix = a[i].getIndexContainer();
					double[] avals = a[i].getValueContainer();	
					int j = (cl==0) ? 0 : a[i].searchIndexesFirstGTE(cl);
					j = (j>=0) ? j : len;
					
					for( ; j numGroups )
					continue;
			
				for( int j=cl; j numGroups )
				continue;
			tmp[g-1]++;
		}
		
		//copy counts into result
		for( int i=0; i=0; j--, ix-- )
						if( aix[j]!=ix )
							break;
					c[cix+0] = ix + 1; //max index (last)
					c[cix+1] = 0; //max value
				}
			}
			else //if( arow==null )
			{
				//correction (not sparse-safe)	
				c[cix+0] = n; //max index (last)
				c[cix+1] = 0; //max value
			}
		}
	}
	
	/**
	 * ROWINDEXMIN, opcode: uarimin, sparse input.
	 * 
	 * @param a
	 * @param c
	 * @param m
	 * @param n
	 * @param init
	 * @param builtin
	 */
	private static void s_uarimin( SparseRow[] a, double[] c, int m, int n, double init, Builtin builtin, int rl, int ru ) 
	{
		for( int i=rl, cix=rl*2; i=0; j--, ix-- )
						if( aix[j]!=ix )
							break;
					c[cix+0] = ix + 1; //min index (last)
					c[cix+1] = 0; //min value
				}
			}
			else //if( arow==null )
			{
				//correction (not sparse-safe)	
				c[cix+0] = n; //min index (last)
				c[cix+1] = 0; //min value
			}
		}
	}
	
	/**
	 * MEAN, opcode: uamean, sparse input.
	 * 
	 * @param a
	 * @param c
	 * @param m
	 * @param n
	 * @param kbuff
	 * @param kplus
	 */
	private static void s_uamean( SparseRow[] a, double[] c, int m, int n, KahanObject kbuff, Mean kmean, int rl, int ru )
	{
		int len = (ru-rl) * n;
		int count = 0;
		
		//correction remaining tuples (not sparse-safe)
		//note: before aggregate computation in order to
		//exploit 0 sum (noop) and better numerical stability
		for( int i=rl; i=maxval) ? i-ai : maxindex;
			maxval = (a[i]>=maxval) ? a[i] : maxval;
		}

		return maxindex;
	}
	
	/**
	 * 
	 * @param a
	 * @param ai
	 * @param init
	 * @param len
	 * @param aggop
	 * @return
	 */
	private static int indexmin( double[] a, int ai, final double init, final int len, Builtin aggop ) 
	{
		double minval = init;
		int minindex = -1;
		
		for( int i=ai; i {}
	
	/**
	 *
	 * 
	 */
	private static class RowAggTask extends AggTask 
	{
		private MatrixBlock _in  = null;
		private MatrixBlock _ret = null;
		private AggType _aggtype = null;
		private AggregateUnaryOperator _uaop = null;		
		private int _rl = -1;
		private int _ru = -1;

		protected RowAggTask( MatrixBlock in, MatrixBlock ret, AggType aggtype, AggregateUnaryOperator uaop, int rl, int ru )
		{
			_in = in;
			_ret = ret;
			_aggtype = aggtype;
			_uaop = uaop;
			_rl = rl;
			_ru = ru;
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			if( !_in.sparse )
				aggregateUnaryMatrixDense(_in, _ret, _aggtype, _uaop.aggOp.increOp.fn, _uaop.indexFn, _rl, _ru);
			else
				aggregateUnaryMatrixSparse(_in, _ret, _aggtype, _uaop.aggOp.increOp.fn, _uaop.indexFn, _rl, _ru);
			
			return null;
		}
	}

	/**
	 * 
	 * 
	 */
	private static class PartialAggTask extends AggTask 
	{
		private MatrixBlock _in  = null;
		private MatrixBlock _ret = null;
		private AggType _aggtype = null;
		private AggregateUnaryOperator _uaop = null;		
		private int _rl = -1;
		private int _ru = -1;

		protected PartialAggTask( MatrixBlock in, MatrixBlock ret, AggType aggtype, AggregateUnaryOperator uaop, int rl, int ru ) 
			throws DMLRuntimeException
		{
			_in = in;			
			_aggtype = aggtype;
			_uaop = uaop;
			_rl = rl;
			_ru = ru;
			
			//allocate local result for partial aggregation
			_ret = new MatrixBlock(ret.rlen, ret.clen, false);
			_ret.allocateDenseBlock();
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			if( !_in.sparse )
				aggregateUnaryMatrixDense(_in, _ret, _aggtype, _uaop.aggOp.increOp.fn, _uaop.indexFn, _rl, _ru);
			else
				aggregateUnaryMatrixSparse(_in, _ret, _aggtype, _uaop.aggOp.increOp.fn, _uaop.indexFn, _rl, _ru);
			
			//recompute non-zeros of partial result
			_ret.recomputeNonZeros();
			
			return null;
		}
		
		public MatrixBlock getResult() {
			return _ret;
		}
	}
	
	/**
	 * 
	 */
	private static class AggTernaryTask extends AggTask 
	{
		private MatrixBlock _in1  = null;
		private MatrixBlock _in2  = null;
		private MatrixBlock _in3  = null;
		private double _ret = -1;
		private int _rl = -1;
		private int _ru = -1;

		protected AggTernaryTask( MatrixBlock in1, MatrixBlock in2, MatrixBlock in3, int rl, int ru ) 
			throws DMLRuntimeException
		{
			_in1 = in1;	
			_in2 = in2;	
			_in3 = in3;				
			_rl = rl;
			_ru = ru;
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			if( !_in1.sparse && !_in2.sparse && (_in3==null||!_in3.sparse) ) //DENSE
				_ret = aggregateTernaryDense(_in1, _in2, _in3, _rl, _ru);
			else //GENERAL CASE
				_ret = aggregateTernaryGeneric(_in1, _in2, _in3, _rl, _ru);
			
			return null;
		}
		
		public double getResult() {
			return _ret;
		}
	}
	
	private static class GrpAggTask extends AggTask 
	{
		private MatrixBlock _groups  = null;
		private MatrixBlock _target  = null;
		private MatrixBlock _weights  = null;
		private MatrixBlock _ret  = null;
		private int _numGroups = -1;
		private Operator _op = null;
		private int _cl = -1;
		private int _cu = -1;

		protected GrpAggTask( MatrixBlock groups, MatrixBlock target, MatrixBlock weights, MatrixBlock ret, int numGroups, Operator op, int cl, int cu ) 
			throws DMLRuntimeException
		{
			_groups = groups;
			_target = target;
			_weights = weights;
			_ret = ret;
			_numGroups = numGroups;
			_op = op;
			_cl = cl;
			_cu = cu;
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			//CM operator for count, mean, variance
			//note: current support only for column vectors
			if( _op instanceof CMOperator ) {
				CMOperator cmOp = (CMOperator) _op;
				groupedAggregateCM(_groups, _target, _weights, _ret, _numGroups, cmOp, _cl, _cu);
			}
			//Aggregate operator for sum (via kahan sum)
			//note: support for row/column vectors and dense/sparse
			else if( _op instanceof AggregateOperator ) {
				AggregateOperator aggop = (AggregateOperator) _op;
				groupedAggregateKahanPlus(_groups, _target, _weights, _ret, _numGroups, aggop, _cl, _cu);
			}
			
			return null;
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy