All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.matrix.data.LibMatrixOuterAgg Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.matrix.data;

import java.util.Arrays;

import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.Equals;
import org.apache.sysml.runtime.functionobjects.GreaterThan;
import org.apache.sysml.runtime.functionobjects.GreaterThanEquals;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.functionobjects.LessThan;
import org.apache.sysml.runtime.functionobjects.LessThanEquals;
import org.apache.sysml.runtime.functionobjects.NotEquals;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.functionobjects.ReduceCol;
import org.apache.sysml.runtime.functionobjects.ReduceRow;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
import org.apache.sysml.runtime.util.DataConverter;
import org.apache.sysml.runtime.util.SortUtils;

/**
 * ACS:
 * Purpose of this library is to make some of the unary outer aggregate operator more efficient.
 * Today these operators are being handled through common operations.
 * This library will expand per need and priority to include these operators through this support.
 * To begin with, first operator being handled is unary aggregate for less than (<), rowsum operation.
 * Other list will be added soon are rowsum on >, <=, >=, ==, and != operation.  
 */
public class LibMatrixOuterAgg 
{

	private LibMatrixOuterAgg() {
		//prevent instantiation via private constructor
	}

	
	/**
	 * This will return if uaggOp is of type RowIndexMax
	 * 
	 * @param uaggOp
	 * @return
	 */
	public static boolean isRowIndexMax(AggregateUnaryOperator uaggOp)
	{
		return 	(uaggOp.aggOp.increOp.fn instanceof Builtin														
			    && (((Builtin)(uaggOp.aggOp.increOp.fn)).bFunc == Builtin.BuiltinFunctionCode.MAXINDEX));						
	}
	
	/**
	 * This will return if uaggOp is of type RowIndexMin
	 * 
	 * @param uaggOp
	 * @return
	 */
	public static boolean isRowIndexMin(AggregateUnaryOperator uaggOp)
	{
		return 	(uaggOp.aggOp.increOp.fn instanceof Builtin									
			    && (((Builtin)(uaggOp.aggOp.increOp.fn)).bFunc == Builtin.BuiltinFunctionCode.MININDEX));						
	}
	
	
	/**
	 * This will return if uaggOp is of type RowIndexMin
	 * 
	 * @param bOp
	 * @return true/false, based on if its one of the six operators (<, <=, >, >=, == and !=)
	 */
	public static boolean isCompareOperator(BinaryOperator bOp)
	{
		return ( bOp.fn instanceof LessThan || bOp.fn instanceof LessThanEquals		// For operators <, <=,  
			|| bOp.fn instanceof GreaterThan || bOp.fn instanceof GreaterThanEquals //				 >, >=
			|| bOp.fn instanceof Equals || bOp.fn instanceof NotEquals);				//				==, !=
	}
		
		
			/**
	 * @param uaggOp
	 * @param bOp
	 * @return
	 */
	public static boolean isSupportedUaggOp( AggregateUnaryOperator uaggOp, BinaryOperator bOp )
	{
		boolean bSupported = false;
		
		if(isCompareOperator(bOp)
			&& 
				(uaggOp.aggOp.increOp.fn instanceof KahanPlus					// Kahanplus
			    ||		
				(isRowIndexMin(uaggOp) || isRowIndexMax(uaggOp)))				// RowIndexMin or RowIndexMax 						
			&&			
				(uaggOp.indexFn instanceof ReduceCol							// ReduceCol 
					|| uaggOp.indexFn instanceof ReduceRow 						// ReduceRow
					|| uaggOp.indexFn instanceof ReduceAll))					// ReduceAll
			
			bSupported = true;
		
		return bSupported;
			
	}

	/*
	 * 
	 * @param iCols
	 * @param vmb
	 * @param bOp
	 * @param uaggOp
	 * 
	 */
	public static int[] prepareRowIndices(int iCols, double vmb[], BinaryOperator bOp, AggregateUnaryOperator uaggOp) throws DMLRuntimeException, DMLUnsupportedOperationException
	{
		return (isRowIndexMax(uaggOp)?prepareRowIndicesMax(iCols, vmb, bOp):prepareRowIndicesMin(iCols, vmb, bOp));
	}
	
	/**
	 * This function will return max indices, based on column vector data. 
	 * This indices will be computed based on operator. 
	 * These indices can be used to compute max index for a given input value in subsequent operation.
	 * 
	 *  e.g. Right Vector has data (V1)    :                6   3   9   7   2   4   4   3
	 *       Original indices for this data will be (I1):   1   2   3   4   5   6   7   8		
	 * 
	 *  Sorting this data based on value will be (V2):      2   3   3   4   4   6   7   9	
	 *      Then indices will be ordered as (I2):           5   2   8   6   7   1   4   3
	 * 
	 * CumMax of I2 will be A:  (CumMin(I2))                5   5   8   8   8   8   8   8
	 * CumMax of I2 in reverse order be B:                  8   8   8   7   7   4   4   3
	 * 
	 * Values from vector A is used to compute RowIndexMax for > & >= operators
	 * Values from vector B is used to compute RowIndexMax for < & <= operators
	 * Values from I2 is used to compute RowIndexMax for == operator.
	 * Original values are directly used to compute RowIndexMax for != operator
	 * 
	 * Shifting values from vector A or B is required to compute final indices.
	 * Once indices are shifted from vector A or B, their cell value corresponding to input data will be used. 
	 *  
	 * 
	 * @param iCols
	 * @param vmb
	 * @param bOp
	 * @return vixCumSum
	 */
	public static int[] prepareRowIndicesMax(int iCols, double vmb[], BinaryOperator bOp) throws DMLRuntimeException, DMLUnsupportedOperationException
	{
		int[] vixCumSum = null;
		int[] vix = new int[iCols];
		
		//sort index vector on extracted data (unstable)
		if(!(bOp.fn instanceof NotEquals)){
			for( int i=0; i operator
	 * Values from vector B is used to compute RowIndexMin for <, <= and >= operators
	 * Values from I2 is used to compute RowIndexMax for == operator.
	 * Original values are directly used to compute RowIndexMax for != operator
	 * 
	 * Shifting values from vector A or B is required to compute final indices.
	 * Once indices are shifted from vector A or B, their cell value corresponding to input data will be used. 
	 *  
	 * 
	 * @param iCols
	 * @param vmb
	 * @param bOp
	 * @return vixCumSum
	 */
	public static int[] prepareRowIndicesMin(int iCols, double vmb[], BinaryOperator bOp) throws DMLRuntimeException, DMLUnsupportedOperationException
	{
		int[] vixCumSum = null;
		int[] vix = new int[iCols];
		
		//sort index vector on extracted data (unstable)
		if(!(bOp.fn instanceof NotEquals || bOp.fn instanceof Equals )){
			for( int i=0; i= 0 ){ //match, scan to next val
			while( ix > 0 && value==bv[ix-1]) --ix;
			ix++;	//Readjust index to match subsenquent index calculation.
		}

		cnt = bv.length-Math.abs(ix)+1;

		//cnt = Math.abs(ix) - 1;
		if ((bOp.fn instanceof LessThan) || (bOp.fn instanceof GreaterThan))
			cnt = bv.length - cnt;

		return cnt;
	}

	
	/**
	 * Calculates the sum of number for rowSum of LessThan and GreaterThanEqual, and 
	 * 									colSum of GreaterThan and LessThanEqual operators.
	 * 
	 * @param value
	 * @param bv
	 * @param bOp
	 * @throws DMLRuntimeException
	 */
	private static int sumRowSumLtGeColSumGtLe(double value, double[] bv, BinaryOperator bOp) 
			throws DMLRuntimeException
	{
		int ix = Arrays.binarySearch(bv, value);
		int cnt = 0;
		
		if( ix >= 0 ){ //match, scan to next val
			while( value==bv[ix++] && ix= 0 ){ //match, scan to next val
			while( ix > 0 && value==bv[ix-1]) --ix;
			while( ix= 0 ) 
			ixMax = bvi[ix]+1;
		return ixMax;
	}

	
	/**
	 * Find out rowIndexMax for NotEqual operator. 
	 * 
	 * @param value
	 * @param bv
	 * @param bOp
	 * @throws DMLRuntimeException
	 */
	private static int uarimaxNe(double value, double[] bv, int bvi[], BinaryOperator bOp) 
			throws DMLRuntimeException
	{
		int ixMax = bv.length;
		
		if( bv[bv.length-1] == value ) 
			ixMax = bvi[0]+1;
		return ixMax;
	}

	
	/**
	 * Find out rowIndexMax for GreaterThan operator. 
	 * 
	 * @param value
	 * @param bv
	 * @param bOp
	 * @throws DMLRuntimeException
	 */
	private static int uarimaxGt(double value, double[] bv, int bvi[], BinaryOperator bOp) 
			throws DMLRuntimeException
	{
		int ixMax = bv.length;
		
		if(value <= bv[0] || value > bv[bv.length-1]) 
			return ixMax;
		
		int ix = Arrays.binarySearch(bv, value);
		ix = Math.abs(ix)-1;
		ixMax = bvi[ix-1]+1; 
		
		return ixMax;
	}

	
	/**
	 * Find out rowIndexMax for GreaterThanEqual operator. 
	 * 
	 * @param value
	 * @param bv
	 * @param bOp
	 * @throws DMLRuntimeException
	 */
	private static int uarimaxGe(double value, double[] bv, int[] bvi, BinaryOperator bOp) 
			throws DMLRuntimeException
	{
		int ixMax = bv.length;
		
		if(value < bv[0] || value >= bv[bv.length-1]) 
			return ixMax;
		
		int ix = Arrays.binarySearch(bv, value);
		ix = Math.abs(ix)-1;
		ixMax = bvi[ix-1]+1; 
		
		return ixMax;
	}

	
	/**
	 * Find out rowIndexMax for LessThan operator. 
	 * 
	 * @param value
	 * @param bv
	 * @param bOp
	 * @throws DMLRuntimeException
	 */
	private static int uarimaxLt(double value, double[] bv, int[] bvi, BinaryOperator bOp) 
			throws DMLRuntimeException
	{
		int ixMax = bv.length;
		
		if(value < bv[0] || value >= bv[bv.length-1]) 
			return ixMax;
		
		int ix = Arrays.binarySearch(bv, value);
		if (ix < 0) 
			ix = Math.abs(ix)-1;
		ixMax = bvi[ix-1]+1; 
		
		return ixMax;
	}

	/**
	 * Find out rowIndexMax for LessThanEquals operator. 
	 * 
	 * @param value
	 * @param bv
	 * @param bOp
	 * @throws DMLRuntimeException
	 */
	private static int uarimaxLe(double value, double[] bv, int[] bvi, BinaryOperator bOp) 
			throws DMLRuntimeException
	{
		int ixMax = bv.length;
		
		if(value < bv[0] || value > bv[bv.length-1]) 
			return ixMax;
		
		int ix = Arrays.binarySearch(bv, value);
		ix = Math.abs(ix);
		ixMax = bvi[ix-1]+1; 
		
		return ixMax;
	}
	

	/**
	 * Find out rowIndexMin for Equal operator. 
	 * 
	 * @param value
	 * @param bv
	 * @param bOp
	 * @throws DMLRuntimeException
	 */
	private static int uariminEq(double value, double[] bv, int bvi[], BinaryOperator bOp) 
			throws DMLRuntimeException
	{
		int ixMin = 1;
		
		if(value == bv[0])
			ixMin = bvi[0]+1;
		return ixMin;
	}

	
	/**
	 * Find out rowIndexMin for NotEqual operator. 
	 * 
	 * @param value
	 * @param bv
	 * @param bOp
	 * @throws DMLRuntimeException
	 */
	private static int uariminNe(double value, double[] bv, int bvi[], BinaryOperator bOp) 
			throws DMLRuntimeException
	{
		int ixMin = 1;
		
		if( bv[0] != value ) 
			ixMin = bvi[0]+1;
		return ixMin;
	}

	
	/**
	 * Find out rowIndexMin for GreaterThan operator. 
	 * 
	 * @param value
	 * @param bv
	 * @param bOp
	 * @throws DMLRuntimeException
	 */
	private static int uariminGt(double value, double[] bv, int bvi[], BinaryOperator bOp) 
			throws DMLRuntimeException
	{
		int ixMin = 1;
		
		if(value <= bv[0] || value > bv[bv.length-1]) 
			return ixMin;
		
		int ix = Arrays.binarySearch(bv, value);
		ix = Math.abs(ix)-1;
		ixMin = bvi[ix]+1; 
		
		return ixMin;
	}

	
	/**
	 * Find out rowIndexMin for GreaterThanEqual operator. 
	 * 
	 * @param value
	 * @param bv
	 * @param bOp
	 * @throws DMLRuntimeException
	 */
	private static int uariminGe(double value, double[] bv, int[] bvi, BinaryOperator bOp) 
			throws DMLRuntimeException
	{
		int ixMin = 1;
		
		if(value <= bv[0] || value > bv[bv.length-1]) 
			return ixMin;
		
		int ix = Arrays.binarySearch(bv, value);
		if(ix < 0)
			ix = Math.abs(ix)-1;
		ixMin = bvi[ix-1]+1; 
		
		return ixMin;
	}

	
	/**
	 * Find out rowIndexMin for LessThan operator. 
	 * 
	 * @param value
	 * @param bv
	 * @param bOp
	 * @throws DMLRuntimeException
	 */
	private static int uariminLt(double value, double[] bv, int[] bvi, BinaryOperator bOp) 
			throws DMLRuntimeException
	{
		int ixMin = 1;
		
		if(value < bv[0] || value >= bv[bv.length-1]) 
			return ixMin;
		
		int ix = Arrays.binarySearch(bv, value);
		if (ix < 0) 
			ix = Math.abs(ix)-1;
		ixMin = bvi[ix-1]+1; 
		
		return ixMin;
	}

	/**
	 * Find out rowIndexMin for LessThanEquals operator. 
	 * 
	 * @param value
	 * @param bv
	 * @param bOp
	 * @throws DMLRuntimeException
	 */
	private static int uariminLe(double value, double[] bv, int[] bvi, BinaryOperator bOp) 
			throws DMLRuntimeException
	{
		int ixMin = 1;
		
		if(value < bv[0] || value > bv[bv.length-1]) 
			return ixMin;
		
		int ix = Arrays.binarySearch(bv, value);
		if (ix < 0) 
			ix = Math.abs(ix)-1;
		ixMin = bvi[ix]+1; 
		
		return ixMin;
	}
	
	
	/**
	 * This function adjusts indices to be leveraged in uarimaxXX functions.
	 * Initially vector containing indices are sorted based on value and then CumMax/CumMin 
	 * per need for <, <=, >, >= operator, where as just sorted indices based on value for ==, and != operators.
	 * There is need to shift these indices for different operators, which is handled through this function. 
	 * 
	 * @param vix
	 * @param vmb
	 * @param bOp
	 */
	public static void adjustRowIndicesMax(int[] vix, double[] vmb,BinaryOperator bOp)
    {
    	if (bOp.fn instanceof LessThan) {
        	shiftLeft(vix, vmb);
    	} else if ((bOp.fn instanceof GreaterThanEquals) || (bOp.fn instanceof Equals)) {
    		setMaxIndexInPartition(vix,vmb);
    	} else if(bOp.fn instanceof NotEquals) {
    		double dLastValue = vmb[vmb.length-1];
    		int i=vmb.length-1;
    		while(i>0 && dLastValue == vmb[i-1]) --i;
    		if (i > 0) 
    			vix[0] = i-1;
    		else	
    			vix[0] = vix.length-1;
    	}
    }

	/**
	 * This function adjusts indices to be leveraged in uariminXX functions.
	 * Initially vector containing indices are sorted based on value and then CumMin 
	 * per need for <, <=, >, >= operator, where as just sorted indices based on value for ==, and != operators.
	 * There is need to shift these indices for different operators, which is handled through this function. 
	 * 
	 * @param vix
	 * @param vmb
	 * @param bOp
	 */
	public static void adjustRowIndicesMin(int[] vix, double[] vmb,BinaryOperator bOp)
    {
		if (bOp.fn instanceof GreaterThan) {
			setMinIndexInPartition(vix, vmb);
		}
		else if (bOp.fn instanceof GreaterThanEquals) {
        	shiftLeft(vix, vmb);
    	}
        else if (bOp.fn instanceof LessThanEquals) {
        	shiftRight(vix, vmb);
    	} else if(bOp.fn instanceof Equals) {
    		double dFirstValue = vmb[0];
    		int i=0;
    		while(i0;)
    	{
    		int iPrevInd = i;
    		double dPrevVal = vmb[iPrevInd];
			while(i>=0 && dPrevVal == vmb[i]) --i;
			
			if(i >= 0) {
				for (int j = i+1; j<= iPrevInd; ++j)
					vix[j] = vix[i];
			}
    	}
	}

	
	/**
	 * This function will shift indices from one partition to next in left direction.
	 * 
	 *  For an example, if there are two sorted vector based on value like following, where
	 *   V2 is sorted data, and I2 are its corresponding indices.
	 *   
	 *   Then this function will shift indices to right by one partition I2". 
	 *     Right most partition remained untouched.
	 *   
	 *  Sorting this data based on value will be (V2):      2   3   3   4   4   6   7   9	
	 *      Then indices will be ordered as (I2):           5   2   8   6   7   1   4   3
	 * 
	 *    Shift Left by one partition (I2")                 2   8   6   7   1   4   3   3
	 * 
	 * @param vix
	 * @param vmb
	 */
	
	public static void shiftLeft(int[] vix, double[] vmb)
	{
    	int iCurInd = 0;
		
    	for (int i = 0; i < vix.length;++i)
    	{
    		double dPrevVal = vmb[iCurInd];
			while(i 0;)
    	{
    		while(i>0 && dLastVal == vmb[i]) --i;
    		for (int j=i+1; j 0) {
        		iLastIndex = i;
        		dLastVal = vmb[i];
    		}
    	}
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy