All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.matrix.data.LibMatrixMult Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */


package org.apache.sysml.runtime.matrix.data;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import org.apache.commons.math3.util.FastMath;

import org.apache.sysml.lops.MapMultChain.ChainType;
import org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType;
import org.apache.sysml.lops.WeightedDivMM.WDivMMType;
import org.apache.sysml.lops.WeightedSigmoid.WSigmoidType;
import org.apache.sysml.lops.WeightedSquaredLoss.WeightsType;
import org.apache.sysml.lops.WeightedUnaryMM.WUMMType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
import org.apache.sysml.runtime.util.UtilFunctions;

/**
 * MB:
 * Library for matrix multiplications including MM, MV, VV for all
 * combinations of dense, sparse, ultrasparse representations and special
 * operations such as transpose-self matrix multiplication.
 * 
 * In general all implementations use internally dense outputs
 * for direct access, but change the final result to sparse if necessary.
 * The only exceptions are ultra-sparse matrix mult, wsloss and wsigmoid.  
 * 
 * NOTES on BLAS:
 * * Experiments in 04/2013 showed that even on dense-dense this implementation 
 *   is 3x faster than f2j-BLAS-DGEMM, 2x faster than f2c-BLAS-DGEMM, and
 *   level (+10% after JIT) with a native C implementation. 
 * * Calling native BLAS would loose platform independence and would require 
 *   JNI calls incl data transfer. Furthermore, BLAS does not support sparse 
 *   matrices (except Sparse BLAS, with dedicated function calls and matrix formats) 
 *   and would be an external dependency. 
 * * Experiments in 02/2014 showed that on dense-dense this implementation now achieves
 *   almost 30% peak FP performance. Compared to Intel MKL 11.1 (dgemm, N=1000) it is
 *   just 3.2x (sparsity=1.0) and 1.9x (sparsity=0.5) slower, respectively.  
 *  
 */
public class LibMatrixMult 
{
	//internal configuration
	public static final boolean LOW_LEVEL_OPTIMIZATION = true;
	public static final long MEM_OVERHEAD_THRESHOLD = 2L*1024*1024; //MAX 2 MB
	private static final long PAR_MINFLOP_THRESHOLD = 2L*1024*1024; //MIN 2 MFLOP
	
	private LibMatrixMult() {
		//prevent instantiation via private constructor
	}
	
	////////////////////////////////
	// public matrix mult interface
	////////////////////////////////
	
	/**
	 * Performs a matrix multiplication and stores the result in the output matrix.
	 * 
	 * All variants use a IKJ access pattern, and internally use dense output. After the
	 * actual computation, we recompute nnz and check for sparse/dense representation.
	 *  
	 * 
	 * @param m1 first matrix
	 * @param m2 second matrix
	 * @param ret result matrix
	 * @throws DMLRuntimeException 
	 */
	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret) 
		throws DMLRuntimeException
	{	
		//check inputs / outputs
		if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
			ret.examSparsity(); //turn empty dense into sparse
			return;
		}
		
		//Timing time = new Timing(true);
		
		//pre-processing: output allocation
		boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
		m2 = prepMatrixMultRightInput(m1, m2);
		ret.sparse = (m1.isUltraSparse() || m2.isUltraSparse());
		if( !ret.sparse )
			ret.allocateDenseBlock();
		
		//prepare row-upper for special cases of vector-matrix
		boolean pm2 = checkParMatrixMultRightInput(m1, m2, Integer.MAX_VALUE);
		int ru = pm2 ? m2.rlen : m1.rlen; 
		
		//core matrix mult computation
		if( m1.isUltraSparse() || m2.isUltraSparse() )
			matrixMultUltraSparse(m1, m2, ret, 0, ru);
		else if(!m1.sparse && !m2.sparse)
			matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru);
		else if(m1.sparse && m2.sparse)
			matrixMultSparseSparse(m1, m2, ret, pm2, 0, ru);
		else if(m1.sparse)
			matrixMultSparseDense(m1, m2, ret, pm2, 0, ru);
		else
			matrixMultDenseSparse(m1, m2, ret, pm2, 0, ru);
		
		//post-processing: nnz/representation
		if( !ret.sparse )
			ret.recomputeNonZeros();
		ret.examSparsity();
		
		//System.out.println("MM ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" +
		//		              "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop());
	}
	
	/**
	 * Performs a multi-threaded matrix multiplication and stores the result in the output matrix.
	 * The parameter k (k>=1) determines the max parallelism k' with k'=min(k, vcores, m1.rlen).
	 * 
	 * @param m1
	 * @param m2
	 * @param ret
	 * @param k
	 * @throws DMLRuntimeException 
	 */
	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) 
		throws DMLRuntimeException
	{	
		//check inputs / outputs
		if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
			ret.examSparsity(); //turn empty dense into sparse
			return;
		}
		
		//check too high additional vector-matrix memory requirements (fallback to sequential)
		//check too small workload in terms of flops (fallback to sequential too)
		if( m1.rlen == 1 && (8L * m2.clen * k > MEM_OVERHEAD_THRESHOLD || !LOW_LEVEL_OPTIMIZATION || m2.clen==1 || m1.isUltraSparse() || m2.isUltraSparse()) 
			|| 2L * m1.rlen * m1.clen * m2.clen < PAR_MINFLOP_THRESHOLD ) 
		{ 
			matrixMult(m1, m2, ret);
			return;
		}
		
		//Timing time = new Timing(true);
		
		//pre-processing: output allocation (in contrast to single-threaded,
		//we need to allocate sparse as well in order to prevent synchronization)
		boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
		m2 = prepMatrixMultRightInput(m1, m2);
		ret.sparse = (m1.isUltraSparse() || m2.isUltraSparse());
		if( !ret.sparse )
			ret.allocateDenseBlock();
		else
			ret.allocateSparseRowsBlock();
		
		//prepare row-upper for special cases of vector-matrix / matrix-matrix
		boolean pm2 = checkParMatrixMultRightInput(m1, m2, k);
		int ru = pm2 ? m2.rlen : m1.rlen; 
		
		//core multi-threaded matrix mult computation
		//(currently: always parallelization over number of rows)
		try {
			ExecutorService pool = Executors.newFixedThreadPool( k );
			ArrayList tasks = new ArrayList();
			int blklen = (int)(Math.ceil((double)ru/k));
			for( int i=0; i=1) determines the max parallelism k' with k'=min(k, vcores, m1.rlen).
	 * 
	 * NOTE: This multi-threaded mmchain operation has additional memory requirements of k*ncol(X)*8bytes 
	 * for partial aggregation. Current max memory: 256KB; otherwise redirectly to sequential execution.
	 * 
	 * @param mX
	 * @param mV
	 * @param mW
	 * @param ret
	 * @param ct
	 * @param k
	 * @throws DMLRuntimeException
	 * @throws DMLUnsupportedOperationException
	 */
	public static void matrixMultChain(MatrixBlock mX, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, ChainType ct, int k) 
		throws DMLRuntimeException, DMLUnsupportedOperationException
	{		
		//check inputs / outputs (after that mV and mW guaranteed to be dense)
		if( mX.isEmptyBlock(false) || mV.isEmptyBlock(false) || (mW !=null && mW.isEmptyBlock(false)) ) {
			ret.examSparsity(); //turn empty dense into sparse
			return;
		}

		//check too high additional memory requirements (fallback to sequential)
		//check too small workload in terms of flops (fallback to sequential too)
		if( 8L * mV.rlen * k > MEM_OVERHEAD_THRESHOLD 
			|| 4L * mX.rlen * mX.clen < PAR_MINFLOP_THRESHOLD ) 
		{ 
			matrixMultChain(mX, mV, mW, ret, ct);
			return;
		}
		
		//Timing time = new Timing(true);
				
		//pre-processing
		ret.sparse = false;
		ret.allocateDenseBlock();
		
		//core matrix mult chain computation
		//(currently: always parallelization over number of rows)
		try {
			ExecutorService pool = Executors.newFixedThreadPool( k );
			ArrayList tasks = new ArrayList();
			int blklen = (int)(Math.ceil((double)mX.rlen/k));
			blklen += (blklen%24 != 0)?24-blklen%24:0;
			for( int i=0; i tasks = new ArrayList();
			//load balance via #tasks=2k due to triangular shape 
			int blklen = (int)(Math.ceil((double)ret.rlen/(2*k)));
			for( int i=0; i<2*k & i*blklen tasks = new ArrayList();
			int blklen = (int)(Math.ceil((double)pm1.rlen/k));
			for( int i=0; i tasks = new ArrayList();
			int blklen = (int)(Math.ceil((double)mX.rlen/k));
			for( int i=0; i tasks = new ArrayList();
			int blklen = (int)(Math.ceil((double)mW.rlen/k));
			for( int i=0; i tasks = new ArrayList();			
			//create tasks (for wdivmm-left, parallelization over columns;
			//for wdivmm-right, parallelization over rows; both ensure disjoint results)
			if( wt.isLeft() ) {
				int blklen = (int)(Math.ceil((double)mW.clen/k));
				for( int j=0; j tasks = new ArrayList();
			int blklen = (int)(Math.ceil((double)mW.rlen/k));
			for( int i=0; i tasks = new ArrayList();
			int blklen = (int)(Math.ceil((double)mW.rlen/k));
			for( int i=0; i1 && cd == 1 )  //OUTER PRODUCT
			{
				for( int i=rl, cix=rl*n; i < ru; i++, cix+=n) {
					if( a[i] == 1 )
						System.arraycopy(b, 0, c, cix, n);
				    else if( a[i] != 0 )
						vectMultiplyWrite(a[i], b, c, 0, cix, n);
					else
						Arrays.fill(c, cix, cix+n, 0);
				}
			}
			else if( n==1 && cd == 1 ) //VECTOR-SCALAR
			{
				vectMultiplyWrite(b[0], a, c, rl, rl, ru-rl);
			}
			else if( n==1 )            //MATRIX-VECTOR
			{
				for( int i=rl, aix=rl*cd; i < ru; i++, aix+=cd) 
					c[ i ] = dotProduct(a, b, aix, 0, cd);	
			}
			else if( pm2 && m==1 )     //VECTOR-MATRIX
			{
				//parallelization over rows in rhs matrix
				//rest not aligned to blocks of 2 rows
				final int kn = (ru-rl)%2;
				if( kn == 1 && a[rl] != 0 )
					vectMultiplyAdd(a[rl], b, c, rl*n, 0, n);
				
				//compute blocks of 2 rows (2 instead of 4 for small n<64) 
				for( int k=rl+kn, bix=(rl+kn)*n; k n && cd > 64 && n < 64
				//however, explicit flag required since dimension change m2
				final int n2 = m2.rlen;
				for( int i=rl, aix=rl*cd, cix=rl*n2; i < ru; i++, aix+=cd, cix+=n2 ) 
					for( int j=0, bix=0; j 0 ) //for skipping empty rows
				    			
			    				//rest not aligned to blocks of 4 rows
				    			final int bn = knnz % 4;
				    			switch( bn ){
					    			case 1: vectMultiplyAdd(ta[0], b, c, tbi[0], cixj, bjlen); break;
					    	    	case 2: vectMultiplyAdd2(ta[0],ta[1], b, c, tbi[0], tbi[1], cixj, bjlen); break;
					    			case 3: vectMultiplyAdd3(ta[0],ta[1],ta[2], b, c, tbi[0], tbi[1],tbi[2], cixj, bjlen); break;
				    			}
				    			
				    			//compute blocks of 4 rows (core inner loop)
				    			for( int k = bn; k=0) ? rlix : alen;
					
					for( int k=rlix; k=0) ? k1 : alen;
						int k2 = (ru==cd) ? alen : a[i].searchIndexesFirstGTE(ru);
						k2 = (k2>=0) ? k2 : alen;
						
						//rest not aligned to blocks of 4 rows
		    			final int bn = (k2-k1) % 4;
		    			switch( bn ){
			    			case 1: vectMultiplyAdd(avals[k1], b, c, aix[k1]*n, cix, n); break;
			    	    	case 2: vectMultiplyAdd2(avals[k1],avals[k1+1], b, c, aix[k1]*n, aix[k1+1]*n, cix, n); break;
			    			case 3: vectMultiplyAdd3(avals[k1],avals[k1+1],avals[k1+2], b, c, aix[k1]*n, aix[k1+1]*n, aix[k1+2]*n, cix, n); break;
		    			}
		    			
		    			//compute blocks of 4 rows (core inner loop)
		    			for( int k = k1+bn; k=0) ? rlix : alen;
					
					for( int k=rlix; k=0) ? rlix : alen;
						
						for(int i = rlix; i < alen && aix[i]=0) ? rlix : alen;
						
						for(int i = rlix; i < alen && aix[i]=0) ? rlix : alen;
							
							for(int i = rlix; i < alen && aix[i]=0) ? rlix : alen;
							
							for(int i = rlix; i < alen && aix[i] 0 ) //selected row
			{
				int bpos = (pos-1) % brlen;
				int blk = (pos-1) / brlen;
				
				//allocate and switch to second output block
				//(never happens in cp, correct for multi-threaded usage)
				if( lastblk!=-1 && lastblk 0 ) //selected row
			{
				int bpos = (pos-1) % brlen;
				int blk = (pos-1) / brlen;
				
				//allocate and switch to second output block
				//(never happens in cp, correct for multi-threaded usage)
				if( lastblk!=-1 && lastblk 0 ) //selected row
			{
				int bpos = (pos-1) % brlen;
				int blk = (pos-1) / brlen;
				
				//allocate and switch to second output block
				//(never happens in cp, correct for multi-threaded usage)
				if( lastblk!=-1 && lastblk=0) ? k : wlen;
					for( ; k=0) ? k : wlen;
					for( ; k 1) //X%*%t(X) SPARSE MATRIX
		{	
			//directly via LibMatrixReorg in order to prevent sparsity change
			MatrixBlock tmpBlock = new MatrixBlock(m1.clen, m1.rlen, m1.sparse);
			LibMatrixReorg.reorg(m1, tmpBlock, new ReorgOperator(SwapIndex.getSwapIndexFnObject()));
			ret = tmpBlock;
		}
		
		return ret;
	}
	
	/**
	 * 
	 * @param m1
	 * @param m2
	 * @return
	 */
	private static boolean checkPrepMatrixMultRightInput( MatrixBlock m1, MatrixBlock m2 )
	{
		//transpose if dense-dense, skinny rhs matrix (not vector), and memory guarded by output 
		return (!m1.sparse && !m2.sparse && m1.rlen>m2.clen && m2.rlen > 64 && m2.clen > 1 && m2.clen < 64);
	}
	
	/**
	 * 
	 * @param m1
	 * @param m2
	 * @return
	 */
	private static boolean checkParMatrixMultRightInput( MatrixBlock m1, MatrixBlock m2, int k )
	{
		//parallelize over rows in rhs matrix if number of rows in lhs/output is very small
		return (m1.rlen==1 && LOW_LEVEL_OPTIMIZATION && m2.clen>1 && !(m1.isUltraSparse()||m2.isUltraSparse()))
			|| (m1.rlen<=16 && LOW_LEVEL_OPTIMIZATION && m2.clen>1 && m2.rlen > m1.rlen 
			   && ( !m1.sparse && !m2.sparse || m1.sparse && !m2.sparse ) //dense-dense / sparse/dense
			   && (long)k * 8 * m1.rlen * m2.clen < MEM_OVERHEAD_THRESHOLD ); 
	}
	
	/**
	 * 
	 * @param m1
	 * @param m2
	 * @return
	 * @throws DMLRuntimeException
	 */
	private static MatrixBlock prepMatrixMultRightInput( MatrixBlock m1, MatrixBlock m2 ) 
		throws DMLRuntimeException
	{
		MatrixBlock ret = m2;
		
		//transpose if dense-dense, skinny rhs matrix (not vector), and memory guarded by output 
		if( checkPrepMatrixMultRightInput(m1, m2)  ) {
			MatrixBlock tmpBlock = new MatrixBlock(m2.clen, m2.rlen, m2.sparse);
			LibMatrixReorg.reorg(m2, tmpBlock, new ReorgOperator(SwapIndex.getSwapIndexFnObject()));
			ret = tmpBlock;
		}
		
		return ret;
	}

	/**
	 * 
	 * @param a
	 * @param aixi
	 * @param bk
	 * @param bj
	 * @param n
	 * @param tmpa
	 * @param tmpbi
	 * @param bklen
	 * @return
	 */
	private static int copyNonZeroElements( double[] a, final int aixi, final int bk, final int bj, final int n, double[] tmpa, int[] tmpbi, final int bklen )
	{
		int knnz = 0;
		for( int k = 0; k < bklen; k++ )
			if( a[ aixi+k ] != 0 ) {
				tmpa[ knnz ] = a[ aixi+k ];
				tmpbi[ knnz ] = (bk+k) * n + bj; //scan index on b
				knnz ++;
			}
		
		return knnz;
	}
	
	/**
	 * 
	 * @param a
	 * @param aixi
	 * @param bk
	 * @param bj
	 * @param n
	 * @param nx
	 * @param tmpa
	 * @param tmpbi
	 * @param bklen
	 * @return
	 */
	private static int copyNonZeroElements( double[] a, int aixi, final int bk, final int bj, final int n, final int nx, double[] tmpa, int[] tmpbi, final int bklen )
	{
		int knnz = 0;
		for( int k = 0; k < bklen; k++, aixi+=n )
			if( a[ aixi ] != 0 ) {
				tmpa[ knnz ] = a[ aixi ];
				tmpbi[ knnz ] = (bk+k) * nx + bj; //scan index on b
				knnz ++;
			}
		
		return knnz;
	}

	/**
	 * 
	 * @param tasks
	 * @param ret
	 */
	private static void sumScalarResults(ArrayList tasks, MatrixBlock ret)
	{
		//aggregate partial results
		double val = 0;
		for(ScalarResultTask task : tasks)
			val += task.getScalarResult();
		ret.quickSetValue(0, 0, val);
	}
	
	/**
	 * 
	 * @param partret
	 * @param ret
	 */
	@SuppressWarnings("unused")
	private static void sumDenseResults( double[][] partret, double[] ret )
	{
		final int len = ret.length;
		final int k = partret.length;
		final int bk = k % 4;
		final int blocksize = 2 * 1024; //16KB (half of common L1 data)
		
		//cache-conscious aggregation to prevent repreated scans/writes of ret
		for( int bi=0; bi 
	{
		private MatrixBlock _m1  = null;
		private MatrixBlock _m2  = null;
		private MatrixBlock _ret = null;
		private boolean _tm2 = false; //transposed m2
		private boolean _pm2 = false; //par over m2
		private int _rl = -1;
		private int _ru = -1;
		private long _nnz = -1;

		protected MatrixMultTask( MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean tm2, boolean pm2, int rl, int ru )
		{
			_m1 = m1;
			_m2 = m2;
			_tm2 = tm2;
			_pm2 = pm2;
			_rl = rl;
			_ru = ru;
			
			if( pm2 ) { //vector-matrix / matrix-matrix
				//allocate local result for partial aggregation
				_ret = new MatrixBlock(ret.rlen, ret.clen, false);
			}
			else { //default case
				_ret = ret;
			}
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			//thread-local allocation
			if( _pm2 )
				_ret.allocateDenseBlock();
			
			//compute block matrix multiplication
			if( _m1.isUltraSparse() || _m2.isUltraSparse() )
				matrixMultUltraSparse(_m1, _m2, _ret, _rl, _ru);
			else if(!_m1.sparse && !_m2.sparse)
				matrixMultDenseDense(_m1, _m2, _ret, _tm2, _pm2, _rl, _ru);
			else if(_m1.sparse && _m2.sparse)
				matrixMultSparseSparse(_m1, _m2, _ret, _pm2, _rl, _ru);
			else if(_m1.sparse)
				matrixMultSparseDense(_m1, _m2, _ret, _pm2, _rl, _ru);
			else
				matrixMultDenseSparse(_m1, _m2, _ret, _pm2, _rl, _ru);
			
			//maintain block nnz (upper bounds inclusive)
			if( !_pm2 )
				_nnz = _ret.recomputeNonZeros(_rl, _ru-1, 0, _ret.getNumColumns()-1);
			
			return null;
		}
		
		public long getPartialNnz(){
			return _nnz;
		}
		
		public MatrixBlock getResult(){
			return _ret;
		}
	}
	
	/**
	 * 
	 * 
	 */
	private static class MatrixMultChainTask implements Callable 
	{
		private MatrixBlock _m1  = null;
		private MatrixBlock _m2  = null;
		private MatrixBlock _m3  = null;
		private MatrixBlock _ret = null;
		private ChainType _ct = null;
		private int _rl = -1;
		private int _ru = -1;

		protected MatrixMultChainTask( MatrixBlock mX, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, ChainType ct, int rl, int ru ) 
			throws DMLRuntimeException
		{
			_m1 = mX;
			_m2 = mV;
			_m3 = mW;
			_ct = ct;
			_rl = rl;
			_ru = ru;
			
			//allocate local result for partial aggregation
			_ret = new MatrixBlock(ret.rlen, ret.clen, false);
			_ret.allocateDenseBlock();
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			if( _m1.sparse )
				matrixMultChainSparse(_m1, _m2, _m3, _ret, _ct, _rl, _ru);
			else
				matrixMultChainDense(_m1, _m2, _m3, _ret, _ct, _rl, _ru);
			
			//NOTE: we dont do global aggregation from concurrent tasks in order
			//to prevent synchronization (sequential aggregation led to better 
			//performance after JIT)
			
			return null;
		}
		
		public MatrixBlock getResult() {
			return _ret;
		}
	}

	private static class MatrixMultTransposeTask implements Callable 
	{
		private MatrixBlock _m1  = null;
		private MatrixBlock _ret = null;
		private boolean _left = true;
		private int _rl = -1;
		private int _ru = -1;

		protected MatrixMultTransposeTask( MatrixBlock m1, MatrixBlock ret, boolean left, int rl, int ru )
		{
			_m1 = m1;
			_ret = ret;
			_left = left;
			_rl = rl;
			_ru = ru;
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			if( _m1.sparse )
				matrixMultTransposeSelfSparse(_m1, _ret, _left, _rl, _ru);
			else
				matrixMultTransposeSelfDense(_m1, _ret, _left, _rl, _ru);
				
			return null;
		}
	}
	
	/**
	 * 
	 * 
	 */
	private static class MatrixMultPermuteTask implements Callable 
	{
		private MatrixBlock _pm1  = null;
		private MatrixBlock _m2 = null;
		private MatrixBlock _ret1 = null;
		private MatrixBlock _ret2 = null;
		private int _rl = -1;
		private int _ru = -1;

		protected MatrixMultPermuteTask( MatrixBlock pm1, MatrixBlock m2, MatrixBlock ret1, MatrixBlock ret2, int rl, int ru)
		{
			_pm1 = pm1;
			_m2 = m2;
			_ret1 = ret1;
			_ret2 = ret2;
			_rl = rl;
			_ru = ru;
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			if( _m2.sparse )
				matrixMultPermuteSparse(_pm1, _m2, _ret1, _ret2, _rl, _ru);
			else if( _ret1.sparse )
				matrixMultPermuteDenseSparse(_pm1, _m2, _ret1, _ret2, _rl, _ru);
			else 
				matrixMultPermuteDense(_pm1, _m2, _ret1, _ret2, _rl, _ru);

			return null;
		}
	}

	/**
	 * 
	 */
	private static interface ScalarResultTask extends Callable{
		public double getScalarResult();
	}
	
	/**
	 * 
	 * 
	 */
	private static class MatrixMultWSLossTask implements ScalarResultTask
	{
		private MatrixBlock _mX = null;
		private MatrixBlock _mU = null;
		private MatrixBlock _mV = null;
		private MatrixBlock _mW = null;
		private MatrixBlock _ret = null;
		private WeightsType _wt = null;
		private int _rl = -1;
		private int _ru = -1;

		protected MatrixMultWSLossTask(MatrixBlock mX, MatrixBlock mU, MatrixBlock mV, MatrixBlock mW, WeightsType wt, int rl, int ru) 
			throws DMLRuntimeException
		{
			_mX = mX;
			_mU = mU;
			_mV = mV;
			_mW = mW;
			_wt = wt;
			_rl = rl;
			_ru = ru;
			
			//allocate local result for partial aggregation
			_ret = new MatrixBlock(1, 1, false);
			_ret.allocateDenseBlock();
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			if( !_mX.sparse && !_mU.sparse && !_mV.sparse && (_mW==null || !_mW.sparse) 
				&& !_mX.isEmptyBlock() && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() 
				&& (_mW==null || !_mW.isEmptyBlock()))
				matrixMultWSLossDense(_mX, _mU, _mV, _mW, _ret, _wt, _rl, _ru);
			else if( _mX.sparse && !_mU.sparse && !_mV.sparse && (_mW==null || _mW.sparse)
				    && !_mX.isEmptyBlock() && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() 
				    && (_mW==null || !_mW.isEmptyBlock()))
				matrixMultWSLossSparseDense(_mX, _mU, _mV, _mW, _ret, _wt, _rl, _ru);
			else
				matrixMultWSLossGeneric(_mX, _mU, _mV, _mW, _ret, _wt, _rl, _ru);

			return null;
		}
		
		@Override
		public double getScalarResult() {
			return _ret.quickGetValue(0, 0);
		}
	}
	
	/**
	 * 
	 * 
	 */
	private static class MatrixMultWSigmoidTask implements Callable 
	{
		private MatrixBlock _mW = null;
		private MatrixBlock _mU = null;
		private MatrixBlock _mV = null;
		private MatrixBlock _ret = null;
		private WSigmoidType _wt = null;
		private int _rl = -1;
		private int _ru = -1;
		private long _nnz = -1;
		
		protected MatrixMultWSigmoidTask(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru) 
			throws DMLRuntimeException
		{
			_mW = mW;
			_mU = mU;
			_mV = mV;
			_ret = ret;
			_wt = wt;
			_rl = rl;
			_ru = ru;
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			//core weighted square sum mm computation
			if( !_mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() )
				matrixMultWSigmoidDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru);
			else if( _mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock())
				matrixMultWSigmoidSparseDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru);
			else
				matrixMultWSigmoidGeneric(_mW, _mU, _mV, _ret, _wt, _rl, _ru);
			
			//maintain block nnz (upper bounds inclusive)
			_nnz = _ret.recomputeNonZeros(_rl, _ru-1, 0, _ret.getNumColumns()-1);
			
			return null;
		}
		
		public long getPartialNnz(){
			return _nnz;
		}
	}
	
	/**
	 * 
	 * 
	 */
	private static class MatrixMultWDivTask implements Callable 
	{
		private MatrixBlock _mW = null;
		private MatrixBlock _mU = null;
		private MatrixBlock _mV = null;
		private MatrixBlock _ret = null;
		private WDivMMType _wt = null;
		private int _rl = -1;
		private int _ru = -1;
		private int _cl = -1;
		private int _cu = -1;
		private long _nnz = -1;
		
		protected MatrixMultWDivTask(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WDivMMType wt, int rl, int ru, int cl, int cu) 
			throws DMLRuntimeException
		{
			_mW = mW;
			_mU = mU;
			_mV = mV;
			_wt = wt;
			_rl = rl;
			_ru = ru;
			_cl = cl;
			_cu = cu;
			_ret = ret;	
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			//core weighted div mm computation
			if( !_mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() )
				matrixMultWDivMMDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru, _cl, _cu);
			else if( _mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock())
				matrixMultWDivMMSparseDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru, _cl, _cu);
			else
				matrixMultWDivMMGeneric(_mW, _mU, _mV, _ret, _wt, _rl, _ru, _cl, _cu);
		
			//maintain partial nnz for right (upper bounds inclusive)
			int rl = _wt.isLeft() ? _cl : _rl;
			int ru = _wt.isLeft() ? _cu : _ru;
			_nnz = _ret.recomputeNonZeros(rl, ru-1, 0, _ret.getNumColumns()-1);
				
			return null;
		}
		
		/**
		 * For wdivmm right.
		 * @return
		 */
		public long getPartialNnz(){
			return _nnz;
		}
	}
	
	private static class MatrixMultWCeTask implements ScalarResultTask
	{
		private MatrixBlock _mW = null;
		private MatrixBlock _mU = null;
		private MatrixBlock _mV = null;
		private MatrixBlock _ret = null;
		private WCeMMType _wt = null;
		private int _rl = -1;
		private int _ru = -1;

		protected MatrixMultWCeTask(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, WCeMMType wt, int rl, int ru) 
			throws DMLRuntimeException
		{
			_mW = mW;
			_mU = mU;
			_mV = mV;
			_wt = wt;
			_rl = rl;
			_ru = ru;
			
			//allocate local result for partial aggregation
			_ret = new MatrixBlock(1, 1, false);
			_ret.allocateDenseBlock();
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			//core weighted div mm computation
			if( !_mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() )
				matrixMultWCeMMDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru);
			else if( _mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock())
				matrixMultWCeMMSparseDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru);
			else
				matrixMultWCeMMGeneric(_mW, _mU, _mV, _ret, _wt, _rl, _ru);
			
			
			return null;
		}
		
		@Override
		public double getScalarResult() {
			return _ret.quickGetValue(0, 0);
		}
	}
	
	/**
	 * 
	 */
	private static class MatrixMultWuTask implements Callable 
	{
		private MatrixBlock _mW = null;
		private MatrixBlock _mU = null;
		private MatrixBlock _mV = null;
		private MatrixBlock _ret = null;
		private WUMMType _wt = null;
		private ValueFunction _fn = null;
		private int _rl = -1;
		private int _ru = -1;
		private long _nnz = -1;
		
		protected MatrixMultWuTask(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn, int rl, int ru) 
			throws DMLRuntimeException
		{
			_mW = mW;
			_mU = mU;
			_mV = mV;
			_ret = ret;
			_wt = wt;
			_fn = fn;
			_rl = rl;
			_ru = ru;
		}
		
		@Override
		public Object call() throws DMLRuntimeException
		{
			//core weighted square sum mm computation
			if( !_mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() )
				matrixMultWuMMDense(_mW, _mU, _mV, _ret, _wt, _fn, _rl, _ru);
			else if( _mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock())
				matrixMultWuMMSparseDense(_mW, _mU, _mV, _ret, _wt, _fn, _rl, _ru);
			else
				matrixMultWuMMGeneric(_mW, _mU, _mV, _ret, _wt, _fn, _rl, _ru);
			
			//maintain block nnz (upper bounds inclusive)
			_nnz = _ret.recomputeNonZeros(_rl, _ru-1, 0, _ret.getNumColumns()-1);
			
			return null;
		}
		
		public long getPartialNnz() {
			return _nnz;
		}
	}
}