Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.apache.sysml.runtime.matrix.data.LibMatrixMult Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.matrix.data;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.math3.util.FastMath;
import org.apache.sysml.lops.MapMultChain.ChainType;
import org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType;
import org.apache.sysml.lops.WeightedDivMM.WDivMMType;
import org.apache.sysml.lops.WeightedSigmoid.WSigmoidType;
import org.apache.sysml.lops.WeightedSquaredLoss.WeightsType;
import org.apache.sysml.lops.WeightedUnaryMM.WUMMType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
import org.apache.sysml.runtime.util.UtilFunctions;
/**
* MB:
* Library for matrix multiplications including MM, MV, VV for all
* combinations of dense, sparse, ultrasparse representations and special
* operations such as transpose-self matrix multiplication.
*
* In general all implementations use internally dense outputs
* for direct access, but change the final result to sparse if necessary.
* The only exceptions are ultra-sparse matrix mult, wsloss and wsigmoid.
*
* NOTES on BLAS:
* * Experiments in 04/2013 showed that even on dense-dense this implementation
* is 3x faster than f2j-BLAS-DGEMM, 2x faster than f2c-BLAS-DGEMM, and
* level (+10% after JIT) with a native C implementation.
* * Calling native BLAS would loose platform independence and would require
* JNI calls incl data transfer. Furthermore, BLAS does not support sparse
* matrices (except Sparse BLAS, with dedicated function calls and matrix formats)
* and would be an external dependency.
* * Experiments in 02/2014 showed that on dense-dense this implementation now achieves
* almost 30% peak FP performance. Compared to Intel MKL 11.1 (dgemm, N=1000) it is
* just 3.2x (sparsity=1.0) and 1.9x (sparsity=0.5) slower, respectively.
*
*/
public class LibMatrixMult
{
//internal configuration
private static final boolean LOW_LEVEL_OPTIMIZATION = true;
private static final long MEM_OVERHEAD_THRESHOLD = 2L*1024*1024; //MAX 2 MB
private static final long PAR_MINFLOP_THRESHOLD = 2L*1024*1024; //MIN 2 MFLOP
private static final int L2_CACHESIZE = 256 *1024; //256KB (common size)
private LibMatrixMult() {
//prevent instantiation via private constructor
}
////////////////////////////////
// public matrix mult interface
////////////////////////////////
/**
* Performs a matrix multiplication and stores the result in the output matrix.
*
* All variants use a IKJ access pattern, and internally use dense output. After the
* actual computation, we recompute nnz and check for sparse/dense representation.
*
*
* @param m1 first matrix
* @param m2 second matrix
* @param ret result matrix
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret)
throws DMLRuntimeException
{
matrixMult(m1, m2, ret, 0, m1.rlen);
}
/**
* This method allows one to disabling exam sparsity. This feature is useful if matrixMult is used as an intermediate
* operation (for example: LibMatrixDNN). It makes sense for LibMatrixDNN because the output is internally
* consumed by another dense instruction, which makes repeated conversion to sparse wasteful.
* This should be used in rare cases and if you are unsure,
* use the method 'matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret)' instead.
*
* @param m1 first matrix
* @param m2 second matrix
* @param ret result matrix
* @param examSparsity if false, sparsity examination is disabled
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean examSparsity)
throws DMLRuntimeException
{
matrixMult(m1, m2, ret, 0, m1.rlen, examSparsity);
}
public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru)
throws DMLRuntimeException
{
matrixMult(m1, m2, ret, rl, ru, true);
}
public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, boolean examSparsity)
throws DMLRuntimeException
{
//check inputs / outputs
if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//Timing time = new Timing(true);
//pre-processing: output allocation
boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
m2 = prepMatrixMultRightInput(m1, m2);
ret.sparse = (m1.isUltraSparse() || m2.isUltraSparse());
if( !ret.sparse )
ret.allocateDenseBlock();
//prepare row-upper for special cases of vector-matrix
boolean pm2 = checkParMatrixMultRightInputRows(m1, m2, Integer.MAX_VALUE);
int ru2 = (pm2 && ru==m1.rlen) ? m2.rlen : ru;
int cu = m2.clen;
//core matrix mult computation
if( m1.isUltraSparse() || m2.isUltraSparse() )
matrixMultUltraSparse(m1, m2, ret, 0, ru2);
else if(!m1.sparse && !m2.sparse)
matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, cu);
else if(m1.sparse && m2.sparse)
matrixMultSparseSparse(m1, m2, ret, pm2, 0, ru2);
else if(m1.sparse)
matrixMultSparseDense(m1, m2, ret, pm2, 0, ru2);
else
matrixMultDenseSparse(m1, m2, ret, pm2, 0, ru2);
//post-processing: nnz/representation
if( !ret.sparse )
ret.recomputeNonZeros();
if(examSparsity)
ret.examSparsity();
//System.out.println("MM ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" +
// "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop());
}
/**
* Performs a multi-threaded matrix multiplication and stores the result in the output matrix.
* The parameter k (k>=1) determines the max parallelism k' with k'=min(k, vcores, m1.rlen).
*
* @param m1 first matrix
* @param m2 second matrix
* @param ret result matrix
* @param k maximum parallelism
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k)
throws DMLRuntimeException
{
//check inputs / outputs
if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//check too high additional vector-matrix memory requirements (fallback to sequential)
//check too small workload in terms of flops (fallback to sequential too)
if( m1.rlen == 1 && (8L * m2.clen * k > MEM_OVERHEAD_THRESHOLD || !LOW_LEVEL_OPTIMIZATION || m2.clen==1 || m1.isUltraSparse() || m2.isUltraSparse())
|| 2L * m1.rlen * m1.clen * m2.clen < PAR_MINFLOP_THRESHOLD )
{
matrixMult(m1, m2, ret);
return;
}
//Timing time = new Timing(true);
//pre-processing: output allocation (in contrast to single-threaded,
//we need to allocate sparse as well in order to prevent synchronization)
boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
m2 = prepMatrixMultRightInput(m1, m2);
ret.sparse = (m1.isUltraSparse() || m2.isUltraSparse());
if( !ret.sparse )
ret.allocateDenseBlock();
else
ret.allocateSparseRowsBlock();
if (!ret.isThreadSafe()){
matrixMult(m1, m2, ret);
return;
}
//prepare row-upper for special cases of vector-matrix / matrix-matrix
boolean pm2r = checkParMatrixMultRightInputRows(m1, m2, k);
boolean pm2c = checkParMatrixMultRightInputCols(m1, m2, k, pm2r);
int num = pm2r ? m2.rlen : pm2c ? m2.clen : m1.rlen;
//core multi-threaded matrix mult computation
//(currently: always parallelization over number of rows)
try {
ExecutorService pool = Executors.newFixedThreadPool( k );
ArrayList tasks = new ArrayList();
int nk = (pm2r||pm2c) ? k : UtilFunctions.roundToNext(Math.min(8*k,num/32), k);
ArrayList blklens = getBalancedBlockSizes(num, nk);
for( int i=0, lb=0; i> taskret = pool.invokeAll(tasks);
pool.shutdown();
//aggregate partial results (nnz, ret for vector/matrix)
ret.nonZeros = 0; //reset after execute
for( Future task : taskret ) {
if( pm2r )
vectAdd((double[])task.get(), ret.denseBlock, 0, 0, ret.rlen*ret.clen);
else
ret.nonZeros += (Long)task.get();
}
if( pm2r )
ret.recomputeNonZeros();
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
//post-processing (nnz maintained in parallel)
ret.examSparsity();
//System.out.println("MM k="+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" +
// "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop());
}
/**
* Performs a matrix multiplication chain operation of type t(X)%*%(X%*%v) or t(X)%*%(w*(X%*%v)).
*
* All variants use a IKJ access pattern, and internally use dense output. After the
* actual computation, we recompute nnz and check for sparse/dense representation.
*
* @param mX X matrix
* @param mV v matrix
* @param mW w matrix
* @param ret result matrix
* @param ct chain type
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static void matrixMultChain(MatrixBlock mX, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, ChainType ct)
throws DMLRuntimeException
{
//check inputs / outputs (after that mV and mW guaranteed to be dense)
if( mX.isEmptyBlock(false) || mV.isEmptyBlock(false) || (mW !=null && mW.isEmptyBlock(false)) ) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//Timing time = new Timing(true);
//pre-processing: output allocation
ret.sparse = false;
ret.allocateDenseBlock();
//core matrix mult chain computation
if( mX.sparse )
matrixMultChainSparse(mX, mV, mW, ret, ct, 0, mX.rlen);
else
matrixMultChainDense(mX, mV, mW, ret, ct, 0, mX.rlen);
//post-processing
ret.recomputeNonZeros();
ret.examSparsity();
//System.out.println("MMChain "+ct.toString()+" ("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x" +
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
}
/**
* Performs a parallel matrix multiplication chain operation of type t(X)%*%(X%*%v) or t(X)%*%(w*(X%*%v)).
* The parameter k (k>=1) determines the max parallelism k' with k'=min(k, vcores, m1.rlen).
*
* NOTE: This multi-threaded mmchain operation has additional memory requirements of k*ncol(X)*8bytes
* for partial aggregation. Current max memory: 256KB; otherwise redirectly to sequential execution.
*
* @param mX X matrix
* @param mV v matrix
* @param mW w matrix
* @param ret result matrix
* @param ct chain type
* @param k maximum parallelism
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static void matrixMultChain(MatrixBlock mX, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, ChainType ct, int k)
throws DMLRuntimeException
{
//check inputs / outputs (after that mV and mW guaranteed to be dense)
if( mX.isEmptyBlock(false) || mV.isEmptyBlock(false) || (mW !=null && mW.isEmptyBlock(false)) ) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//check too high additional memory requirements (fallback to sequential)
//check too small workload in terms of flops (fallback to sequential too)
if( 8L * mV.rlen * k > MEM_OVERHEAD_THRESHOLD
|| 4L * mX.rlen * mX.clen < PAR_MINFLOP_THRESHOLD)
{
matrixMultChain(mX, mV, mW, ret, ct);
return;
}
//Timing time = new Timing(true);
//pre-processing (no need to check isThreadSafe)
ret.sparse = false;
ret.allocateDenseBlock();
//core matrix mult chain computation
//(currently: always parallelization over number of rows)
try {
ExecutorService pool = Executors.newFixedThreadPool( k );
ArrayList tasks = new ArrayList();
int blklen = (int)(Math.ceil((double)mX.rlen/k));
blklen += (blklen%24 != 0)?24-blklen%24:0;
for( int i=0; i> taskret = pool.invokeAll(tasks);
pool.shutdown();
//aggregate partial results
for( Future task : taskret )
vectAdd(task.get(), ret.denseBlock, 0, 0, mX.clen);
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
//post-processing
ret.recomputeNonZeros();
ret.examSparsity();
//System.out.println("MMChain "+ct.toString()+" k="+k+" ("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x" +
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
}
public static void matrixMultTransposeSelf( MatrixBlock m1, MatrixBlock ret, boolean leftTranspose )
throws DMLRuntimeException
{
//check inputs / outputs
if( m1.isEmptyBlock(false) ) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//Timing time = new Timing(true);
//pre-processing
m1 = prepMatrixMultTransposeSelfInput(m1, leftTranspose);
ret.sparse = false;
ret.allocateDenseBlock();
if( m1.sparse )
matrixMultTransposeSelfSparse(m1, ret, leftTranspose, 0, ret.rlen);
else
matrixMultTransposeSelfDense(m1, ret, leftTranspose, 0, ret.rlen );
//post-processing
copyUpperToLowerTriangle( ret );
ret.recomputeNonZeros();
ret.examSparsity();
//System.out.println("TSMM ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+","+leftTranspose+") in "+time.stop());
}
public static void matrixMultTransposeSelf( MatrixBlock m1, MatrixBlock ret, boolean leftTranspose, int k )
throws DMLRuntimeException
{
//check inputs / outputs
if( m1.isEmptyBlock(false) ) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//check no parallelization benefit (fallback to sequential)
//check too small workload in terms of flops (fallback to sequential too)
if( ret.rlen == 1
|| leftTranspose && 1L * m1.rlen * m1.clen * m1.clen < PAR_MINFLOP_THRESHOLD
|| !leftTranspose && 1L * m1.clen * m1.rlen * m1.rlen < PAR_MINFLOP_THRESHOLD)
{
matrixMultTransposeSelf(m1, ret, leftTranspose);
return;
}
//Timing time = new Timing(true);
//pre-processing (no need to check isThreadSafe)
m1 = prepMatrixMultTransposeSelfInput(m1, leftTranspose);
ret.sparse = false;
ret.allocateDenseBlock();
//core multi-threaded matrix mult computation
try {
ExecutorService pool = Executors.newFixedThreadPool( k );
ArrayList tasks = new ArrayList();
//load balance via #tasks=2k due to triangular shape
int blklen = (int)(Math.ceil((double)ret.rlen/(2*k)));
for( int i=0; i<2*k & i*blklen> rtasks = pool.invokeAll(tasks);
pool.shutdown();
for( Future rtask : rtasks )
rtask.get(); //error handling
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
//post-processing
copyUpperToLowerTriangle( ret );
ret.recomputeNonZeros();
ret.examSparsity();
//System.out.println("TSMM k="+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+","+leftTranspose+") in "+time.stop());
}
public static void matrixMultPermute( MatrixBlock pm1, MatrixBlock m2, MatrixBlock ret1, MatrixBlock ret2 )
throws DMLRuntimeException
{
//check inputs / outputs
if( pm1.isEmptyBlock(false) || m2.isEmptyBlock(false) )
return;
//Timing time = new Timing(true);
//pre-processing
ret1.sparse = (m2.sparse || ret1.sparse);
if( ret1.sparse )
ret1.allocateSparseRowsBlock();
else
ret1.allocateDenseBlock();
//core permutation mm computation
if( m2.sparse )
matrixMultPermuteSparse(pm1, m2, ret1, ret2, 0, pm1.rlen);
else if( ret1.sparse )
matrixMultPermuteDenseSparse(pm1, m2, ret1, ret2, 0, pm1.rlen);
else
matrixMultPermuteDense(pm1, m2, ret1, ret2, 0, pm1.rlen);
//post-processing
ret1.recomputeNonZeros();
ret1.examSparsity();
if( ret2 != null ) { //optional second output
ret2.recomputeNonZeros();
ret2.examSparsity();
}
//System.out.println("PMM Seq ("+pm1.isInSparseFormat()+","+pm1.getNumRows()+","+pm1.getNumColumns()+","+pm1.getNonZeros()+")x" +
// "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop());
}
public static void matrixMultPermute( MatrixBlock pm1, MatrixBlock m2, MatrixBlock ret1, MatrixBlock ret2, int k)
throws DMLRuntimeException
{
//check inputs / outputs
if( pm1.isEmptyBlock(false) || m2.isEmptyBlock(false) )
return;
//check no parallelization benefit (fallback to sequential)
if (pm1.rlen == 1) {
matrixMultPermute(pm1, m2, ret1, ret2);
return;
}
//Timing time = new Timing(true);
//allocate first output block (second allocated if needed)
ret1.sparse = false; // no need to check isThreadSafe
ret1.allocateDenseBlock();
try
{
ExecutorService pool = Executors.newFixedThreadPool(k);
ArrayList tasks = new ArrayList();
int blklen = (int)(Math.ceil((double)pm1.rlen/k));
for( int i=0; i tasks = new ArrayList();
int blklen = (int)(Math.ceil((double)mX.rlen/k));
for( int i=0; i> taskret = pool.invokeAll(tasks);
pool.shutdown();
//aggregate partial results
sumScalarResults(taskret, ret);
}
catch( Exception e ) {
throw new DMLRuntimeException(e);
}
//System.out.println("MMWSLoss "+wt.toString()+" k="+k+" ("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x" +
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
}
public static void matrixMultWSigmoid(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt)
throws DMLRuntimeException
{
//check for empty result
if( mW.isEmptyBlock(false) ) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//Timing time = new Timing(true);
//pre-processing
ret.sparse = mW.sparse;
ret.allocateDenseOrSparseBlock();
//core weighted square sum mm computation
if( !mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock() )
matrixMultWSigmoidDense(mW, mU, mV, ret, wt, 0, mW.rlen);
else if( mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock())
matrixMultWSigmoidSparseDense(mW, mU, mV, ret, wt, 0, mW.rlen);
else
matrixMultWSigmoidGeneric(mW, mU, mV, ret, wt, 0, mW.rlen);
//post-processing
ret.recomputeNonZeros();
ret.examSparsity();
//System.out.println("MMWSig "+wt.toString()+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" +
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
}
public static void matrixMultWSigmoid(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int k)
throws DMLRuntimeException
{
//check for empty result
if( mW.isEmptyBlock(false) ) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//check no parallelization benefit (fallback to sequential)
if (mW.rlen == 1 || !MatrixBlock.isThreadSafe(mW.sparse)) {
matrixMultWSigmoid(mW, mU, mV, ret, wt);
return;
}
//Timing time = new Timing(true);
//pre-processing
ret.sparse = mW.sparse;
ret.allocateDenseOrSparseBlock();
try
{
ExecutorService pool = Executors.newFixedThreadPool(k);
ArrayList tasks = new ArrayList();
int blklen = (int)(Math.ceil((double)mW.rlen/k));
for( int i=0; i> taskret = pool.invokeAll(tasks);
pool.shutdown();
//aggregate partial nnz and check for errors
ret.nonZeros = 0; //reset after execute
for( Future task : taskret )
ret.nonZeros += task.get();
}
catch (Exception e) {
throw new DMLRuntimeException(e);
}
//post-processing (nnz maintained in parallel)
ret.examSparsity();
//System.out.println("MMWSig "+wt.toString()+" k="+k+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" +
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop() + ".");
}
/**
* NOTE: This operation has limited NaN support, which is acceptable because all our sparse-safe operations
* have only limited NaN support. If this is not intended behavior, please disable the rewrite. In detail,
* this operator will produce for W/(U%*%t(V)) a zero intermediate for each zero in W (even if UVij is zero
* which would give 0/0=NaN) but INF/-INF for non-zero entries in V where the corresponding cell in (Y%*%X)
* is zero.
*
* @param mW matrix W
* @param mU matrix U
* @param mV matrix V
* @param mX matrix X
* @param ret result type
* @param wt weighted divide matrix multiplication type
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static void matrixMultWDivMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock mX, MatrixBlock ret, WDivMMType wt)
throws DMLRuntimeException
{
//check for empty result
if( mW.isEmptyBlock(false)
|| (wt.isLeft() && mU.isEmptyBlock(false))
|| (wt.isRight() && mV.isEmptyBlock(false))
|| (wt.isBasic() && mW.isEmptyBlock(false))) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//Timing time = new Timing(true);
//pre-processing
ret.sparse = wt.isBasic()?mW.sparse:false;
ret.allocateDenseOrSparseBlock();
//core weighted div mm computation
boolean scalarX = wt.hasScalar();
if( !mW.sparse && !mU.sparse && !mV.sparse && (mX==null || !mX.sparse || scalarX) && !mU.isEmptyBlock() && !mV.isEmptyBlock() )
matrixMultWDivMMDense(mW, mU, mV, mX, ret, wt, 0, mW.rlen, 0, mW.clen);
else if( mW.sparse && !mU.sparse && !mV.sparse && (mX==null || mX.sparse || scalarX) && !mU.isEmptyBlock() && !mV.isEmptyBlock())
matrixMultWDivMMSparseDense(mW, mU, mV, mX, ret, wt, 0, mW.rlen, 0, mW.clen);
else
matrixMultWDivMMGeneric(mW, mU, mV, mX, ret, wt, 0, mW.rlen, 0, mW.clen);
//post-processing
ret.recomputeNonZeros();
ret.examSparsity();
//System.out.println("MMWDiv "+wt.toString()+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" +
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
}
/**
* NOTE: This operation has limited NaN support, which is acceptable because all our sparse-safe operations
* have only limited NaN support. If this is not intended behavior, please disable the rewrite. In detail,
* this operator will produce for W/(U%*%t(V)) a zero intermediate for each zero in W (even if UVij is zero
* which would give 0/0=NaN) but INF/-INF for non-zero entries in V where the corresponding cell in (Y%*%X)
* is zero.
*
* @param mW matrix W
* @param mU matrix U
* @param mV matrix V
* @param mX matrix X
* @param ret result matrix
* @param wt weighted divide matrix multiplication type
* @param k maximum parallelism
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static void matrixMultWDivMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock mX, MatrixBlock ret, WDivMMType wt, int k)
throws DMLRuntimeException
{
//check for empty result
if( mW.isEmptyBlock(false)
|| (wt.isLeft() && mU.isEmptyBlock(false))
|| (wt.isRight() && mV.isEmptyBlock(false))
|| (wt.isBasic() && mW.isEmptyBlock(false))) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//Timing time = new Timing(true);
//pre-processing
ret.sparse = wt.isBasic()?mW.sparse:false;
ret.allocateDenseOrSparseBlock();
if (!ret.isThreadSafe()){
matrixMultWDivMM(mW, mU, mV, mX, ret, wt);
return;
}
try
{
ExecutorService pool = Executors.newFixedThreadPool(k);
ArrayList tasks = new ArrayList();
//create tasks (for wdivmm-left, parallelization over columns;
//for wdivmm-right, parallelization over rows; both ensure disjoint results)
if( wt.isLeft() ) {
int blklen = (int)(Math.ceil((double)mW.clen/k));
for( int j=0; j> taskret = pool.invokeAll(tasks);
pool.shutdown();
//aggregate partial nnz and check for errors
ret.nonZeros = 0; //reset after execute
for( Future task : taskret )
ret.nonZeros += task.get();
}
catch (Exception e) {
throw new DMLRuntimeException(e);
}
//post-processing
ret.examSparsity();
//System.out.println("MMWDiv "+wt.toString()+" k="+k+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" +
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
}
public static void matrixMultWCeMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, double eps, MatrixBlock ret, WCeMMType wt)
throws DMLRuntimeException
{
//check for empty result
if( mW.isEmptyBlock(false) ) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//Timing time = new Timing(true);
//pre-processing
ret.sparse = false;
ret.allocateDenseBlock();
//core weighted cross entropy mm computation
if( !mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock() )
matrixMultWCeMMDense(mW, mU, mV, eps, ret, wt, 0, mW.rlen);
else if( mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock())
matrixMultWCeMMSparseDense(mW, mU, mV, eps, ret, wt, 0, mW.rlen);
else
matrixMultWCeMMGeneric(mW, mU, mV, eps, ret, wt, 0, mW.rlen);
//System.out.println("MMWCe "+wt.toString()+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" +
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
}
public static void matrixMultWCeMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, double eps, MatrixBlock ret, WCeMMType wt, int k)
throws DMLRuntimeException
{
//check for empty result
if( mW.isEmptyBlock(false) ) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//Timing time = new Timing(true);
//pre-processing (no need to check isThreadSafe)
ret.sparse = false;
ret.allocateDenseBlock();
try
{
ExecutorService pool = Executors.newFixedThreadPool(k);
ArrayList tasks = new ArrayList();
int blklen = (int)(Math.ceil((double)mW.rlen/k));
for( int i=0; i> taskret = pool.invokeAll(tasks);
pool.shutdown();
//aggregate partial results
sumScalarResults(taskret, ret);
}
catch( Exception e ) {
throw new DMLRuntimeException(e);
}
//System.out.println("MMWCe "+wt.toString()+" k="+k+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" +
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
}
public static void matrixMultWuMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn)
throws DMLRuntimeException
{
//check for empty result
if( mW.isEmptyBlock(false) ) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//Timing time = new Timing(true);
//pre-processing
ret.sparse = mW.sparse;
ret.allocateDenseOrSparseBlock();
//core weighted square sum mm computation
if( !mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock() )
matrixMultWuMMDense(mW, mU, mV, ret, wt, fn, 0, mW.rlen);
else if( mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock())
matrixMultWuMMSparseDense(mW, mU, mV, ret, wt, fn, 0, mW.rlen);
else
matrixMultWuMMGeneric(mW, mU, mV, ret, wt, fn, 0, mW.rlen);
//post-processing
ret.recomputeNonZeros();
ret.examSparsity();
//System.out.println("MMWu "+wt.toString()+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" +
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
}
public static void matrixMultWuMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn, int k)
throws DMLRuntimeException
{
//check for empty result
if( mW.isEmptyBlock(false) ) {
ret.examSparsity(); //turn empty dense into sparse
return;
}
//check no parallelization benefit (fallback to sequential)
if (mW.rlen == 1 || !MatrixBlock.isThreadSafe(mW.sparse)) {
matrixMultWuMM(mW, mU, mV, ret, wt, fn);
return;
}
//Timing time = new Timing(true);
//pre-processing
ret.sparse = mW.sparse;
ret.allocateDenseOrSparseBlock();
try
{
ExecutorService pool = Executors.newFixedThreadPool(k);
ArrayList tasks = new ArrayList();
int blklen = (int)(Math.ceil((double)mW.rlen/k));
for( int i=0; i> taskret = pool.invokeAll(tasks);
pool.shutdown();
//aggregate partial nnz and check for errors
ret.nonZeros = 0; //reset after execute
for( Future task : taskret )
ret.nonZeros += task.get();
}
catch (Exception e) {
throw new DMLRuntimeException(e);
}
//post-processing (nnz maintained in parallel)
ret.examSparsity();
//System.out.println("MMWu "+wt.toString()+" k="+k+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" +
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop() + ".");
}
//////////////////////////////////////////
// optimized matrix mult implementation //
//////////////////////////////////////////
private static void matrixMultDenseDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean tm2, boolean pm2, int rl, int ru, int cl, int cu)
throws DMLRuntimeException
{
double[] a = m1.denseBlock;
double[] b = m2.denseBlock;
double[] c = ret.denseBlock;
final int m = m1.rlen;
final int n = m2.clen;
final int cd = m1.clen;
if( LOW_LEVEL_OPTIMIZATION )
{
if( m==1 && n==1 ) //DOT PRODUCT
{
c[0] = dotProduct(a, b, cd);
}
else if( n>1 && cd == 1 ) //OUTER PRODUCT
{
for( int i=rl, cix=rl*n; i < ru; i++, cix+=n) {
if( a[i] == 1 )
System.arraycopy(b, 0, c, cix, n);
else if( a[i] != 0 )
vectMultiplyWrite(a[i], b, c, 0, cix, n);
else
Arrays.fill(c, cix, cix+n, 0);
}
}
else if( n==1 && cd == 1 ) //VECTOR-SCALAR
{
vectMultiplyWrite(b[0], a, c, rl, rl, ru-rl);
}
else if( n==1 && cd<=2*1024 ) //MATRIX-VECTOR (short rhs)
{
for( int i=rl, aix=rl*cd; i < ru; i++, aix+=cd)
c[i] = dotProduct(a, b, aix, 0, cd);
}
else if( n==1 ) //MATRIX-VECTOR (tall rhs)
{
final int blocksizeI = 32;
final int blocksizeK = 2*1024; //16KB vector blocks (L1)
for( int bi=rl; bi n && cd > 64 && n < 64
//however, explicit flag required since dimension change m2
final int n2 = m2.rlen;
for( int i=rl, aix=rl*cd, cix=rl*n2; i < ru; i++, aix+=cd, cix+=n2 )
for( int j=0, bix=0; j 0 ) //for skipping empty rows
//rest not aligned to blocks of 4 rows
final int bn = knnz % 4;
switch( bn ){
case 1: vectMultiplyAdd(ta[0], b, c, tbi[0], cixj, bjlen); break;
case 2: vectMultiplyAdd2(ta[0],ta[1], b, c, tbi[0], tbi[1], cixj, bjlen); break;
case 3: vectMultiplyAdd3(ta[0],ta[1],ta[2], b, c, tbi[0], tbi[1],tbi[2], cixj, bjlen); break;
}
//compute blocks of 4 rows (core inner loop)
for( int k = bn; k=0) ? rlix : alen;
for( int k=rlix; k=0) ? k1 : apos+alen;
int k2 = (ru==cd) ? apos+alen : a.posFIndexGTE(i, ru);
k2 = (k2>=0) ? k2 : apos+alen;
//rest not aligned to blocks of 4 rows
final int bn = (k2-k1) % 4;
switch( bn ){
case 1: vectMultiplyAdd(avals[k1], b, c, aix[k1]*n, cix, n); break;
case 2: vectMultiplyAdd2(avals[k1],avals[k1+1], b, c, aix[k1]*n, aix[k1+1]*n, cix, n); break;
case 3: vectMultiplyAdd3(avals[k1],avals[k1+1],avals[k1+2], b, c, aix[k1]*n, aix[k1+1]*n, aix[k1+2]*n, cix, n); break;
}
//compute blocks of 4 rows (core inner loop)
for( int k = k1+bn; k=0) ? rlix : alen;
for( int k=rlix; k=0) ? rlix : apos+alen;
for(int i = rlix; i < apos+alen && aix[i]=0) ? rlix : apos+alen;
for(int i = rlix; i < apos+alen && aix[i]=0) ? rlix : apos+alen;
for(int i = rlix; i < apos+alen && aix[i]=0) ? rlix : apos+alen;
for(int i = rlix; i < apos+alen && aix[i] 0 ) //selected row
{
int bpos = (pos-1) % brlen;
int blk = (pos-1) / brlen;
//allocate and switch to second output block
//(never happens in cp, correct for multi-threaded usage)
if( lastblk!=-1 && lastblk 0 ) //selected row
{
int bpos = (pos-1) % brlen;
int blk = (pos-1) / brlen;
//allocate and switch to second output block
//(never happens in cp, correct for multi-threaded usage)
if( lastblk!=-1 && lastblk 0 ) //selected row
{
int bpos = (pos-1) % brlen;
int blk = (pos-1) / brlen;
//allocate and switch to second output block
//(never happens in cp, correct for multi-threaded usage)
if( lastblk!=-1 && lastblk=0) ? k : mW.clen;
}
//prepare alignment info if necessary
if( four && !scalar )
for( int i=bi; i=0) ? k : wpos+wlen;
for( ; k 1) //X%*%t(X) SPARSE MATRIX
{
//directly via LibMatrixReorg in order to prevent sparsity change
MatrixBlock tmpBlock = new MatrixBlock(m1.clen, m1.rlen, m1.sparse);
LibMatrixReorg.reorg(m1, tmpBlock, new ReorgOperator(SwapIndex.getSwapIndexFnObject()));
ret = tmpBlock;
}
return ret;
}
private static boolean checkPrepMatrixMultRightInput( MatrixBlock m1, MatrixBlock m2 )
{
//transpose if dense-dense, skinny rhs matrix (not vector), and memory guarded by output
return (LOW_LEVEL_OPTIMIZATION && !m1.sparse && !m2.sparse
&& m1.rlen > m2.clen && m2.rlen > 64 && m2.clen > 1 && m2.clen < 64
&& 8*m2.rlen*m2.clen < 256*1024 ); //rhs fits in L2 cache
}
private static boolean checkParMatrixMultRightInputRows( MatrixBlock m1, MatrixBlock m2, int k ) {
//parallelize over rows in rhs matrix if number of rows in lhs/output is very small
return (m1.rlen==1 && LOW_LEVEL_OPTIMIZATION && m2.clen>1 && !(m1.isUltraSparse()||m2.isUltraSparse()))
|| (m1.rlen<=16 && LOW_LEVEL_OPTIMIZATION && m2.clen>1 && m2.rlen > m1.rlen
&& ( !m1.isUltraSparse() && !m2.sparse ) //dense-dense / sparse/dense
&& (long)k * 8 * m1.rlen * m2.clen < MEM_OVERHEAD_THRESHOLD );
}
private static boolean checkParMatrixMultRightInputCols( MatrixBlock m1, MatrixBlock m2, int k, boolean pm2r ) {
//parallelize over cols in rhs matrix if dense, number of cols in rhs is large, and lhs fits in l2
return (LOW_LEVEL_OPTIMIZATION && !m1.sparse && !m2.sparse
&& m2.clen > k * 1024 && m1.rlen < k * 32 && !pm2r
&& 8*m1.rlen*m1.clen < 256*1024 ); //lhs fits in L2 cache
}
private static MatrixBlock prepMatrixMultRightInput( MatrixBlock m1, MatrixBlock m2 )
throws DMLRuntimeException
{
MatrixBlock ret = m2;
//transpose if dense-dense, skinny rhs matrix (not vector), and memory guarded by output
if( checkPrepMatrixMultRightInput(m1, m2) ) {
MatrixBlock tmpBlock = new MatrixBlock(m2.clen, m2.rlen, m2.sparse);
LibMatrixReorg.reorg(m2, tmpBlock, new ReorgOperator(SwapIndex.getSwapIndexFnObject()));
ret = tmpBlock;
}
return ret;
}
private static int copyNonZeroElements( double[] a, final int aixi, final int bk, final int bj, final int n, double[] tmpa, int[] tmpbi, final int bklen )
{
int knnz = 0;
for( int k = 0; k < bklen; k++ )
if( a[ aixi+k ] != 0 ) {
tmpa[ knnz ] = a[ aixi+k ];
tmpbi[ knnz ] = (bk+k) * n + bj; //scan index on b
knnz ++;
}
return knnz;
}
private static int copyNonZeroElements( double[] a, int aixi, final int bk, final int bj, final int n, final int nx, double[] tmpa, int[] tmpbi, final int bklen )
{
int knnz = 0;
for( int k = 0; k < bklen; k++, aixi+=n )
if( a[ aixi ] != 0 ) {
tmpa[ knnz ] = a[ aixi ];
tmpbi[ knnz ] = (bk+k) * nx + bj; //scan index on b
knnz ++;
}
return knnz;
}
private static void sumScalarResults(List> tasks, MatrixBlock ret)
throws InterruptedException, ExecutionException
{
//aggregate partial results and check for errors
double val = 0;
for(Future task : tasks)
val += task.get();
ret.quickSetValue(0, 0, val);
}
@SuppressWarnings("unused")
private static void sumDenseResults( double[][] partret, double[] ret )
{
final int len = ret.length;
final int k = partret.length;
final int bk = k % 4;
final int blocksize = 2 * 1024; //16KB (half of common L1 data)
//cache-conscious aggregation to prevent repreated scans/writes of ret
for( int bi=0; bi getBalancedBlockSizes(int len, int k) {
ArrayList ret = new ArrayList();
int base = len / k;
int rest = len % k;
for( int i=0; i 0 )
ret.add(val);
}
return ret;
}
/////////////////////////////////////////////////////////
// Task Implementations for Multi-Threaded Operations //
/////////////////////////////////////////////////////////
private static class MatrixMultTask implements Callable
{
private MatrixBlock _m1 = null;
private MatrixBlock _m2 = null;
private MatrixBlock _ret = null;
private boolean _tm2 = false; //transposed m2
private boolean _pm2r = false; //par over m2 rows
private boolean _pm2c = false; //par over m2 rows
private int _rl = -1;
private int _ru = -1;
protected MatrixMultTask( MatrixBlock m1, MatrixBlock m2, MatrixBlock ret,
boolean tm2, boolean pm2r, boolean pm2c, int rl, int ru )
{
_m1 = m1;
_m2 = m2;
_tm2 = tm2;
_pm2r = pm2r;
_pm2c = pm2c;
_rl = rl;
_ru = ru;
if( pm2r ) { //vector-matrix / matrix-matrix
//allocate local result for partial aggregation
_ret = new MatrixBlock(ret.rlen, ret.clen, false);
}
else { //default case
_ret = ret;
}
}
@Override
public Object call() throws DMLRuntimeException
{
//setup target index ranges
int rl = _pm2c ? 0 : _rl;
int ru = _pm2c ? _m1.rlen : _ru;
int cl = _pm2c ? _rl : 0;
int cu = _pm2c ? _ru : _ret.clen;
//thread-local allocation
if( _pm2r )
_ret.allocateDenseBlock();
//compute block matrix multiplication
if( _m1.isUltraSparse() || _m2.isUltraSparse() )
matrixMultUltraSparse(_m1, _m2, _ret, rl, ru);
else if(!_m1.sparse && !_m2.sparse)
matrixMultDenseDense(_m1, _m2, _ret, _tm2, _pm2r, rl, ru, cl, cu);
else if(_m1.sparse && _m2.sparse)
matrixMultSparseSparse(_m1, _m2, _ret, _pm2r, rl, ru);
else if(_m1.sparse)
matrixMultSparseDense(_m1, _m2, _ret, _pm2r, rl, ru);
else
matrixMultDenseSparse(_m1, _m2, _ret, _pm2r, rl, ru);
//maintain block nnz (upper bounds inclusive)
if( !_pm2r )
return _ret.recomputeNonZeros(rl, ru-1, cl, cu-1);
else
return _ret.getDenseBlock();
}
}
private static class MatrixMultChainTask implements Callable
{
private MatrixBlock _m1 = null;
private MatrixBlock _m2 = null;
private MatrixBlock _m3 = null;
private ChainType _ct = null;
private int _rl = -1;
private int _ru = -1;
protected MatrixMultChainTask( MatrixBlock mX, MatrixBlock mV, MatrixBlock mW, ChainType ct, int rl, int ru )
throws DMLRuntimeException
{
_m1 = mX;
_m2 = mV;
_m3 = mW;
_ct = ct;
_rl = rl;
_ru = ru;
}
@Override
public double[] call() throws DMLRuntimeException
{
//thread-local allocation for partial aggregation
MatrixBlock ret = new MatrixBlock(1, _m1.clen, false);
ret.allocateDenseBlock();
if( _m1.sparse )
matrixMultChainSparse(_m1, _m2, _m3, ret, _ct, _rl, _ru);
else
matrixMultChainDense(_m1, _m2, _m3, ret, _ct, _rl, _ru);
//NOTE: we dont do global aggregation from concurrent tasks in order
//to prevent synchronization (sequential aggregation led to better
//performance after JIT)
return ret.getDenseBlock();
}
}
private static class MatrixMultTransposeTask implements Callable
{
private final MatrixBlock _m1;
private final MatrixBlock _ret;
private final boolean _left;
private final int _rl;
private final int _ru;
protected MatrixMultTransposeTask( MatrixBlock m1, MatrixBlock ret, boolean left, int rl, int ru )
{
_m1 = m1;
_ret = ret;
_left = left;
_rl = rl;
_ru = ru;
}
@Override
public Object call() throws DMLRuntimeException
{
if( _m1.sparse )
matrixMultTransposeSelfSparse(_m1, _ret, _left, _rl, _ru);
else
matrixMultTransposeSelfDense(_m1, _ret, _left, _rl, _ru);
return null;
}
}
private static class MatrixMultPermuteTask implements Callable
{
private MatrixBlock _pm1 = null;
private MatrixBlock _m2 = null;
private MatrixBlock _ret1 = null;
private MatrixBlock _ret2 = null;
private int _rl = -1;
private int _ru = -1;
protected MatrixMultPermuteTask( MatrixBlock pm1, MatrixBlock m2, MatrixBlock ret1, MatrixBlock ret2, int rl, int ru)
{
_pm1 = pm1;
_m2 = m2;
_ret1 = ret1;
_ret2 = ret2;
_rl = rl;
_ru = ru;
}
@Override
public Object call() throws DMLRuntimeException
{
if( _m2.sparse )
matrixMultPermuteSparse(_pm1, _m2, _ret1, _ret2, _rl, _ru);
else if( _ret1.sparse )
matrixMultPermuteDenseSparse(_pm1, _m2, _ret1, _ret2, _rl, _ru);
else
matrixMultPermuteDense(_pm1, _m2, _ret1, _ret2, _rl, _ru);
return null;
}
}
private static class MatrixMultWSLossTask implements Callable
{
private MatrixBlock _mX = null;
private MatrixBlock _mU = null;
private MatrixBlock _mV = null;
private MatrixBlock _mW = null;
private MatrixBlock _ret = null;
private WeightsType _wt = null;
private int _rl = -1;
private int _ru = -1;
protected MatrixMultWSLossTask(MatrixBlock mX, MatrixBlock mU, MatrixBlock mV, MatrixBlock mW, WeightsType wt, int rl, int ru)
throws DMLRuntimeException
{
_mX = mX;
_mU = mU;
_mV = mV;
_mW = mW;
_wt = wt;
_rl = rl;
_ru = ru;
//allocate local result for partial aggregation
_ret = new MatrixBlock(1, 1, false);
_ret.allocateDenseBlock();
}
@Override
public Double call() throws DMLRuntimeException
{
if( !_mX.sparse && !_mU.sparse && !_mV.sparse && (_mW==null || !_mW.sparse)
&& !_mX.isEmptyBlock() && !_mU.isEmptyBlock() && !_mV.isEmptyBlock()
&& (_mW==null || !_mW.isEmptyBlock()))
matrixMultWSLossDense(_mX, _mU, _mV, _mW, _ret, _wt, _rl, _ru);
else if( _mX.sparse && !_mU.sparse && !_mV.sparse && (_mW==null || _mW.sparse)
&& !_mX.isEmptyBlock() && !_mU.isEmptyBlock() && !_mV.isEmptyBlock()
&& (_mW==null || !_mW.isEmptyBlock()))
matrixMultWSLossSparseDense(_mX, _mU, _mV, _mW, _ret, _wt, _rl, _ru);
else
matrixMultWSLossGeneric(_mX, _mU, _mV, _mW, _ret, _wt, _rl, _ru);
return _ret.quickGetValue(0, 0);
}
}
private static class MatrixMultWSigmoidTask implements Callable
{
private MatrixBlock _mW = null;
private MatrixBlock _mU = null;
private MatrixBlock _mV = null;
private MatrixBlock _ret = null;
private WSigmoidType _wt = null;
private int _rl = -1;
private int _ru = -1;
protected MatrixMultWSigmoidTask(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru)
throws DMLRuntimeException
{
_mW = mW;
_mU = mU;
_mV = mV;
_ret = ret;
_wt = wt;
_rl = rl;
_ru = ru;
}
@Override
public Long call() throws DMLRuntimeException
{
//core weighted square sum mm computation
if( !_mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() )
matrixMultWSigmoidDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru);
else if( _mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock())
matrixMultWSigmoidSparseDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru);
else
matrixMultWSigmoidGeneric(_mW, _mU, _mV, _ret, _wt, _rl, _ru);
//maintain block nnz (upper bounds inclusive)
return _ret.recomputeNonZeros(_rl, _ru-1, 0, _ret.getNumColumns()-1);
}
}
private static class MatrixMultWDivTask implements Callable
{
private MatrixBlock _mW = null;
private MatrixBlock _mU = null;
private MatrixBlock _mV = null;
private MatrixBlock _mX = null;
private MatrixBlock _ret = null;
private WDivMMType _wt = null;
private int _rl = -1;
private int _ru = -1;
private int _cl = -1;
private int _cu = -1;
protected MatrixMultWDivTask(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock mX, MatrixBlock ret, WDivMMType wt, int rl, int ru, int cl, int cu)
throws DMLRuntimeException
{
_mW = mW;
_mU = mU;
_mV = mV;
_mX = mX;
_wt = wt;
_rl = rl;
_ru = ru;
_cl = cl;
_cu = cu;
_ret = ret;
}
@Override
public Long call() throws DMLRuntimeException
{
//core weighted div mm computation
boolean scalarX = _wt.hasScalar();
if( !_mW.sparse && !_mU.sparse && !_mV.sparse && (_mX==null || !_mX.sparse || scalarX) && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() )
matrixMultWDivMMDense(_mW, _mU, _mV, _mX, _ret, _wt, _rl, _ru, _cl, _cu);
else if( _mW.sparse && !_mU.sparse && !_mV.sparse && (_mX==null || _mX.sparse || scalarX) && !_mU.isEmptyBlock() && !_mV.isEmptyBlock())
matrixMultWDivMMSparseDense(_mW, _mU, _mV, _mX, _ret, _wt, _rl, _ru, _cl, _cu);
else
matrixMultWDivMMGeneric(_mW, _mU, _mV, _mX, _ret, _wt, _rl, _ru, _cl, _cu);
//maintain partial nnz for right (upper bounds inclusive)
int rl = _wt.isLeft() ? _cl : _rl;
int ru = _wt.isLeft() ? _cu : _ru;
return _ret.recomputeNonZeros(rl, ru-1, 0, _ret.getNumColumns()-1);
}
}
private static class MatrixMultWCeTask implements Callable
{
private MatrixBlock _mW = null;
private MatrixBlock _mU = null;
private MatrixBlock _mV = null;
private double _eps = 0.0;
private MatrixBlock _ret = null;
private WCeMMType _wt = null;
private int _rl = -1;
private int _ru = -1;
protected MatrixMultWCeTask(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, double eps, WCeMMType wt, int rl, int ru)
throws DMLRuntimeException
{
_mW = mW;
_mU = mU;
_mV = mV;
_eps = eps;
_wt = wt;
_rl = rl;
_ru = ru;
//allocate local result for partial aggregation
_ret = new MatrixBlock(1, 1, false);
_ret.allocateDenseBlock();
}
@Override
public Double call() throws DMLRuntimeException
{
//core weighted cross entropy mm computation
if( !_mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() )
matrixMultWCeMMDense(_mW, _mU, _mV, _eps, _ret, _wt, _rl, _ru);
else if( _mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock())
matrixMultWCeMMSparseDense(_mW, _mU, _mV, _eps, _ret, _wt, _rl, _ru);
else
matrixMultWCeMMGeneric(_mW, _mU, _mV, _eps, _ret, _wt, _rl, _ru);
return _ret.quickGetValue(0, 0);
}
}
private static class MatrixMultWuTask implements Callable
{
private MatrixBlock _mW = null;
private MatrixBlock _mU = null;
private MatrixBlock _mV = null;
private MatrixBlock _ret = null;
private WUMMType _wt = null;
private ValueFunction _fn = null;
private int _rl = -1;
private int _ru = -1;
protected MatrixMultWuTask(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn, int rl, int ru)
throws DMLRuntimeException
{
_mW = mW;
_mU = mU;
_mV = mV;
_ret = ret;
_wt = wt;
_fn = fn;
_rl = rl;
_ru = ru;
}
@Override
public Long call() throws DMLRuntimeException
{
//core weighted square sum mm computation
if( !_mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() )
matrixMultWuMMDense(_mW, _mU, _mV, _ret, _wt, _fn, _rl, _ru);
else if( _mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock())
matrixMultWuMMSparseDense(_mW, _mU, _mV, _ret, _wt, _fn, _rl, _ru);
else
matrixMultWuMMGeneric(_mW, _mU, _mV, _ret, _wt, _fn, _rl, _ru);
//maintain block nnz (upper bounds inclusive)
return _ret.recomputeNonZeros(_rl, _ru-1, 0, _ret.getNumColumns()-1);
}
}
}