org.apache.sysml.runtime.functionobjects.Builtin Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.functionobjects;
import java.util.HashMap;
import org.apache.commons.math3.util.FastMath;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLScriptException;
/**
* Class with pre-defined set of objects. This class can not be instantiated elsewhere.
*
* Notes on commons.math FastMath:
* * FastMath uses lookup tables and interpolation instead of native calls.
* * The memory overhead for those tables is roughly 48KB in total (acceptable)
* * Micro and application benchmarks showed significantly (30%-3x) performance improvements
* for most operations; without loss of accuracy.
* * atan / sqrt were 20% slower in FastMath and hence, we use Math there
* * round / abs were equivalent in FastMath and hence, we use Math there
* * Finally, there is just one argument against FastMath - The comparison heavily depends
* on the JVM. For example, currently the IBM JDK JIT compiles to HW instructions for sqrt
* which makes this operation very efficient; as soon as other operations like log/exp are
* similarly compiled, we should rerun the micro benchmarks, and switch back if necessary.
*
*/
public class Builtin extends ValueFunction
{
private static final long serialVersionUID = 3836744687789840574L;
public enum BuiltinCode { SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP }
public BuiltinCode bFunc;
private static final boolean FASTMATH = true;
static public HashMap String2BuiltinCode;
static {
String2BuiltinCode = new HashMap<>();
String2BuiltinCode.put( "sin" , BuiltinCode.SIN);
String2BuiltinCode.put( "cos" , BuiltinCode.COS);
String2BuiltinCode.put( "tan" , BuiltinCode.TAN);
String2BuiltinCode.put( "sinh" , BuiltinCode.SINH);
String2BuiltinCode.put( "cosh" , BuiltinCode.COSH);
String2BuiltinCode.put( "tanh" , BuiltinCode.TANH);
String2BuiltinCode.put( "asin" , BuiltinCode.ASIN);
String2BuiltinCode.put( "acos" , BuiltinCode.ACOS);
String2BuiltinCode.put( "atan" , BuiltinCode.ATAN);
String2BuiltinCode.put( "log" , BuiltinCode.LOG);
String2BuiltinCode.put( "log_nz" , BuiltinCode.LOG_NZ);
String2BuiltinCode.put( "min" , BuiltinCode.MIN);
String2BuiltinCode.put( "max" , BuiltinCode.MAX);
String2BuiltinCode.put( "maxindex", BuiltinCode.MAXINDEX);
String2BuiltinCode.put( "minindex", BuiltinCode.MININDEX);
String2BuiltinCode.put( "abs" , BuiltinCode.ABS);
String2BuiltinCode.put( "sign" , BuiltinCode.SIGN);
String2BuiltinCode.put( "sqrt" , BuiltinCode.SQRT);
String2BuiltinCode.put( "exp" , BuiltinCode.EXP);
String2BuiltinCode.put( "plogp" , BuiltinCode.PLOGP);
String2BuiltinCode.put( "print" , BuiltinCode.PRINT);
String2BuiltinCode.put( "printf" , BuiltinCode.PRINTF);
String2BuiltinCode.put( "nrow" , BuiltinCode.NROW);
String2BuiltinCode.put( "ncol" , BuiltinCode.NCOL);
String2BuiltinCode.put( "length" , BuiltinCode.LENGTH);
String2BuiltinCode.put( "round" , BuiltinCode.ROUND);
String2BuiltinCode.put( "stop" , BuiltinCode.STOP);
String2BuiltinCode.put( "ceil" , BuiltinCode.CEIL);
String2BuiltinCode.put( "floor" , BuiltinCode.FLOOR);
String2BuiltinCode.put( "ucumk+" , BuiltinCode.CUMSUM);
String2BuiltinCode.put( "ucum*" , BuiltinCode.CUMPROD);
String2BuiltinCode.put( "ucummin", BuiltinCode.CUMMIN);
String2BuiltinCode.put( "ucummax", BuiltinCode.CUMMAX);
String2BuiltinCode.put( "inverse", BuiltinCode.INVERSE);
String2BuiltinCode.put( "sprop", BuiltinCode.SPROP);
String2BuiltinCode.put( "sigmoid", BuiltinCode.SIGMOID);
String2BuiltinCode.put( "sel+", BuiltinCode.SELP);
}
// We should create one object for every builtin function that we support
private static Builtin sinObj = null, cosObj = null, tanObj = null, sinhObj = null, coshObj = null, tanhObj = null, asinObj = null, acosObj = null, atanObj = null;
private static Builtin logObj = null, lognzObj = null, minObj = null, maxObj = null, maxindexObj = null, minindexObj=null;
private static Builtin absObj = null, signObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null, printfObj;
private static Builtin nrowObj = null, ncolObj = null, lengthObj = null, roundObj = null, ceilObj=null, floorObj=null;
private static Builtin inverseObj=null, cumsumObj=null, cumprodObj=null, cumminObj=null, cummaxObj=null;
private static Builtin stopObj = null, spropObj = null, sigmoidObj = null, selpObj = null;
private Builtin(BuiltinCode bf) {
bFunc = bf;
}
public BuiltinCode getBuiltinCode() {
return bFunc;
}
public static Builtin getBuiltinFnObject (String str)
{
BuiltinCode code = String2BuiltinCode.get(str);
return getBuiltinFnObject( code );
}
public static Builtin getBuiltinFnObject(BuiltinCode code)
{
if ( code == null )
return null;
switch ( code ) {
case SIN:
if ( sinObj == null )
sinObj = new Builtin(BuiltinCode.SIN);
return sinObj;
case COS:
if ( cosObj == null )
cosObj = new Builtin(BuiltinCode.COS);
return cosObj;
case TAN:
if ( tanObj == null )
tanObj = new Builtin(BuiltinCode.TAN);
return tanObj;
case SINH:
if ( sinhObj == null )
sinhObj = new Builtin(BuiltinCode.SINH);
return sinhObj;
case COSH:
if ( coshObj == null )
coshObj = new Builtin(BuiltinCode.COSH);
return coshObj;
case TANH:
if ( tanhObj == null )
tanhObj = new Builtin(BuiltinCode.TANH);
return tanhObj;
case ASIN:
if ( asinObj == null )
asinObj = new Builtin(BuiltinCode.ASIN);
return asinObj;
case ACOS:
if ( acosObj == null )
acosObj = new Builtin(BuiltinCode.ACOS);
return acosObj;
case ATAN:
if ( atanObj == null )
atanObj = new Builtin(BuiltinCode.ATAN);
return atanObj;
case LOG:
if ( logObj == null )
logObj = new Builtin(BuiltinCode.LOG);
return logObj;
case LOG_NZ:
if ( lognzObj == null )
lognzObj = new Builtin(BuiltinCode.LOG_NZ);
return lognzObj;
case MAX:
if ( maxObj == null )
maxObj = new Builtin(BuiltinCode.MAX);
return maxObj;
case MAXINDEX:
if ( maxindexObj == null )
maxindexObj = new Builtin(BuiltinCode.MAXINDEX);
return maxindexObj;
case MIN:
if ( minObj == null )
minObj = new Builtin(BuiltinCode.MIN);
return minObj;
case MININDEX:
if ( minindexObj == null )
minindexObj = new Builtin(BuiltinCode.MININDEX);
return minindexObj;
case ABS:
if ( absObj == null )
absObj = new Builtin(BuiltinCode.ABS);
return absObj;
case SIGN:
if ( signObj == null )
signObj = new Builtin(BuiltinCode.SIGN);
return signObj;
case SQRT:
if ( sqrtObj == null )
sqrtObj = new Builtin(BuiltinCode.SQRT);
return sqrtObj;
case EXP:
if ( expObj == null )
expObj = new Builtin(BuiltinCode.EXP);
return expObj;
case PLOGP:
if ( plogpObj == null )
plogpObj = new Builtin(BuiltinCode.PLOGP);
return plogpObj;
case PRINT:
if ( printObj == null )
printObj = new Builtin(BuiltinCode.PRINT);
return printObj;
case PRINTF:
if (printfObj == null) {
printfObj = new Builtin(BuiltinCode.PRINTF);
}
return printfObj;
case NROW:
if ( nrowObj == null )
nrowObj = new Builtin(BuiltinCode.NROW);
return nrowObj;
case NCOL:
if ( ncolObj == null )
ncolObj = new Builtin(BuiltinCode.NCOL);
return ncolObj;
case LENGTH:
if ( lengthObj == null )
lengthObj = new Builtin(BuiltinCode.LENGTH);
return lengthObj;
case ROUND:
if ( roundObj == null )
roundObj = new Builtin(BuiltinCode.ROUND);
return roundObj;
case CEIL:
if ( ceilObj == null )
ceilObj = new Builtin(BuiltinCode.CEIL);
return ceilObj;
case FLOOR:
if ( floorObj == null )
floorObj = new Builtin(BuiltinCode.FLOOR);
return floorObj;
case CUMSUM:
if ( cumsumObj == null )
cumsumObj = new Builtin(BuiltinCode.CUMSUM);
return cumsumObj;
case CUMPROD:
if ( cumprodObj == null )
cumprodObj = new Builtin(BuiltinCode.CUMPROD);
return cumprodObj;
case CUMMIN:
if ( cumminObj == null )
cumminObj = new Builtin(BuiltinCode.CUMMIN);
return cumminObj;
case CUMMAX:
if ( cummaxObj == null )
cummaxObj = new Builtin(BuiltinCode.CUMMAX);
return cummaxObj;
case INVERSE:
if ( inverseObj == null )
inverseObj = new Builtin(BuiltinCode.INVERSE);
return inverseObj;
case STOP:
if ( stopObj == null )
stopObj = new Builtin(BuiltinCode.STOP);
return stopObj;
case SPROP:
if ( spropObj == null )
spropObj = new Builtin(BuiltinCode.SPROP);
return spropObj;
case SIGMOID:
if ( sigmoidObj == null )
sigmoidObj = new Builtin(BuiltinCode.SIGMOID);
return sigmoidObj;
case SELP:
if ( selpObj == null )
selpObj = new Builtin(BuiltinCode.SELP);
return selpObj;
default:
// Unknown code --> return null
return null;
}
}
@Override
public double execute (double in)
throws DMLRuntimeException
{
switch(bFunc) {
case SIN: return FASTMATH ? FastMath.sin(in) : Math.sin(in);
case COS: return FASTMATH ? FastMath.cos(in) : Math.cos(in);
case TAN: return FASTMATH ? FastMath.tan(in) : Math.tan(in);
case ASIN: return FASTMATH ? FastMath.asin(in) : Math.asin(in);
case ACOS: return FASTMATH ? FastMath.acos(in) : Math.acos(in);
case ATAN: return Math.atan(in); //faster in Math
// FastMath.*h is faster 98% of time than Math.*h in initial micro-benchmarks
case SINH: return FASTMATH ? FastMath.sinh(in) : Math.sinh(in);
case COSH: return FASTMATH ? FastMath.cosh(in) : Math.cosh(in);
case TANH: return FASTMATH ? FastMath.tanh(in) : Math.tanh(in);
case CEIL: return FASTMATH ? FastMath.ceil(in) : Math.ceil(in);
case FLOOR: return FASTMATH ? FastMath.floor(in) : Math.floor(in);
case LOG: return Math.log(in); //faster in Math
case LOG_NZ: return (in==0) ? 0 : Math.log(in); //faster in Math
case ABS: return Math.abs(in); //no need for FastMath
case SIGN: return FASTMATH ? FastMath.signum(in) : Math.signum(in);
case SQRT: return Math.sqrt(in); //faster in Math
case EXP: return FASTMATH ? FastMath.exp(in) : Math.exp(in);
case ROUND: return Math.round(in); //no need for FastMath
case PLOGP:
if (in == 0.0)
return 0.0;
else if (in < 0)
return Double.NaN;
else //faster in Math
return in * Math.log(in);
case SPROP:
//sample proportion: P*(1-P)
return in * (1 - in);
case SIGMOID:
//sigmoid: 1/(1+exp(-x))
return FASTMATH ? 1 / (1 + FastMath.exp(-in)) : 1 / (1 + Math.exp(-in));
case SELP:
//select positive: x*(x>0)
return (in > 0) ? in : 0;
default:
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
@Override
public double execute (long in) throws DMLRuntimeException {
return execute((double)in);
}
/*
* Builtin functions with two inputs
*/
@Override
public double execute (double in1, double in2) throws DMLRuntimeException {
switch(bFunc) {
/*
* Arithmetic relational operators (==, !=, <=, >=) must be instead of
* Double.compare()
due to the inconsistencies in the way
* NaN and -0.0 are handled. The behavior of methods in
* Double
class are designed mainly to make Java
* collections work properly. For more details, see the help for
* Double.equals()
and Double.comapreTo()
.
*/
case MAX:
case CUMMAX:
//return (Double.compare(in1, in2) >= 0 ? in1 : in2);
return (in1 >= in2 ? in1 : in2);
case MIN:
case CUMMIN:
//return (Double.compare(in1, in2) <= 0 ? in1 : in2);
return (in1 <= in2 ? in1 : in2);
// *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
// rowIndexMax() and its siblings require comparing four values, but
// the aggregation API only allows two values. So the execute()
// method receives as its argument the two cell values to be
// compared and performs just the value part of the comparison. We
// return an integer cast down to a double, since the aggregation
// API doesn't have any way to return anything but a double. The
// integer returned takes on three posssible values: //
// . 0 => keep the index associated with in1 //
// . 1 => use the index associated with in2 //
// . 2 => use whichever index is higher (tie in value) //
case MAXINDEX:
if (in1 == in2) {
return 2;
} else if (in1 > in2) {
return 1;
} else { // in1 < in2
return 0;
}
case MININDEX:
if (in1 == in2) {
return 2;
} else if (in1 < in2) {
return 1;
} else { // in1 > in2
return 0;
}
// *** END HACK ***
case LOG:
//faster in Math
return (Math.log(in1)/Math.log(in2));
case LOG_NZ:
//faster in Math
return (in1==0) ? 0 : (Math.log(in1)/Math.log(in2));
default:
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
/**
* Simplified version without exception handling
*
* @param in1 double 1
* @param in2 double 2
* @return result
*/
public double execute2(double in1, double in2)
{
switch(bFunc) {
case MAX:
case CUMMAX:
//return (Double.compare(in1, in2) >= 0 ? in1 : in2);
return (in1 >= in2 ? in1 : in2);
case MIN:
case CUMMIN:
//return (Double.compare(in1, in2) <= 0 ? in1 : in2);
return (in1 <= in2 ? in1 : in2);
case MAXINDEX:
return (in1 >= in2) ? 1 : 0;
case MININDEX:
return (in1 <= in2) ? 1 : 0;
default:
// For performance reasons, avoid throwing an exception
return -1;
}
}
@Override
public double execute (long in1, long in2) throws DMLRuntimeException {
switch(bFunc) {
case MAX:
case CUMMAX: return (in1 >= in2 ? in1 : in2);
case MIN:
case CUMMIN: return (in1 <= in2 ? in1 : in2);
case MAXINDEX: return (in1 >= in2) ? 1 : 0;
case MININDEX: return (in1 <= in2) ? 1 : 0;
case LOG:
//faster in Math
return Math.log(in1)/Math.log(in2);
case LOG_NZ:
//faster in Math
return (in1==0) ? 0 : Math.log(in1)/Math.log(in2);
default:
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
@Override
public String execute (String in1)
throws DMLRuntimeException
{
switch (bFunc) {
case PRINT:
if (!DMLScript.suppressPrint2Stdout())
System.out.println(in1);
return null;
case PRINTF:
if (!DMLScript.suppressPrint2Stdout())
System.out.println(in1);
return null;
case STOP:
throw new DMLScriptException(in1);
default:
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
}