org.apache.sysml.runtime.matrix.data.LibMatrixOuterAgg Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.matrix.data;
import java.util.Arrays;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.Equals;
import org.apache.sysml.runtime.functionobjects.GreaterThan;
import org.apache.sysml.runtime.functionobjects.GreaterThanEquals;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.functionobjects.LessThan;
import org.apache.sysml.runtime.functionobjects.LessThanEquals;
import org.apache.sysml.runtime.functionobjects.NotEquals;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.functionobjects.ReduceCol;
import org.apache.sysml.runtime.functionobjects.ReduceRow;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
import org.apache.sysml.runtime.util.DataConverter;
import org.apache.sysml.runtime.util.SortUtils;
/**
* ACS:
* Purpose of this library is to make some of the unary outer aggregate operator more efficient.
* Today these operators are being handled through common operations.
* This library will expand per need and priority to include these operators through this support.
* To begin with, first operator being handled is unary aggregate for less than (<), rowsum operation.
* Other list will be added soon are rowsum on >, <=, >=, ==, and != operation.
*/
public class LibMatrixOuterAgg
{
private LibMatrixOuterAgg() {
//prevent instantiation via private constructor
}
/**
* This will return if uaggOp is of type RowIndexMax
*
* @param uaggOp aggregate unary operator
* @return true if aggregate unary operator is of type rowIndexMax
*/
public static boolean isRowIndexMax(AggregateUnaryOperator uaggOp)
{
return (uaggOp.aggOp.increOp.fn instanceof Builtin
&& (((Builtin)(uaggOp.aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MAXINDEX));
}
/**
* This will return if uaggOp is of type RowIndexMin
*
* @param uaggOp aggregate unary operator
* @return true if aggregate unary operator is of type rowIndexMin
*/
public static boolean isRowIndexMin(AggregateUnaryOperator uaggOp)
{
return (uaggOp.aggOp.increOp.fn instanceof Builtin
&& (((Builtin)(uaggOp.aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MININDEX));
}
/**
* This will return if uaggOp is of type RowIndexMin
*
* @param bOp binary operator
* @return true/false, based on if its one of the six operators (<, <=, >, >=, == and !=)
*/
public static boolean isCompareOperator(BinaryOperator bOp)
{
return ( bOp.fn instanceof LessThan || bOp.fn instanceof LessThanEquals // For operators <, <=,
|| bOp.fn instanceof GreaterThan || bOp.fn instanceof GreaterThanEquals // >, >=
|| bOp.fn instanceof Equals || bOp.fn instanceof NotEquals); // ==, !=
}
public static boolean isSupportedUaggOp( AggregateUnaryOperator uaggOp, BinaryOperator bOp )
{
boolean bSupported = false;
if(isCompareOperator(bOp)
&&
(uaggOp.aggOp.increOp.fn instanceof KahanPlus // Kahanplus
||
(isRowIndexMin(uaggOp) || isRowIndexMax(uaggOp))) // RowIndexMin or RowIndexMax
&&
(uaggOp.indexFn instanceof ReduceCol // ReduceCol
|| uaggOp.indexFn instanceof ReduceRow // ReduceRow
|| uaggOp.indexFn instanceof ReduceAll)) // ReduceAll
bSupported = true;
return bSupported;
}
public static int[] prepareRowIndices(int iCols, double vmb[], BinaryOperator bOp, AggregateUnaryOperator uaggOp)
throws DMLRuntimeException
{
return (isRowIndexMax(uaggOp)?prepareRowIndicesMax(iCols, vmb, bOp):prepareRowIndicesMin(iCols, vmb, bOp));
}
/**
* This function will return max indices, based on column vector data.
* This indices will be computed based on operator.
* These indices can be used to compute max index for a given input value in subsequent operation.
*
* e.g. Right Vector has data (V1) : 6 3 9 7 2 4 4 3
* Original indices for this data will be (I1): 1 2 3 4 5 6 7 8
*
* Sorting this data based on value will be (V2): 2 3 3 4 4 6 7 9
* Then indices will be ordered as (I2): 5 2 8 6 7 1 4 3
*
* CumMax of I2 will be A: (CumMin(I2)) 5 5 8 8 8 8 8 8
* CumMax of I2 in reverse order be B: 8 8 8 7 7 4 4 3
*
* Values from vector A is used to compute RowIndexMax for > & >= operators
* Values from vector B is used to compute RowIndexMax for < & <= operators
* Values from I2 is used to compute RowIndexMax for == operator.
* Original values are directly used to compute RowIndexMax for != operator
*
* Shifting values from vector A or B is required to compute final indices.
* Once indices are shifted from vector A or B, their cell value corresponding to input data will be used.
*
* @param iCols ?
* @param vmb ?
* @param bOp binary operator
* @return array of maximum row indices
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static int[] prepareRowIndicesMax(int iCols, double vmb[], BinaryOperator bOp) throws DMLRuntimeException
{
int[] vixCumSum = null;
int[] vix = new int[iCols];
//sort index vector on extracted data (unstable)
if(!(bOp.fn instanceof NotEquals)){
for( int i=0; i= 0 ){ //match, scan to next val
while( ix > 0 && value==bv[ix-1]) --ix;
ix++; //Readjust index to match subsenquent index calculation.
}
cnt = bv.length-Math.abs(ix)+1;
//cnt = Math.abs(ix) - 1;
if ((bOp.fn instanceof LessThan) || (bOp.fn instanceof GreaterThan))
cnt = bv.length - cnt;
return cnt;
}
/**
* Calculates the sum of number for rowSum of LessThan and GreaterThanEqual, and
* colSum of GreaterThan and LessThanEqual operators.
*
* @param value ?
* @param bv ?
* @param bOp binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static int sumRowSumLtGeColSumGtLe(double value, double[] bv, BinaryOperator bOp)
throws DMLRuntimeException
{
int ix = Arrays.binarySearch(bv, value);
int cnt = 0;
if( ix >= 0 ){ //match, scan to next val
while( value==bv[ix++] && ix= 0 ){ //match, scan to next val
while( ix > 0 && value==bv[ix-1]) --ix;
while( ix= 0 )
ixMax = bvi[ix]+1;
return ixMax;
}
/**
* Find out rowIndexMax for NotEqual operator.
*
* @param value ?
* @param bv ?
* @param bOp binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static int uarimaxNe(double value, double[] bv, int bvi[], BinaryOperator bOp)
throws DMLRuntimeException
{
int ixMax = bv.length;
if( bv[bv.length-1] == value )
ixMax = bvi[0]+1;
return ixMax;
}
/**
* Find out rowIndexMax for GreaterThan operator.
*
* @param value ?
* @param bv ?
* @param bOp binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static int uarimaxGt(double value, double[] bv, int bvi[], BinaryOperator bOp)
throws DMLRuntimeException
{
int ixMax = bv.length;
if(value <= bv[0] || value > bv[bv.length-1])
return ixMax;
int ix = Arrays.binarySearch(bv, value);
ix = Math.abs(ix)-1;
ixMax = bvi[ix-1]+1;
return ixMax;
}
/**
* Find out rowIndexMax for GreaterThanEqual operator.
*
* @param value ?
* @param bv ?
* @param bOp binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static int uarimaxGe(double value, double[] bv, int[] bvi, BinaryOperator bOp)
throws DMLRuntimeException
{
int ixMax = bv.length;
if(value < bv[0] || value >= bv[bv.length-1])
return ixMax;
int ix = Arrays.binarySearch(bv, value);
if(ix < 0)
ix = Math.abs(ix)-2;
ixMax = bvi[ix]+1;
return ixMax;
}
/**
* Find out rowIndexMax for LessThan operator.
*
* @param value ?
* @param bv ?
* @param bOp binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static int uarimaxLt(double value, double[] bv, int[] bvi, BinaryOperator bOp)
throws DMLRuntimeException
{
int ixMax = bv.length;
if(value < bv[0] || value >= bv[bv.length-1])
return ixMax;
int ix = Arrays.binarySearch(bv, value);
if (ix < 0)
ix = Math.abs(ix)-2;
ixMax = bvi[ix]+1;
return ixMax;
}
/**
* Find out rowIndexMax for LessThanEquals operator.
*
* @param value ?
* @param bv ?
* @param bOp binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static int uarimaxLe(double value, double[] bv, int[] bvi, BinaryOperator bOp)
throws DMLRuntimeException
{
int ixMax = bv.length;
if(value <= bv[0] || value > bv[bv.length-1])
return ixMax;
int ix = Arrays.binarySearch(bv, value);
if(ix < 0)
ix = Math.abs(ix)-1;
ixMax = bvi[ix]+1;
return ixMax;
}
/**
* Find out rowIndexMin for Equal operator.
*
* @param value ?
* @param bv ?
* @param bOp binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static int uariminEq(double value, double[] bv, int bvi[], BinaryOperator bOp)
throws DMLRuntimeException
{
int ixMin = 1;
if(value == bv[0])
ixMin = bvi[0]+1;
return ixMin;
}
/**
* Find out rowIndexMin for NotEqual operator.
*
* @param value ?
* @param bv ?
* @param bOp binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static int uariminNe(double value, double[] bv, int bvi[], BinaryOperator bOp)
throws DMLRuntimeException
{
int ixMin = 1;
if( bv[0] != value )
ixMin = bvi[0]+1;
return ixMin;
}
/**
* Find out rowIndexMin for GreaterThan operator.
*
* @param value ?
* @param bv ?
* @param bOp binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static int uariminGt(double value, double[] bv, int bvi[], BinaryOperator bOp)
throws DMLRuntimeException
{
int ixMin = 1;
if(value <= bv[0] || value > bv[bv.length-1])
return ixMin;
int ix = Arrays.binarySearch(bv, value);
ix = Math.abs(ix)-1;
ixMin = bvi[ix]+1;
return ixMin;
}
/**
* Find out rowIndexMin for GreaterThanEqual operator.
*
* @param value ?
* @param bv ?
* @param bOp binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static int uariminGe(double value, double[] bv, int[] bvi, BinaryOperator bOp)
throws DMLRuntimeException
{
int ixMin = 1;
if(value <= bv[0] || value > bv[bv.length-1])
return ixMin;
int ix = Arrays.binarySearch(bv, value);
if(ix < 0)
ix = Math.abs(ix)-1;
ixMin = bvi[ix-1]+1;
return ixMin;
}
/**
* Find out rowIndexMin for LessThan operator.
*
* @param value ?
* @param bv ?
* @param bOp binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static int uariminLt(double value, double[] bv, int[] bvi, BinaryOperator bOp)
throws DMLRuntimeException
{
int ixMin = 1;
if(value < bv[0] || value >= bv[bv.length-1])
return ixMin;
int ix = Arrays.binarySearch(bv, value);
if (ix < 0)
ix = Math.abs(ix)-2;
ixMin = bvi[ix]+1;
return ixMin;
}
/**
* Find out rowIndexMin for LessThanEquals operator.
*
* @param value ?
* @param bv ?
* @param bOp binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static int uariminLe(double value, double[] bv, int[] bvi, BinaryOperator bOp)
throws DMLRuntimeException
{
int ixMin = 1;
if(value < bv[0] || value > bv[bv.length-1])
return ixMin;
int ix = Arrays.binarySearch(bv, value);
if (ix < 0)
ix = Math.abs(ix)-1;
ixMin = bvi[ix]+1;
return ixMin;
}
/**
* This function adjusts indices to be leveraged in uarimaxXX functions.
* Initially vector containing indices are sorted based on value and then CumMax/CumMin
* per need for <, <=, >, >= operator, where as just sorted indices based on value for ==, and != operators.
* There is need to shift these indices for different operators, which is handled through this function.
*
* @param vix ?
* @param vmb ?
* @param bOp binary operator
*/
public static void adjustRowIndicesMax(int[] vix, double[] vmb,BinaryOperator bOp)
{
if (bOp.fn instanceof LessThan) {
shiftLeft(vix, vmb);
} else if ((bOp.fn instanceof GreaterThanEquals) || (bOp.fn instanceof Equals)) {
setMaxIndexInPartition(vix,vmb);
} else if(bOp.fn instanceof NotEquals) {
double dLastValue = vmb[vmb.length-1];
int i=vmb.length-1;
while(i>0 && dLastValue == vmb[i-1]) --i;
if (i > 0)
vix[0] = i-1;
else
vix[0] = vix.length-1;
}
}
/**
* This function adjusts indices to be leveraged in uariminXX functions.
* Initially vector containing indices are sorted based on value and then CumMin
* per need for <, <=, >, >= operator, where as just sorted indices based on value for ==, and != operators.
* There is need to shift these indices for different operators, which is handled through this function.
*
* @param vix ?
* @param vmb ?
* @param bOp binary operator
*/
public static void adjustRowIndicesMin(int[] vix, double[] vmb,BinaryOperator bOp)
{
if (bOp.fn instanceof GreaterThan) {
setMinIndexInPartition(vix, vmb);
}
else if (bOp.fn instanceof GreaterThanEquals) {
shiftLeft(vix, vmb);
}
else if (bOp.fn instanceof LessThanEquals) {
shiftRight(vix, vmb);
} else if(bOp.fn instanceof Equals) {
double dFirstValue = vmb[0];
int i=0;
while(i0;)
{
int iPrevInd = i;
double dPrevVal = vmb[iPrevInd];
while(i>=0 && dPrevVal == vmb[i]) --i;
if(i >= 0) {
for (int j = i+1; j<= iPrevInd; ++j)
vix[j] = vix[i];
}
}
}
/**
* This function will shift indices from one partition to next in left direction.
*
* For an example, if there are two sorted vector based on value like following, where
* V2 is sorted data, and I2 are its corresponding indices.
*
* Then this function will shift indices to right by one partition I2".
* Right most partition remained untouched.
*
* Sorting this data based on value will be (V2): 2 3 3 4 4 6 7 9
* Then indices will be ordered as (I2): 5 2 8 6 7 1 4 3
*
* Shift Left by one partition (I2") 2 8 6 7 1 4 3 3
*
* @param vix ?
* @param vmb ?
*/
public static void shiftLeft(int[] vix, double[] vmb)
{
int iCurInd = 0;
for (int i = 0; i < vix.length;i++)
{
double dPrevVal = vmb[iCurInd];
while(i 0;)
{
while(i>0 && dLastVal == vmb[i]) --i;
for (int j=i+1; j 0) {
iLastIndex = i;
dLastVal = vmb[i];
}
}
}
}