org.apache.sysml.runtime.functionobjects.ParameterizedBuiltin Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.functionobjects;
import java.util.HashMap;
import org.apache.commons.math3.distribution.AbstractRealDistribution;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.distribution.ExponentialDistribution;
import org.apache.commons.math3.distribution.FDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.TDistribution;
import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.util.UtilFunctions;
/**
* Function object for builtin function that takes a list of name=value parameters.
* This class can not be instantiated elsewhere.
*/
public class ParameterizedBuiltin extends ValueFunction
{
private static final long serialVersionUID = -5966242955816522697L;
public enum ParameterizedBuiltinCode { INVALID, CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, TRANSFORM };
public enum ProbabilityDistributionCode { INVALID, NORMAL, EXP, CHISQ, F, T };
public ParameterizedBuiltinCode bFunc;
public ProbabilityDistributionCode distFunc;
static public HashMap String2ParameterizedBuiltinCode;
static {
String2ParameterizedBuiltinCode = new HashMap();
String2ParameterizedBuiltinCode.put( "cdf", ParameterizedBuiltinCode.CDF);
String2ParameterizedBuiltinCode.put( "invcdf", ParameterizedBuiltinCode.INVCDF);
String2ParameterizedBuiltinCode.put( "rmempty", ParameterizedBuiltinCode.RMEMPTY);
String2ParameterizedBuiltinCode.put( "replace", ParameterizedBuiltinCode.REPLACE);
String2ParameterizedBuiltinCode.put( "rexpand", ParameterizedBuiltinCode.REXPAND);
String2ParameterizedBuiltinCode.put( "transform", ParameterizedBuiltinCode.TRANSFORM);
}
static public HashMap String2DistCode;
static {
String2DistCode = new HashMap();
String2DistCode.put("normal" , ProbabilityDistributionCode.NORMAL);
String2DistCode.put("exp" , ProbabilityDistributionCode.EXP);
String2DistCode.put("chisq" , ProbabilityDistributionCode.CHISQ);
String2DistCode.put("f" , ProbabilityDistributionCode.F);
String2DistCode.put("t" , ProbabilityDistributionCode.T);
}
// We should create one object for every builtin function that we support
private static ParameterizedBuiltin normalObj = null, expObj = null, chisqObj = null, fObj = null, tObj = null;
private static ParameterizedBuiltin inormalObj = null, iexpObj = null, ichisqObj = null, ifObj = null, itObj = null;
private ParameterizedBuiltin(ParameterizedBuiltinCode bf) {
bFunc = bf;
distFunc = ProbabilityDistributionCode.INVALID;
}
private ParameterizedBuiltin(ParameterizedBuiltinCode bf, ProbabilityDistributionCode dist) {
bFunc = bf;
distFunc = dist;
}
public static ParameterizedBuiltin getParameterizedBuiltinFnObject (String str) throws DMLRuntimeException {
return getParameterizedBuiltinFnObject (str, null);
}
public static ParameterizedBuiltin getParameterizedBuiltinFnObject (String str, String str2) throws DMLRuntimeException {
ParameterizedBuiltinCode code = String2ParameterizedBuiltinCode.get(str);
switch ( code )
{
case CDF:
// str2 will point the appropriate distribution
ProbabilityDistributionCode dcode = String2DistCode.get(str2.toLowerCase());
switch(dcode) {
case NORMAL:
if ( normalObj == null )
normalObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode);
return normalObj;
case EXP:
if ( expObj == null )
expObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode);
return expObj;
case CHISQ:
if ( chisqObj == null )
chisqObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode);
return chisqObj;
case F:
if ( fObj == null )
fObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode);
return fObj;
case T:
if ( tObj == null )
tObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode);
return tObj;
default:
throw new DMLRuntimeException("Invalid distribution code: " + dcode);
}
case INVCDF:
// str2 will point the appropriate distribution
ProbabilityDistributionCode distcode = String2DistCode.get(str2.toLowerCase());
switch(distcode) {
case NORMAL:
if ( inormalObj == null )
inormalObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode);
return inormalObj;
case EXP:
if ( iexpObj == null )
iexpObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode);
return iexpObj;
case CHISQ:
if ( ichisqObj == null )
ichisqObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode);
return ichisqObj;
case F:
if ( ifObj == null )
ifObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode);
return ifObj;
case T:
if ( itObj == null )
itObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode);
return itObj;
default:
throw new DMLRuntimeException("Invalid distribution code: " + distcode);
}
case RMEMPTY:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.RMEMPTY);
case REPLACE:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.REPLACE);
case REXPAND:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.REXPAND);
case TRANSFORM:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.TRANSFORM);
default:
throw new DMLRuntimeException("Invalid parameterized builtin code: " + code);
}
}
public Object clone() throws CloneNotSupportedException {
// cloning is not supported for singleton classes
throw new CloneNotSupportedException();
}
public double execute(HashMap params) throws DMLRuntimeException {
switch(bFunc) {
case CDF:
case INVCDF:
switch(distFunc) {
case NORMAL:
case EXP:
case CHISQ:
case F:
case T:
return computeFromDistribution(distFunc, params, (bFunc==ParameterizedBuiltinCode.INVCDF));
default:
throw new DMLRuntimeException("Unsupported distribution (" + distFunc + ").");
}
default:
throw new DMLRuntimeException("ParameterizedBuiltin.execute(): Unknown operation: " + bFunc);
}
}
/**
* Helper function to compute distribution-specific cdf (both lowertail and uppertail) and inverse cdf.
*
* @param dcode
* @param params
* @param inverse
* @return
* @throws MathArithmeticException
* @throws DMLRuntimeException
*/
private double computeFromDistribution (ProbabilityDistributionCode dcode, HashMap params, boolean inverse ) throws MathArithmeticException, DMLRuntimeException {
// given value is "quantile" when inverse=false, and it is "probability" when inverse=true
double val = Double.parseDouble(params.get("target"));
boolean lowertail = true;
if(params.get("lower.tail") != null) {
lowertail = Boolean.parseBoolean(params.get("lower.tail"));
}
AbstractRealDistribution distFunction = null;
switch(dcode) {
case NORMAL:
double mean = 0.0, sd = 1.0; // default values for mean and sd
String mean_s = params.get("mean"), sd_s = params.get("sd");
if(mean_s != null) mean = Double.parseDouble(mean_s);
if(sd_s != null) sd = Double.parseDouble(sd_s);
if ( sd <= 0 )
throw new DMLRuntimeException("Standard deviation for Normal distribution must be positive (" + sd + ")");
distFunction = new NormalDistribution(mean, sd);
break;
case EXP:
double exp_rate = 1.0; // default value for 1/mean or rate
if(params.get("rate") != null) exp_rate = Double.parseDouble(params.get("rate"));
if ( exp_rate <= 0 ) {
throw new DMLRuntimeException("Rate for Exponential distribution must be positive (" + exp_rate + ")");
}
// For exponential distribution: mean = 1/rate
distFunction = new ExponentialDistribution(1.0/exp_rate);
break;
case CHISQ:
if ( params.get("df") == null ) {
throw new DMLRuntimeException("" +
"Degrees of freedom must be specified for chi-squared distribution " +
"(e.g., q=qchisq(0.5, df=20); p=pchisq(target=q, df=1.2))");
}
int df = UtilFunctions.parseToInt(params.get("df"));
if ( df <= 0 ) {
throw new DMLRuntimeException("Degrees of Freedom for chi-squared distribution must be positive (" + df + ")");
}
distFunction = new ChiSquaredDistribution(df);
break;
case F:
if ( params.get("df1") == null || params.get("df2") == null ) {
throw new DMLRuntimeException("" +
"Degrees of freedom must be specified for F distribution " +
"(e.g., q = qf(target=0.5, df1=20, df2=30); p=pf(target=q, df1=20, df2=30))");
}
int df1 = UtilFunctions.parseToInt(params.get("df1"));
int df2 = UtilFunctions.parseToInt(params.get("df2"));
if ( df1 <= 0 || df2 <= 0) {
throw new DMLRuntimeException("Degrees of Freedom for F distribution must be positive (" + df1 + "," + df2 + ")");
}
distFunction = new FDistribution(df1, df2);
break;
case T:
if ( params.get("df") == null ) {
throw new DMLRuntimeException("" +
"Degrees of freedom is needed to compute probabilities from t distribution " +
"(e.g., q = qt(target=0.5, df=10); p = pt(target=q, df=10))");
}
int t_df = UtilFunctions.parseToInt(params.get("df"));
if ( t_df <= 0 ) {
throw new DMLRuntimeException("Degrees of Freedom for t distribution must be positive (" + t_df + ")");
}
distFunction = new TDistribution(t_df);
break;
default:
throw new DMLRuntimeException("Invalid distribution code: " + dcode);
}
double ret = Double.NaN;
if(inverse) {
// inverse cdf
ret = distFunction.inverseCumulativeProbability(val);
}
else if(lowertail) {
// cdf (lowertail)
ret = distFunction.cumulativeProbability(val);
}
else {
// cdf (upper tail)
// TODO: more accurate distribution-specific computation of upper tail probabilities
ret = 1.0 - distFunction.cumulativeProbability(val);
}
return ret;
}
}