org.apache.sysml.runtime.matrix.data.LibMatrixAgg Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.matrix.data;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.sysml.lops.PartialAggregate.CorrectionLocationType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinFunctionCode;
import org.apache.sysml.runtime.functionobjects.CM;
import org.apache.sysml.runtime.functionobjects.IndexFunction;
import org.apache.sysml.runtime.functionobjects.KahanFunction;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.functionobjects.KahanPlusSq;
import org.apache.sysml.runtime.functionobjects.Mean;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.functionobjects.ReduceCol;
import org.apache.sysml.runtime.functionobjects.ReduceDiag;
import org.apache.sysml.runtime.functionobjects.ReduceRow;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
import org.apache.sysml.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
import org.apache.sysml.runtime.util.UtilFunctions;
/**
* MB:
* Library for matrix aggregations including ak+, uak+ for all
* combinations of dense and sparse representations, and corrections.
* Those are performance-critical operations because they are used
* on combiners/reducers of important operations like tsmm, mvmult,
* indexing, but also basic sum/min/max/mean, row*, col*, etc. Specific
* handling is especially required for all non sparse-safe operations
* in order to prevent unnecessary worse asymptotic behavior.
*
* This library currently covers the following opcodes:
* ak+, uak+, uark+, uack+, uasqk+, uarsqk+, uacsqk+,
* uamin, uarmin, uacmin, uamax, uarmax, uacmax,
* ua*, uamean, uarmean, uacmean, uavar, uarvar, uacvar,
* uarimax, uaktrace, cumk+, cummin, cummax, cum*, tak+.
*
* TODO next opcode extensions: a+, colindexmax
*/
public class LibMatrixAgg
{
//internal configuration parameters
private static final boolean NAN_AWARENESS = false;
private static final long PAR_NUMCELL_THRESHOLD = 1024*1024; //Min 1M elements
private static final long PAR_INTERMEDIATE_SIZE_THRESHOLD = 2*1024*1024; //Max 2MB
////////////////////////////////
// public matrix agg interface
////////////////////////////////
private enum AggType {
KAHAN_SUM,
KAHAN_SUM_SQ,
CUM_KAHAN_SUM,
CUM_MIN,
CUM_MAX,
CUM_PROD,
MIN,
MAX,
MEAN,
VAR,
MAX_INDEX,
MIN_INDEX,
PROD,
INVALID,
}
private LibMatrixAgg() {
//prevent instantiation via private constructor
}
/**
* Core incremental matrix aggregate (ak+) as used in mapmult, tsmm,
* cpmm, etc. Note that we try to keep the current
* aggVal and aggCorr in dense format in order to allow efficient
* access according to the dense/sparse input.
*
*
* @param in input matrix
* @param aggVal current aggregate values (in/out)
* @param aggCorr current aggregate correction (in/out)
* @throws DMLRuntimeException
*/
public static void aggregateBinaryMatrix(MatrixBlock in, MatrixBlock aggVal, MatrixBlock aggCorr)
throws DMLRuntimeException
{
//Timing time = new Timing(true);
//boolean saggVal = aggVal.isInSparseFormat(), saggCorr = aggCorr.isInSparseFormat();
//long naggVal = aggVal.getNonZeros(), naggCorr = aggCorr.getNonZeros();
//core aggregation
if(!in.sparse && !aggVal.sparse && !aggCorr.sparse)
aggregateBinaryMatrixAllDense(in, aggVal, aggCorr);
else if(in.sparse && !aggVal.sparse && !aggCorr.sparse)
aggregateBinaryMatrixSparseDense(in, aggVal, aggCorr);
else if(in.sparse ) //any aggVal, aggCorr
aggregateBinaryMatrixSparseGeneric(in, aggVal, aggCorr);
else //if( !in.sparse ) //any aggVal, aggCorr
aggregateBinaryMatrixDenseGeneric(in, aggVal, aggCorr);
//System.out.println("agg ("+in.rlen+","+in.clen+","+in.getNonZeros()+","+in.sparse+"), " +
// "("+naggVal+","+saggVal+"), ("+naggCorr+","+saggCorr+") -> " +
// "("+aggVal.getNonZeros()+","+aggVal.isInSparseFormat()+"), ("+aggCorr.getNonZeros()+","+aggCorr.isInSparseFormat()+") " +
// "in "+time.stop()+"ms.");
}
/**
* Core incremental matrix aggregate (ak+) as used for uack+ and acrk+.
* Embedded correction values.
*
* @param in
* @param aggVal
* @throws DMLRuntimeException
*/
public static void aggregateBinaryMatrix(MatrixBlock in, MatrixBlock aggVal, AggregateOperator aop)
throws DMLRuntimeException
{
//sanity check matching dimensions
if( in.getNumRows()!=aggVal.getNumRows() || in.getNumColumns()!=aggVal.getNumColumns() )
throw new DMLRuntimeException("Dimension mismatch on aggregate: "+in.getNumRows()+"x"+in.getNumColumns()+
" vs "+aggVal.getNumRows()+"x"+aggVal.getNumColumns());
//Timing time = new Timing(true);
//core aggregation
boolean lastRowCorr = (aop.correctionLocation == CorrectionLocationType.LASTROW);
boolean lastColCorr = (aop.correctionLocation == CorrectionLocationType.LASTCOLUMN);
if( !in.sparse && lastRowCorr )
aggregateBinaryMatrixLastRowDenseGeneric(in, aggVal);
else if( in.sparse && lastRowCorr )
aggregateBinaryMatrixLastRowSparseGeneric(in, aggVal);
else if( !in.sparse && lastColCorr )
aggregateBinaryMatrixLastColDenseGeneric(in, aggVal);
else //if( in.sparse && lastColCorr )
aggregateBinaryMatrixLastColSparseGeneric(in, aggVal);
//System.out.println("agg ("+in.rlen+","+in.clen+","+in.getNonZeros()+","+in.sparse+"), ("+naggVal+","+saggVal+") -> " +
// "("+aggVal.getNonZeros()+","+aggVal.isInSparseFormat()+") in "+time.stop()+"ms.");
}
/**
*
* @param in
* @param out
* @param vFn
* @param ixFn
* @throws DMLRuntimeException
*/
public static void aggregateUnaryMatrix(MatrixBlock in, MatrixBlock out, AggregateUnaryOperator uaop)
throws DMLRuntimeException
{
//prepare meta data
AggType aggtype = getAggType(uaop);
final int m = in.rlen;
final int m2 = out.rlen;
final int n2 = out.clen;
//filter empty input blocks (incl special handling for sparse-unsafe operations)
if( in.isEmptyBlock(false) ){
aggregateUnaryMatrixEmpty(in, out, aggtype, uaop.indexFn);
return;
}
//Timing time = new Timing(true);
//allocate output arrays (if required)
out.reset(m2, n2, false); //always dense
out.allocateDenseBlock();
if( !in.sparse )
aggregateUnaryMatrixDense(in, out, aggtype, uaop.aggOp.increOp.fn, uaop.indexFn, 0, m);
else
aggregateUnaryMatrixSparse(in, out, aggtype, uaop.aggOp.increOp.fn, uaop.indexFn, 0, m);
//cleanup output and change representation (if necessary)
out.recomputeNonZeros();
out.examSparsity();
//System.out.println("uagg ("+in.rlen+","+in.clen+","+in.sparse+") in "+time.stop()+"ms.");
}
/**
*
* @param in
* @param out
* @param uaop
* @param k
* @throws DMLRuntimeException
*/
public static void aggregateUnaryMatrix(MatrixBlock in, MatrixBlock out, AggregateUnaryOperator uaop, int k)
throws DMLRuntimeException
{
//fall back to sequential version if necessary
if( k <= 1 || (long)in.rlen*in.clen < PAR_NUMCELL_THRESHOLD || in.rlen <= k
|| (!(uaop.indexFn instanceof ReduceCol) && out.clen*8*k > PAR_INTERMEDIATE_SIZE_THRESHOLD ) ) {
aggregateUnaryMatrix(in, out, uaop);
return;
}
//prepare meta data
AggType aggtype = getAggType(uaop);
final int m = in.rlen;
final int m2 = out.rlen;
final int n2 = out.clen;
//filter empty input blocks (incl special handling for sparse-unsafe operations)
if( in.isEmptyBlock(false) ){
aggregateUnaryMatrixEmpty(in, out, aggtype, uaop.indexFn);
return;
}
//Timing time = new Timing(true);
//allocate output arrays (if required)
if( uaop.indexFn instanceof ReduceCol ) {
out.reset(m2, n2, false); //always dense
out.allocateDenseBlock();
}
//core multi-threaded unary aggregate computation
//(currently: always parallelization over number of rows)
try {
ExecutorService pool = Executors.newFixedThreadPool( k );
ArrayList tasks = new ArrayList();
int blklen = (int)(Math.ceil((double)m/k));
for( int i=0; i tasks = new ArrayList();
int blklen = (int)(Math.ceil((double)in1.rlen/k));
for( int i=0; i1);
if( k <= 1 || (long)target.rlen*target.clen < PAR_NUMCELL_THRESHOLD || rowVector || target.clen==1 ) {
groupedAggregate(groups, target, weights, result, numGroups, op);
return;
}
if( !(op instanceof CMOperator || op instanceof AggregateOperator) ) {
throw new DMLRuntimeException("Invalid operator (" + op + ") encountered while processing groupedAggregate.");
}
//preprocessing
result.sparse = false;
result.allocateDenseBlock();
//core multi-threaded grouped aggregate computation
//(currently: parallelization over columns to avoid additional memory requirements)
try {
ExecutorService pool = Executors.newFixedThreadPool( k );
ArrayList tasks = new ArrayList();
int blklen = (int)(Math.ceil((double)target.clen/k));
for( int i=0; i1);
int numCols = (!rowVector) ? target.getNumColumns() : 1;
double w = 1; //default weight
//skip empty blocks (sparse-safe operation)
if( target.isEmptyBlock(false) )
return;
//init group buffers
int numCols2 = cu-cl;
KahanObject[][] buffer = new KahanObject[numGroups][numCols2];
for( int i=0; i numGroups )
continue;
if ( weights != null )
w = weights.quickGetValue(aix[j],0);
aggop.increOp.fn.execute(buffer[g-1][0], avals[j]*w);
}
}
}
else //DENSE target
{
for ( int i=0; i < target.getNumColumns(); i++ ) {
double d = target.denseBlock[ i ];
if( d != 0 ) //sparse-safe
{
int g = (int) groups.quickGetValue(i, 0);
if ( g > numGroups )
continue;
if ( weights != null )
w = weights.quickGetValue(i,0);
// buffer is 0-indexed, whereas range of values for g = [1,numGroups]
aggop.increOp.fn.execute(buffer[g-1][0], d*w);
}
}
}
}
else //column vector or matrix
{
if( target.sparse ) //SPARSE target
{
SparseRow[] a = target.sparseRows;
for( int i=0; i < groups.getNumRows(); i++ )
{
int g = (int) groups.quickGetValue(i, 0);
if ( g > numGroups )
continue;
if( a[i] != null && !a[i].isEmpty() )
{
int len = a[i].size();
int[] aix = a[i].getIndexContainer();
double[] avals = a[i].getValueContainer();
int j = (cl==0) ? 0 : a[i].searchIndexesFirstGTE(cl);
j = (j>=0) ? j : len;
for( ; j numGroups )
continue;
for( int j=cl; j < cu; j++ ) {
double d = a[ aix+j ];
if( d != 0 ) { //sparse-safe
if ( weights != null )
w = weights.quickGetValue(i,0);
// buffer is 0-indexed, whereas range of values for g = [1,numGroups]
aggop.increOp.fn.execute(buffer[g-1][j-cl], d*w);
}
}
}
}
}
// extract the results from group buffers
for( int i=0; i < numGroups; i++ )
for( int j=0; j < numCols2; j++ )
result.appendValue(i, j+cl, buffer[i][j]._sum);
}
/**
*
* @param target
* @param weights
* @param result
* @param cmOp
* @throws DMLRuntimeException
*/
private static void groupedAggregateCM( MatrixBlock groups, MatrixBlock target, MatrixBlock weights, MatrixBlock result, int numGroups, CMOperator cmOp, int cl, int cu )
throws DMLRuntimeException
{
CM cmFn = CM.getCMFnObject(((CMOperator) cmOp).getAggOpType());
double w = 1; //default weight
//init group buffers
int numCols2 = cu-cl;
CM_COV_Object[][] cmValues = new CM_COV_Object[numGroups][numCols2];
for ( int i=0; i < numGroups; i++ )
for( int j=0; j < numCols2; j++ )
cmValues[i][j] = new CM_COV_Object();
//column vector or matrix
if( target.sparse ) //SPARSE target
{
SparseRow[] a = target.sparseRows;
for( int i=0; i < groups.getNumRows(); i++ )
{
int g = (int) groups.quickGetValue(i, 0);
if ( g > numGroups )
continue;
if( a[i] != null && !a[i].isEmpty() )
{
int len = a[i].size();
int[] aix = a[i].getIndexContainer();
double[] avals = a[i].getValueContainer();
int j = (cl==0) ? 0 : a[i].searchIndexesFirstGTE(cl);
j = (j>=0) ? j : len;
for( ; j numGroups )
continue;
for( int j=cl; j numGroups )
continue;
tmp[g-1]++;
}
//copy counts into result
for( int i=0; i=0; j--, ix-- )
if( aix[j]!=ix )
break;
c[cix+0] = ix + 1; //max index (last)
c[cix+1] = 0; //max value
}
}
else //if( arow==null )
{
//correction (not sparse-safe)
c[cix+0] = n; //max index (last)
c[cix+1] = 0; //max value
}
}
}
/**
* ROWINDEXMIN, opcode: uarimin, sparse input.
*
* @param a
* @param c
* @param m
* @param n
* @param init
* @param builtin
*/
private static void s_uarimin( SparseRow[] a, double[] c, int m, int n, double init, Builtin builtin, int rl, int ru )
{
for( int i=rl, cix=rl*2; i=0; j--, ix-- )
if( aix[j]!=ix )
break;
c[cix+0] = ix + 1; //min index (last)
c[cix+1] = 0; //min value
}
}
else //if( arow==null )
{
//correction (not sparse-safe)
c[cix+0] = n; //min index (last)
c[cix+1] = 0; //min value
}
}
}
/**
* MEAN, opcode: uamean, sparse input.
*
* @param a
* @param c
* @param m
* @param n
* @param kbuff
* @param kplus
*/
private static void s_uamean( SparseRow[] a, double[] c, int m, int n, KahanObject kbuff, Mean kmean, int rl, int ru )
{
int len = (ru-rl) * n;
int count = 0;
//correction remaining tuples (not sparse-safe)
//note: before aggregate computation in order to
//exploit 0 sum (noop) and better numerical stability
for( int i=rl; i=maxval) ? i-ai : maxindex;
maxval = (a[i]>=maxval) ? a[i] : maxval;
}
return maxindex;
}
/**
*
* @param a
* @param ai
* @param init
* @param len
* @param aggop
* @return
*/
private static int indexmin( double[] a, int ai, final double init, final int len, Builtin aggop )
{
double minval = init;
int minindex = -1;
for( int i=ai; i {}
/**
*
*
*/
private static class RowAggTask extends AggTask
{
private MatrixBlock _in = null;
private MatrixBlock _ret = null;
private AggType _aggtype = null;
private AggregateUnaryOperator _uaop = null;
private int _rl = -1;
private int _ru = -1;
protected RowAggTask( MatrixBlock in, MatrixBlock ret, AggType aggtype, AggregateUnaryOperator uaop, int rl, int ru )
{
_in = in;
_ret = ret;
_aggtype = aggtype;
_uaop = uaop;
_rl = rl;
_ru = ru;
}
@Override
public Object call() throws DMLRuntimeException
{
if( !_in.sparse )
aggregateUnaryMatrixDense(_in, _ret, _aggtype, _uaop.aggOp.increOp.fn, _uaop.indexFn, _rl, _ru);
else
aggregateUnaryMatrixSparse(_in, _ret, _aggtype, _uaop.aggOp.increOp.fn, _uaop.indexFn, _rl, _ru);
return null;
}
}
/**
*
*
*/
private static class PartialAggTask extends AggTask
{
private MatrixBlock _in = null;
private MatrixBlock _ret = null;
private AggType _aggtype = null;
private AggregateUnaryOperator _uaop = null;
private int _rl = -1;
private int _ru = -1;
protected PartialAggTask( MatrixBlock in, MatrixBlock ret, AggType aggtype, AggregateUnaryOperator uaop, int rl, int ru )
throws DMLRuntimeException
{
_in = in;
_aggtype = aggtype;
_uaop = uaop;
_rl = rl;
_ru = ru;
//allocate local result for partial aggregation
_ret = new MatrixBlock(ret.rlen, ret.clen, false);
_ret.allocateDenseBlock();
}
@Override
public Object call() throws DMLRuntimeException
{
if( !_in.sparse )
aggregateUnaryMatrixDense(_in, _ret, _aggtype, _uaop.aggOp.increOp.fn, _uaop.indexFn, _rl, _ru);
else
aggregateUnaryMatrixSparse(_in, _ret, _aggtype, _uaop.aggOp.increOp.fn, _uaop.indexFn, _rl, _ru);
//recompute non-zeros of partial result
_ret.recomputeNonZeros();
return null;
}
public MatrixBlock getResult() {
return _ret;
}
}
/**
*
*/
private static class AggTernaryTask extends AggTask
{
private MatrixBlock _in1 = null;
private MatrixBlock _in2 = null;
private MatrixBlock _in3 = null;
private double _ret = -1;
private int _rl = -1;
private int _ru = -1;
protected AggTernaryTask( MatrixBlock in1, MatrixBlock in2, MatrixBlock in3, int rl, int ru )
throws DMLRuntimeException
{
_in1 = in1;
_in2 = in2;
_in3 = in3;
_rl = rl;
_ru = ru;
}
@Override
public Object call() throws DMLRuntimeException
{
if( !_in1.sparse && !_in2.sparse && (_in3==null||!_in3.sparse) ) //DENSE
_ret = aggregateTernaryDense(_in1, _in2, _in3, _rl, _ru);
else //GENERAL CASE
_ret = aggregateTernaryGeneric(_in1, _in2, _in3, _rl, _ru);
return null;
}
public double getResult() {
return _ret;
}
}
private static class GrpAggTask extends AggTask
{
private MatrixBlock _groups = null;
private MatrixBlock _target = null;
private MatrixBlock _weights = null;
private MatrixBlock _ret = null;
private int _numGroups = -1;
private Operator _op = null;
private int _cl = -1;
private int _cu = -1;
protected GrpAggTask( MatrixBlock groups, MatrixBlock target, MatrixBlock weights, MatrixBlock ret, int numGroups, Operator op, int cl, int cu )
throws DMLRuntimeException
{
_groups = groups;
_target = target;
_weights = weights;
_ret = ret;
_numGroups = numGroups;
_op = op;
_cl = cl;
_cu = cu;
}
@Override
public Object call() throws DMLRuntimeException
{
//CM operator for count, mean, variance
//note: current support only for column vectors
if( _op instanceof CMOperator ) {
CMOperator cmOp = (CMOperator) _op;
groupedAggregateCM(_groups, _target, _weights, _ret, _numGroups, cmOp, _cl, _cu);
}
//Aggregate operator for sum (via kahan sum)
//note: support for row/column vectors and dense/sparse
else if( _op instanceof AggregateOperator ) {
AggregateOperator aggop = (AggregateOperator) _op;
groupedAggregateKahanPlus(_groups, _target, _weights, _ret, _numGroups, aggop, _cl, _cu);
}
return null;
}
}
}