Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.yahoo.tensor.functions.ScalarFunctions Maven / Gradle / Ivy
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Factory of scalar Java functions.
* The purpose of this is to embellish anonymous functions with a runtime type
* such that they can be inspected and will return a parsable toString.
*
* @author bratseth
*/
public class ScalarFunctions {
public static DoubleBinaryOperator add() { return new Add(); }
public static DoubleBinaryOperator divide() { return new Divide(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
public static DoubleBinaryOperator greater() { return new Greater(); }
public static DoubleBinaryOperator less() { return new Less(); }
public static DoubleBinaryOperator max() { return new Max(); }
public static DoubleBinaryOperator min() { return new Min(); }
public static DoubleBinaryOperator mean() { return new Mean(); }
public static DoubleBinaryOperator multiply() { return new Multiply(); }
public static DoubleBinaryOperator pow() { return new Pow(); }
public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); }
public static DoubleBinaryOperator subtract() { return new Subtract(); }
public static DoubleBinaryOperator hamming() { return new Hamming(); }
public static DoubleUnaryOperator abs() { return new Abs(); }
public static DoubleUnaryOperator acos() { return new Acos(); }
public static DoubleUnaryOperator asin() { return new Asin(); }
public static DoubleUnaryOperator atan() { return new Atan(); }
public static DoubleUnaryOperator ceil() { return new Ceil(); }
public static DoubleUnaryOperator cos() { return new Cos(); }
public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator floor() { return new Floor(); }
public static DoubleUnaryOperator log() { return new Log(); }
public static DoubleUnaryOperator neg() { return new Neg(); }
public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); }
public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); }
public static DoubleUnaryOperator sin() { return new Sin(); }
public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
public static DoubleUnaryOperator tan() { return new Tan(); }
public static DoubleUnaryOperator tanh() { return new Tanh(); }
public static DoubleUnaryOperator erf() { return new Erf(); }
public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); }
public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); }
public static DoubleUnaryOperator leakyrelu(double alpha) { return new LeakyRelu(alpha); }
public static DoubleUnaryOperator relu() { return new Relu(); }
public static DoubleUnaryOperator selu() { return new Selu(); }
public static DoubleUnaryOperator selu(double scale, double alpha) { return new Selu(scale, alpha); }
public static Function, Double> random() { return new Random(); }
public static Function, Double> equal(List argumentNames) { return new EqualElements(argumentNames); }
public static Function, Double> sum(List argumentNames) { return new SumElements(argumentNames); }
public static Function, Double> constant(double value) { return new Constant(value); }
// Binary operators -----------------------------------------------------------------------------
public static class Add implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left + right; }
@Override
public String toString() { return "f(a,b)(a + b)"; }
@Override
public int hashCode() { return "add".hashCode(); }
}
public static class Equal implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
@Override
public String toString() { return "f(a,b)(a==b)"; }
@Override
public int hashCode() { return "equal".hashCode(); }
}
public static class Greater implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; }
@Override
public String toString() { return "f(a,b)(a > b)"; }
@Override
public int hashCode() { return "greater".hashCode(); }
}
public static class Less implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; }
@Override
public String toString() { return "f(a,b)(a < b)"; }
@Override
public int hashCode() { return "less".hashCode(); }
}
public static class Max implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return Math.max(left, right); }
@Override
public String toString() { return "f(a,b)(max(a, b))"; }
@Override
public int hashCode() { return "max".hashCode(); }
}
public static class Min implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return Math.min(left, right); }
@Override
public String toString() { return "f(a,b)(min(a, b))"; }
@Override
public int hashCode() { return "min".hashCode(); }
}
public static class Mean implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return (left + right) / 2; }
@Override
public String toString() { return "f(a,b)((a + b) / 2)"; }
@Override
public int hashCode() { return "mean".hashCode(); }
}
public static class Multiply implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left * right; }
@Override
public String toString() { return "f(a,b)(a * b)"; }
@Override
public int hashCode() { return "multiply".hashCode(); }
}
public static class Pow implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return Math.pow(left, right); }
@Override
public String toString() { return "f(a,b)(pow(a, b))"; }
@Override
public int hashCode() { return "pow".hashCode(); }
}
public static class Divide implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left / right; }
@Override
public String toString() { return "f(a,b)(a / b)"; }
@Override
public int hashCode() { return "divide".hashCode(); }
}
public static class SquaredDifference implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return (left - right) * (left - right); }
@Override
public String toString() { return "f(a,b)((a-b) * (a-b))"; }
@Override
public int hashCode() { return "squareddifference".hashCode(); }
}
public static class Subtract implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left - right; }
@Override
public String toString() { return "f(a,b)(a - b)"; }
@Override
public int hashCode() { return "subtract".hashCode(); }
}
public static class Hamming implements DoubleBinaryOperator {
public static double hamming(double left, double right) {
double distance = 0;
byte a = (byte) left;
byte b = (byte) right;
for (int i = 0; i < 8; i++) {
byte bit = (byte) (1 << i);
if ((a & bit) != (b & bit)) {
distance += 1;
}
}
return distance;
}
@Override
public double applyAsDouble(double left, double right) { return hamming(left, right); }
@Override
public String toString() { return "f(a,b)(hamming(a,b))"; }
@Override
public int hashCode() { return "hamming".hashCode(); }
}
// Unary operators ------------------------------------------------------------------------------
public static class Abs implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.abs(operand); }
@Override
public String toString() { return "f(a)(fabs(a))"; }
@Override
public int hashCode() { return "abs".hashCode(); }
}
public static class Acos implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.acos(operand); }
@Override
public String toString() { return "f(a)(acos(a))"; }
@Override
public int hashCode() { return "acos".hashCode(); }
}
public static class Asin implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.asin(operand); }
@Override
public String toString() { return "f(a)(asin(a))"; }
@Override
public int hashCode() { return "asin".hashCode(); }
}
public static class Atan implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.atan(operand); }
@Override
public String toString() { return "f(a)(atan(a))"; }
@Override
public int hashCode() { return "atan".hashCode(); }
}
public static class Ceil implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.ceil(operand); }
@Override
public String toString() { return "f(a)(ceil(a))"; }
@Override
public int hashCode() { return "ceil".hashCode(); }
}
public static class Cos implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.cos(operand); }
@Override
public String toString() { return "f(a)(cos(a))"; }
@Override
public int hashCode() { return "cos".hashCode(); }
}
public static class Elu implements DoubleUnaryOperator {
private final double alpha;
public Elu() {
this(1.0);
}
public Elu(double alpha) {
this.alpha = alpha;
}
@Override
public double applyAsDouble(double operand) { return operand < 0 ? alpha * (Math.exp(operand) - 1) : operand; }
@Override
public String toString() { return "f(a)(if(a < 0, " + alpha + " * (exp(a)-1), a))"; }
@Override
public int hashCode() { return Objects.hash("elu", alpha); }
}
public static class Exp implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.exp(operand); }
@Override
public String toString() { return "f(a)(exp(a))"; }
@Override
public int hashCode() { return "exp".hashCode(); }
}
public static class Floor implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.floor(operand); }
@Override
public String toString() { return "f(a)(floor(a))"; }
@Override
public int hashCode() { return "floor".hashCode(); }
}
public static class Log implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.log(operand); }
@Override
public String toString() { return "f(a)(log(a))"; }
@Override
public int hashCode() { return "log".hashCode(); }
}
public static class Neg implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return -operand; }
@Override
public String toString() { return "f(a)(-a)"; }
@Override
public int hashCode() { return "neg".hashCode(); }
}
public static class Reciprocal implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return 1.0 / operand; }
@Override
public String toString() { return "f(a)(1 / a)"; }
@Override
public int hashCode() { return "reciprocal".hashCode(); }
}
public static class Relu implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.max(operand, 0); }
@Override
public String toString() { return "f(a)(max(0, a))"; }
@Override
public int hashCode() { return "relu".hashCode(); }
}
public static class Selu implements DoubleUnaryOperator {
// See https://arxiv.org/abs/1706.02515
private final double scale; // 1.0507009873554804934193349852946;
private final double alpha; // 1.6732632423543772848170429916717;
public Selu() {
this(1.0507009873554804934193349852946, 1.6732632423543772848170429916717);
}
public Selu(double scale, double alpha) {
this.scale = scale;
this.alpha = alpha;
}
@Override
public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); }
@Override
public String toString() { return "f(a)(" + scale + " * if(a >= 0, a, " + alpha + " * (exp(a) - 1)))"; }
@Override
public int hashCode() { return Objects.hash("selu", scale, alpha); }
}
public static class LeakyRelu implements DoubleUnaryOperator {
private final double alpha;
public LeakyRelu() {
this(0.01);
}
public LeakyRelu(double alpha) {
this.alpha = alpha;
}
@Override
public double applyAsDouble(double operand) { return Math.max(alpha * operand, operand); }
@Override
public String toString() { return "f(a)(max(" + alpha + " * a, a))"; }
@Override
public int hashCode() { return Objects.hash("leakyrelu", alpha); }
}
public static class Sin implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.sin(operand); }
@Override
public String toString() { return "f(a)(sin(a))"; }
@Override
public int hashCode() { return "sin".hashCode(); }
}
public static class Rsqrt implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); }
@Override
public String toString() { return "f(a)(1.0 / sqrt(a))"; }
@Override
public int hashCode() { return "rsqrt".hashCode(); }
}
public static class Sigmoid implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return 1.0 / (1.0 + Math.exp(-operand)); }
@Override
public String toString() { return "f(a)(1 / (1 + exp(-a)))"; }
@Override
public int hashCode() { return "sigmoid".hashCode(); }
}
public static class Sqrt implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.sqrt(operand); }
@Override
public String toString() { return "f(a)(sqrt(a))"; }
@Override
public int hashCode() { return "sqrt".hashCode(); }
}
public static class Square implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return operand * operand; }
@Override
public String toString() { return "f(a)(a * a)"; }
@Override
public int hashCode() { return "square".hashCode(); }
}
public static class Tan implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.tan(operand); }
@Override
public String toString() { return "f(a)(tan(a))"; }
@Override
public int hashCode() { return "tan".hashCode(); }
}
public static class Tanh implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.tanh(operand); }
@Override
public String toString() { return "f(a)(tanh(a))"; }
@Override
public int hashCode() { return "tanh".hashCode(); }
}
public static class Erf implements DoubleUnaryOperator {
static final Comparator byAbs = (x,y) -> Double.compare(Math.abs(x), Math.abs(y));
static double kummer(double a, double b, double z) {
PriorityQueue terms = new PriorityQueue<>(byAbs);
double term = 1.0;
long n = 0;
while (Math.abs(term) > Double.MIN_NORMAL) {
terms.add(term);
term *= (a+n);
term /= (b+n);
++n;
term *= z;
term /= n;
}
double sum = terms.remove();
while (! terms.isEmpty()) {
sum += terms.remove();
terms.add(sum);
sum = terms.remove();
}
return sum;
}
static double approx_erfc(double x) {
double sq = x*x;
double mult = Math.exp(-sq) / (x * Math.sqrt(Math.PI));
double term = 1.0;
long n = 1;
double sum = 0.0;
while ((sum + term) != sum) {
double pterm = term;
sum += term;
term = 0.5 * pterm * n / sq;
if (term > pterm) {
sum -= 0.5 * pterm;
return sum*mult;
}
n += 2;
pterm = term;
sum -= term;
term = 0.5 * pterm * n / sq;
if (term > pterm) {
sum += 0.5 * pterm;
return sum*mult;
}
n += 2;
}
return sum*mult;
}
@Override
public double applyAsDouble(double operand) { return erf(operand); }
@Override
public String toString() { return "f(a)(erf(a))"; }
@Override
public int hashCode() { return "erf".hashCode(); }
static final double nearZeroMultiplier = 2.0 / Math.sqrt(Math.PI);
public static double erf(double v) {
if (v < 0) {
return -erf(Math.abs(v));
}
if (v < 1.0e-10) {
// Just use the derivate when very near zero:
return v * nearZeroMultiplier;
}
if (v <= 1.0) {
// works best when v is small
return v * nearZeroMultiplier * kummer(0.5, 1.5, -v*v);
}
if (v < 4.3) {
// slower, but works with bigger v
return v * nearZeroMultiplier * Math.exp(-v*v) * kummer(1.0, 1.5, v*v);
}
// works only with "very big" v
return 1.0 - approx_erfc(v);
}
}
// Variable-length operators -----------------------------------------------------------------------------
public static class EqualElements implements Function, Double> {
private final List argumentNames;
private EqualElements(List argumentNames) {
this.argumentNames = List.copyOf(argumentNames);
}
@Override
public Double apply(List values) {
if (values.isEmpty()) return 1.0;
for (Long value : values)
if ( ! value.equals(values.get(0)))
return 0.0;
return 1.0;
}
@Override
public String toString() {
if (argumentNames.size() == 0) return "1";
if (argumentNames.size() == 1) return "1";
if (argumentNames.size() == 2) return argumentNames.get(0) + "==" + argumentNames.get(1);
StringBuilder b = new StringBuilder();
for (int i = 0; i < argumentNames.size() -1; i++) {
b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")");
if ( i < argumentNames.size() -2)
b.append("*");
}
return b.toString();
}
@Override
public int hashCode() { return Objects.hash("equal", argumentNames); }
}
public static class Random implements Function, Double> {
@Override
public Double apply(List values) {
return ThreadLocalRandom.current().nextDouble();
}
@Override
public String toString() { return "random"; }
@Override
public int hashCode() { return "random".hashCode(); }
}
public static class SumElements implements Function, Double> {
private final List argumentNames;
private SumElements(List argumentNames) {
this.argumentNames = List.copyOf(argumentNames);
}
@Override
public Double apply(List values) {
long sum = 0;
for (Long value : values)
sum += value;
return (double)sum;
}
@Override
public String toString() {
return argumentNames.stream().collect(Collectors.joining("+"));
}
@Override
public int hashCode() { return Objects.hash("sum", argumentNames); }
}
public static class Constant implements Function, Double> {
private final double value;
public Constant(double value) {
this.value = value;
}
@Override
public Double apply(List values) {
return value;
}
@Override
public String toString() { return Double.toString(value); }
@Override
public int hashCode() { return Objects.hash("constant", value); }
}
}