All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.functionobjects.Builtin Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.functionobjects;

import java.util.HashMap;

import org.apache.commons.math3.util.FastMath;

import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLScriptException;


/**
 *  Class with pre-defined set of objects. This class can not be instantiated elsewhere.
 *  
 *  Notes on commons.math FastMath:
 *  * FastMath uses lookup tables and interpolation instead of native calls.
 *  * The memory overhead for those tables is roughly 48KB in total (acceptable)
 *  * Micro and application benchmarks showed significantly (30%-3x) performance improvements
 *    for most operations; without loss of accuracy.
 *  * atan / sqrt were 20% slower in FastMath and hence, we use Math there
 *  * round / abs were equivalent in FastMath and hence, we use Math there
 *  * Finally, there is just one argument against FastMath - The comparison heavily depends
 *    on the JVM. For example, currently the IBM JDK JIT compiles to HW instructions for sqrt
 *    which makes this operation very efficient; as soon as other operations like log/exp are
 *    similarly compiled, we should rerun the micro benchmarks, and switch back if necessary.
 *  
 */
public class Builtin extends ValueFunction 
{

	private static final long serialVersionUID = 3836744687789840574L;
	
	public enum BuiltinCode { SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP }
	public BuiltinCode bFunc;
	
	private static final boolean FASTMATH = true;
	
	static public HashMap String2BuiltinCode;
	static {
		String2BuiltinCode = new HashMap<>();
		String2BuiltinCode.put( "sin"    , BuiltinCode.SIN);
		String2BuiltinCode.put( "cos"    , BuiltinCode.COS);
		String2BuiltinCode.put( "tan"    , BuiltinCode.TAN);
		String2BuiltinCode.put( "sinh"    , BuiltinCode.SINH);
		String2BuiltinCode.put( "cosh"    , BuiltinCode.COSH);
		String2BuiltinCode.put( "tanh"    , BuiltinCode.TANH);
		String2BuiltinCode.put( "asin"   , BuiltinCode.ASIN);
		String2BuiltinCode.put( "acos"   , BuiltinCode.ACOS);
		String2BuiltinCode.put( "atan"   , BuiltinCode.ATAN);
		String2BuiltinCode.put( "log"    , BuiltinCode.LOG);
		String2BuiltinCode.put( "log_nz" , BuiltinCode.LOG_NZ);
		String2BuiltinCode.put( "min"    , BuiltinCode.MIN);
		String2BuiltinCode.put( "max"    , BuiltinCode.MAX);
		String2BuiltinCode.put( "maxindex", BuiltinCode.MAXINDEX);
		String2BuiltinCode.put( "minindex", BuiltinCode.MININDEX);
		String2BuiltinCode.put( "abs"    , BuiltinCode.ABS);
		String2BuiltinCode.put( "sign"   , BuiltinCode.SIGN);
		String2BuiltinCode.put( "sqrt"   , BuiltinCode.SQRT);
		String2BuiltinCode.put( "exp"    , BuiltinCode.EXP);
		String2BuiltinCode.put( "plogp"  , BuiltinCode.PLOGP);
		String2BuiltinCode.put( "print"  , BuiltinCode.PRINT);
		String2BuiltinCode.put( "printf"  , BuiltinCode.PRINTF);
		String2BuiltinCode.put( "nrow"   , BuiltinCode.NROW);
		String2BuiltinCode.put( "ncol"   , BuiltinCode.NCOL);
		String2BuiltinCode.put( "length" , BuiltinCode.LENGTH);
		String2BuiltinCode.put( "round"  , BuiltinCode.ROUND);
		String2BuiltinCode.put( "stop"   , BuiltinCode.STOP);
		String2BuiltinCode.put( "ceil"   , BuiltinCode.CEIL);
		String2BuiltinCode.put( "floor"  , BuiltinCode.FLOOR);
		String2BuiltinCode.put( "ucumk+" , BuiltinCode.CUMSUM);
		String2BuiltinCode.put( "ucum*"  , BuiltinCode.CUMPROD);
		String2BuiltinCode.put( "ucummin", BuiltinCode.CUMMIN);
		String2BuiltinCode.put( "ucummax", BuiltinCode.CUMMAX);
		String2BuiltinCode.put( "inverse", BuiltinCode.INVERSE);
		String2BuiltinCode.put( "sprop",   BuiltinCode.SPROP);
		String2BuiltinCode.put( "sigmoid", BuiltinCode.SIGMOID);
		String2BuiltinCode.put( "sel+",    BuiltinCode.SELP);
	}
	
	// We should create one object for every builtin function that we support
	private static Builtin sinObj = null, cosObj = null, tanObj = null, sinhObj = null, coshObj = null, tanhObj = null, asinObj = null, acosObj = null, atanObj = null;
	private static Builtin logObj = null, lognzObj = null, minObj = null, maxObj = null, maxindexObj = null, minindexObj=null;
	private static Builtin absObj = null, signObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null, printfObj;
	private static Builtin nrowObj = null, ncolObj = null, lengthObj = null, roundObj = null, ceilObj=null, floorObj=null; 
	private static Builtin inverseObj=null, cumsumObj=null, cumprodObj=null, cumminObj=null, cummaxObj=null;
	private static Builtin stopObj = null, spropObj = null, sigmoidObj = null, selpObj = null;
	
	private Builtin(BuiltinCode bf) {
		bFunc = bf;
	}
	
	public BuiltinCode getBuiltinCode() {
		return bFunc;
	}

	public static Builtin getBuiltinFnObject (String str) 
	{
		BuiltinCode code = String2BuiltinCode.get(str);
		return getBuiltinFnObject( code );
	}

	public static Builtin getBuiltinFnObject(BuiltinCode code) 
	{	
		if ( code == null ) 
			return null; 
			
		switch ( code ) {
		case SIN:
			if ( sinObj == null )
				sinObj = new Builtin(BuiltinCode.SIN);
			return sinObj;
		
		case COS:
			if ( cosObj == null )
				cosObj = new Builtin(BuiltinCode.COS);
			return cosObj;
		case TAN:
			if ( tanObj == null )
				tanObj = new Builtin(BuiltinCode.TAN);
			return tanObj;
		case SINH:
			if ( sinhObj == null )
				sinhObj = new Builtin(BuiltinCode.SINH);
			return sinhObj;
		
		case COSH:
			if ( coshObj == null )
				coshObj = new Builtin(BuiltinCode.COSH);
			return coshObj;
		case TANH:
			if ( tanhObj == null )
				tanhObj = new Builtin(BuiltinCode.TANH);
			return tanhObj;
		case ASIN:
			if ( asinObj == null )
				asinObj = new Builtin(BuiltinCode.ASIN);
			return asinObj;
		
		case ACOS:
			if ( acosObj == null )
				acosObj = new Builtin(BuiltinCode.ACOS);
			return acosObj;
		case ATAN:
			if ( atanObj == null )
				atanObj = new Builtin(BuiltinCode.ATAN);
			return atanObj;
		case LOG:
			if ( logObj == null )
				logObj = new Builtin(BuiltinCode.LOG);
			return logObj;
		case LOG_NZ:
			if ( lognzObj == null )
				lognzObj = new Builtin(BuiltinCode.LOG_NZ);
			return lognzObj;
		case MAX:
			if ( maxObj == null )
				maxObj = new Builtin(BuiltinCode.MAX);
			return maxObj;
		case MAXINDEX:
			if ( maxindexObj == null )
				maxindexObj = new Builtin(BuiltinCode.MAXINDEX);
			return maxindexObj;
		case MIN:
			if ( minObj == null )
				minObj = new Builtin(BuiltinCode.MIN);
			return minObj;
		case MININDEX:
			if ( minindexObj == null )
				minindexObj = new Builtin(BuiltinCode.MININDEX);
			return minindexObj;
		case ABS:
			if ( absObj == null )
				absObj = new Builtin(BuiltinCode.ABS);
			return absObj;
		case SIGN:
			if ( signObj == null )
				signObj = new Builtin(BuiltinCode.SIGN);
			return signObj;
		case SQRT:
			if ( sqrtObj == null )
				sqrtObj = new Builtin(BuiltinCode.SQRT);
			return sqrtObj;
		case EXP:
			if ( expObj == null )
				expObj = new Builtin(BuiltinCode.EXP);
			return expObj;
		case PLOGP:
			if ( plogpObj == null )
				plogpObj = new Builtin(BuiltinCode.PLOGP);
			return plogpObj;
		case PRINT:
			if ( printObj == null )
				printObj = new Builtin(BuiltinCode.PRINT);
			return printObj;
		case PRINTF:
			if (printfObj == null) {
				printfObj = new Builtin(BuiltinCode.PRINTF);
			}
			return printfObj;
		case NROW:
			if ( nrowObj == null )
				nrowObj = new Builtin(BuiltinCode.NROW);
			return nrowObj;
		case NCOL:
			if ( ncolObj == null )
				ncolObj = new Builtin(BuiltinCode.NCOL);
			return ncolObj;
		case LENGTH:
			if ( lengthObj == null )
				lengthObj = new Builtin(BuiltinCode.LENGTH);
			return lengthObj;
		case ROUND:
			if ( roundObj == null )
				roundObj = new Builtin(BuiltinCode.ROUND);
			return roundObj;
		case CEIL:
			if ( ceilObj == null )
				ceilObj = new Builtin(BuiltinCode.CEIL);
			return ceilObj;
		case FLOOR:
			if ( floorObj == null )
				floorObj = new Builtin(BuiltinCode.FLOOR);
			return floorObj;
		case CUMSUM:
			if ( cumsumObj == null )
				cumsumObj = new Builtin(BuiltinCode.CUMSUM);
			return cumsumObj;	
		case CUMPROD:
			if ( cumprodObj == null )
				cumprodObj = new Builtin(BuiltinCode.CUMPROD);
			return cumprodObj;	
		case CUMMIN:
			if ( cumminObj == null )
				cumminObj = new Builtin(BuiltinCode.CUMMIN);
			return cumminObj;	
		case CUMMAX:
			if ( cummaxObj == null )
				cummaxObj = new Builtin(BuiltinCode.CUMMAX);
			return cummaxObj;	
		case INVERSE:
			if ( inverseObj == null )
				inverseObj = new Builtin(BuiltinCode.INVERSE);
			return inverseObj;	
		case STOP:
			if ( stopObj == null )
				stopObj = new Builtin(BuiltinCode.STOP);
			return stopObj;

		case SPROP:
			if ( spropObj == null )
				spropObj = new Builtin(BuiltinCode.SPROP);
			return spropObj;
			
		case SIGMOID:
			if ( sigmoidObj == null )
				sigmoidObj = new Builtin(BuiltinCode.SIGMOID);
			return sigmoidObj;
		
		case SELP:
			if ( selpObj == null )
				selpObj = new Builtin(BuiltinCode.SELP);
			return selpObj;
			
		default:
			// Unknown code --> return null
			return null;
		}
	}

	@Override
	public double execute (double in) 
		throws DMLRuntimeException 
	{
		switch(bFunc) {
			case SIN:    return FASTMATH ? FastMath.sin(in) : Math.sin(in);
			case COS:    return FASTMATH ? FastMath.cos(in) : Math.cos(in);
			case TAN:    return FASTMATH ? FastMath.tan(in) : Math.tan(in);
			case ASIN:   return FASTMATH ? FastMath.asin(in) : Math.asin(in);
			case ACOS:   return FASTMATH ? FastMath.acos(in) : Math.acos(in);
			case ATAN:   return Math.atan(in); //faster in Math
			// FastMath.*h is faster 98% of time than Math.*h in initial micro-benchmarks
			case SINH:   return FASTMATH ? FastMath.sinh(in) : Math.sinh(in);
			case COSH:   return FASTMATH ? FastMath.cosh(in) : Math.cosh(in);
			case TANH:   return FASTMATH ? FastMath.tanh(in) : Math.tanh(in);
			case CEIL:   return FASTMATH ? FastMath.ceil(in) : Math.ceil(in);
			case FLOOR:  return FASTMATH ? FastMath.floor(in) : Math.floor(in);
			case LOG:    return Math.log(in); //faster in Math
			case LOG_NZ: return (in==0) ? 0 : Math.log(in); //faster in Math
			case ABS:    return Math.abs(in); //no need for FastMath
			case SIGN:	 return FASTMATH ? FastMath.signum(in) : Math.signum(in);
			case SQRT:   return Math.sqrt(in); //faster in Math
			case EXP:    return FASTMATH ? FastMath.exp(in) : Math.exp(in);
			case ROUND: return Math.round(in); //no need for FastMath
			
			case PLOGP:
				if (in == 0.0)
					return 0.0;
				else if (in < 0)
					return Double.NaN;
				else //faster in Math
					return in * Math.log(in);
			
			case SPROP:
				//sample proportion: P*(1-P)
				return in * (1 - in); 
	
			case SIGMOID:
				//sigmoid: 1/(1+exp(-x))
				return FASTMATH ? 1 / (1 + FastMath.exp(-in))  : 1 / (1 + Math.exp(-in));
			
			case SELP:
				//select positive: x*(x>0)
				return (in > 0) ? in : 0;
				
			default:
				throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
		}
	}

	@Override
	public double execute (long in) throws DMLRuntimeException {
		return execute((double)in);
	}

	/*
	 * Builtin functions with two inputs
	 */	
	@Override
	public double execute (double in1, double in2) throws DMLRuntimeException {
		switch(bFunc) {
		
		/*
		 * Arithmetic relational operators (==, !=, <=, >=) must be instead of
		 * Double.compare() due to the inconsistencies in the way
		 * NaN and -0.0 are handled. The behavior of methods in
		 * Double class are designed mainly to make Java
		 * collections work properly. For more details, see the help for
		 * Double.equals() and Double.comapreTo().
		 */
		case MAX:
		case CUMMAX:
			//return (Double.compare(in1, in2) >= 0 ? in1 : in2);
			return (in1 >= in2 ? in1 : in2);
		case MIN:
		case CUMMIN:
			//return (Double.compare(in1, in2) <= 0 ? in1 : in2);
			return (in1 <= in2 ? in1 : in2);
			
			// *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
			// rowIndexMax() and its siblings require comparing four values, but
			// the aggregation API only allows two values. So the execute()
			// method receives as its argument the two cell values to be
			// compared and performs just the value part of the comparison. We
			// return an integer cast down to a double, since the aggregation
			// API doesn't have any way to return anything but a double. The
			// integer returned takes on three posssible values: //
			// .     0 => keep the index associated with in1 //
			// .     1 => use the index associated with in2 //
			// .     2 => use whichever index is higher (tie in value) //
		case MAXINDEX:
			if (in1 == in2) {
				return 2;
			} else if (in1 > in2) {
				return 1;
			} else { // in1 < in2
				return 0;
			}
		case MININDEX:
			if (in1 == in2) {
				return 2;
			} else if (in1 < in2) {
				return 1;
			} else { // in1 > in2
				return 0;
			}
			// *** END HACK ***
		case LOG:
			//faster in Math
			return (Math.log(in1)/Math.log(in2)); 
		case LOG_NZ:
			//faster in Math
			return (in1==0) ? 0 : (Math.log(in1)/Math.log(in2)); 
		
			
		default:
			throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
		}
	}
	
	/**
	 * Simplified version without exception handling
	 * 
	 * @param in1 double 1
	 * @param in2 double 2
	 * @return result
	 */
	public double execute2(double in1, double in2) 
	{
		switch(bFunc) {		
			case MAX:
			case CUMMAX:
				//return (Double.compare(in1, in2) >= 0 ? in1 : in2); 
				return (in1 >= in2 ? in1 : in2);
			case MIN:
			case CUMMIN:
				//return (Double.compare(in1, in2) <= 0 ? in1 : in2); 
				return (in1 <= in2 ? in1 : in2);
			case MAXINDEX: 
				return (in1 >= in2) ? 1 : 0;	
			case MININDEX: 
				return (in1 <= in2) ? 1 : 0;	
				
			default:
				// For performance reasons, avoid throwing an exception 
				return -1;
		}
	}
	
	@Override
	public double execute (long in1, long in2) throws DMLRuntimeException {
		switch(bFunc) {
		
		case MAX:    
		case CUMMAX:   return (in1 >= in2 ? in1 : in2); 
		
		case MIN:    
		case CUMMIN:   return (in1 <= in2 ? in1 : in2); 
		
		case MAXINDEX: return (in1 >= in2) ? 1 : 0;
		case MININDEX: return (in1 <= in2) ? 1 : 0;
		
		case LOG:
			//faster in Math
			return Math.log(in1)/Math.log(in2);
		case LOG_NZ:
			//faster in Math
			return (in1==0) ? 0 : Math.log(in1)/Math.log(in2);
		
				
		
		default:
			throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
		}
	}

	@Override
	public String execute (String in1) 
		throws DMLRuntimeException 
	{
		switch (bFunc) {
		case PRINT:
			if (!DMLScript.suppressPrint2Stdout())
				System.out.println(in1);
			return null;
		case PRINTF:
			if (!DMLScript.suppressPrint2Stdout())
				System.out.println(in1);
			return null;
		case STOP:
			throw new DMLScriptException(in1);
		default:
			throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy