com.yahoo.tensor.functions.ScalarFunctions Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of vespajlib Show documentation
Show all versions of vespajlib Show documentation
Library for use in Java components of Vespa. Shared code which do
not fit anywhere else.
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Factory of scalar Java functions.
* The purpose of this is to embellish anonymous functions with a runtime type
* such that they can be inspected and will return a parsable toString.
*
* @author bratseth
*/
public class ScalarFunctions {
public static DoubleBinaryOperator add() { return new Add(); }
public static DoubleBinaryOperator divide() { return new Divide(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
public static DoubleBinaryOperator greater() { return new Greater(); }
public static DoubleBinaryOperator less() { return new Less(); }
public static DoubleBinaryOperator max() { return new Max(); }
public static DoubleBinaryOperator min() { return new Min(); }
public static DoubleBinaryOperator mean() { return new Mean(); }
public static DoubleBinaryOperator multiply() { return new Multiply(); }
public static DoubleBinaryOperator pow() { return new Pow(); }
public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); }
public static DoubleBinaryOperator subtract() { return new Subtract(); }
public static DoubleBinaryOperator hamming() { return new Hamming(); }
public static DoubleUnaryOperator abs() { return new Abs(); }
public static DoubleUnaryOperator acos() { return new Acos(); }
public static DoubleUnaryOperator asin() { return new Asin(); }
public static DoubleUnaryOperator atan() { return new Atan(); }
public static DoubleUnaryOperator ceil() { return new Ceil(); }
public static DoubleUnaryOperator cos() { return new Cos(); }
public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator floor() { return new Floor(); }
public static DoubleUnaryOperator log() { return new Log(); }
public static DoubleUnaryOperator neg() { return new Neg(); }
public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); }
public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); }
public static DoubleUnaryOperator sin() { return new Sin(); }
public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
public static DoubleUnaryOperator tan() { return new Tan(); }
public static DoubleUnaryOperator tanh() { return new Tanh(); }
public static DoubleUnaryOperator erf() { return new Erf(); }
public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); }
public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); }
public static DoubleUnaryOperator leakyrelu(double alpha) { return new LeakyRelu(alpha); }
public static DoubleUnaryOperator relu() { return new Relu(); }
public static DoubleUnaryOperator selu() { return new Selu(); }
public static DoubleUnaryOperator selu(double scale, double alpha) { return new Selu(scale, alpha); }
public static Function, Double> random() { return new Random(); }
public static Function, Double> equal(List argumentNames) { return new EqualElements(argumentNames); }
public static Function, Double> sum(List argumentNames) { return new SumElements(argumentNames); }
public static Function, Double> constant(double value) { return new Constant(value); }
// Binary operators -----------------------------------------------------------------------------
public static class Add implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left + right; }
@Override
public String toString() { return "f(a,b)(a + b)"; }
@Override
public int hashCode() { return "add".hashCode(); }
}
public static class Equal implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
@Override
public String toString() { return "f(a,b)(a==b)"; }
@Override
public int hashCode() { return "equal".hashCode(); }
}
public static class Greater implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; }
@Override
public String toString() { return "f(a,b)(a > b)"; }
@Override
public int hashCode() { return "greater".hashCode(); }
}
public static class Less implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; }
@Override
public String toString() { return "f(a,b)(a < b)"; }
@Override
public int hashCode() { return "less".hashCode(); }
}
public static class Max implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return Math.max(left, right); }
@Override
public String toString() { return "f(a,b)(max(a, b))"; }
@Override
public int hashCode() { return "max".hashCode(); }
}
public static class Min implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return Math.min(left, right); }
@Override
public String toString() { return "f(a,b)(min(a, b))"; }
@Override
public int hashCode() { return "min".hashCode(); }
}
public static class Mean implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return (left + right) / 2; }
@Override
public String toString() { return "f(a,b)((a + b) / 2)"; }
@Override
public int hashCode() { return "mean".hashCode(); }
}
public static class Multiply implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left * right; }
@Override
public String toString() { return "f(a,b)(a * b)"; }
@Override
public int hashCode() { return "multiply".hashCode(); }
}
public static class Pow implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return Math.pow(left, right); }
@Override
public String toString() { return "f(a,b)(pow(a, b))"; }
@Override
public int hashCode() { return "pow".hashCode(); }
}
public static class Divide implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left / right; }
@Override
public String toString() { return "f(a,b)(a / b)"; }
@Override
public int hashCode() { return "divide".hashCode(); }
}
public static class SquaredDifference implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return (left - right) * (left - right); }
@Override
public String toString() { return "f(a,b)((a-b) * (a-b))"; }
@Override
public int hashCode() { return "squareddifference".hashCode(); }
}
public static class Subtract implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left - right; }
@Override
public String toString() { return "f(a,b)(a - b)"; }
@Override
public int hashCode() { return "subtract".hashCode(); }
}
public static class Hamming implements DoubleBinaryOperator {
public static double hamming(double left, double right) {
double distance = 0;
byte a = (byte) left;
byte b = (byte) right;
for (int i = 0; i < 8; i++) {
byte bit = (byte) (1 << i);
if ((a & bit) != (b & bit)) {
distance += 1;
}
}
return distance;
}
@Override
public double applyAsDouble(double left, double right) { return hamming(left, right); }
@Override
public String toString() { return "f(a,b)(hamming(a,b))"; }
@Override
public int hashCode() { return "hamming".hashCode(); }
}
// Unary operators ------------------------------------------------------------------------------
public static class Abs implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.abs(operand); }
@Override
public String toString() { return "f(a)(fabs(a))"; }
@Override
public int hashCode() { return "abs".hashCode(); }
}
public static class Acos implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.acos(operand); }
@Override
public String toString() { return "f(a)(acos(a))"; }
@Override
public int hashCode() { return "acos".hashCode(); }
}
public static class Asin implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.asin(operand); }
@Override
public String toString() { return "f(a)(asin(a))"; }
@Override
public int hashCode() { return "asin".hashCode(); }
}
public static class Atan implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.atan(operand); }
@Override
public String toString() { return "f(a)(atan(a))"; }
@Override
public int hashCode() { return "atan".hashCode(); }
}
public static class Ceil implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.ceil(operand); }
@Override
public String toString() { return "f(a)(ceil(a))"; }
@Override
public int hashCode() { return "ceil".hashCode(); }
}
public static class Cos implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.cos(operand); }
@Override
public String toString() { return "f(a)(cos(a))"; }
@Override
public int hashCode() { return "cos".hashCode(); }
}
public static class Elu implements DoubleUnaryOperator {
private final double alpha;
public Elu() {
this(1.0);
}
public Elu(double alpha) {
this.alpha = alpha;
}
@Override
public double applyAsDouble(double operand) { return operand < 0 ? alpha * (Math.exp(operand) - 1) : operand; }
@Override
public String toString() { return "f(a)(if(a < 0, " + alpha + " * (exp(a)-1), a))"; }
@Override
public int hashCode() { return Objects.hash("elu", alpha); }
}
public static class Exp implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.exp(operand); }
@Override
public String toString() { return "f(a)(exp(a))"; }
@Override
public int hashCode() { return "exp".hashCode(); }
}
public static class Floor implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.floor(operand); }
@Override
public String toString() { return "f(a)(floor(a))"; }
@Override
public int hashCode() { return "floor".hashCode(); }
}
public static class Log implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.log(operand); }
@Override
public String toString() { return "f(a)(log(a))"; }
@Override
public int hashCode() { return "log".hashCode(); }
}
public static class Neg implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return -operand; }
@Override
public String toString() { return "f(a)(-a)"; }
@Override
public int hashCode() { return "neg".hashCode(); }
}
public static class Reciprocal implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return 1.0 / operand; }
@Override
public String toString() { return "f(a)(1 / a)"; }
@Override
public int hashCode() { return "reciprocal".hashCode(); }
}
public static class Relu implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.max(operand, 0); }
@Override
public String toString() { return "f(a)(max(0, a))"; }
@Override
public int hashCode() { return "relu".hashCode(); }
}
public static class Selu implements DoubleUnaryOperator {
// See https://arxiv.org/abs/1706.02515
private final double scale; // 1.0507009873554804934193349852946;
private final double alpha; // 1.6732632423543772848170429916717;
public Selu() {
this(1.0507009873554804934193349852946, 1.6732632423543772848170429916717);
}
public Selu(double scale, double alpha) {
this.scale = scale;
this.alpha = alpha;
}
@Override
public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); }
@Override
public String toString() { return "f(a)(" + scale + " * if(a >= 0, a, " + alpha + " * (exp(a) - 1)))"; }
@Override
public int hashCode() { return Objects.hash("selu", scale, alpha); }
}
public static class LeakyRelu implements DoubleUnaryOperator {
private final double alpha;
public LeakyRelu() {
this(0.01);
}
public LeakyRelu(double alpha) {
this.alpha = alpha;
}
@Override
public double applyAsDouble(double operand) { return Math.max(alpha * operand, operand); }
@Override
public String toString() { return "f(a)(max(" + alpha + " * a, a))"; }
@Override
public int hashCode() { return Objects.hash("leakyrelu", alpha); }
}
public static class Sin implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.sin(operand); }
@Override
public String toString() { return "f(a)(sin(a))"; }
@Override
public int hashCode() { return "sin".hashCode(); }
}
public static class Rsqrt implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); }
@Override
public String toString() { return "f(a)(1.0 / sqrt(a))"; }
@Override
public int hashCode() { return "rsqrt".hashCode(); }
}
public static class Sigmoid implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return 1.0 / (1.0 + Math.exp(-operand)); }
@Override
public String toString() { return "f(a)(1 / (1 + exp(-a)))"; }
@Override
public int hashCode() { return "sigmoid".hashCode(); }
}
public static class Sqrt implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.sqrt(operand); }
@Override
public String toString() { return "f(a)(sqrt(a))"; }
@Override
public int hashCode() { return "sqrt".hashCode(); }
}
public static class Square implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return operand * operand; }
@Override
public String toString() { return "f(a)(a * a)"; }
@Override
public int hashCode() { return "square".hashCode(); }
}
public static class Tan implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.tan(operand); }
@Override
public String toString() { return "f(a)(tan(a))"; }
@Override
public int hashCode() { return "tan".hashCode(); }
}
public static class Tanh implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.tanh(operand); }
@Override
public String toString() { return "f(a)(tanh(a))"; }
@Override
public int hashCode() { return "tanh".hashCode(); }
}
public static class Erf implements DoubleUnaryOperator {
static final Comparator byAbs = (x,y) -> Double.compare(Math.abs(x), Math.abs(y));
static double kummer(double a, double b, double z) {
PriorityQueue terms = new PriorityQueue<>(byAbs);
double term = 1.0;
long n = 0;
while (Math.abs(term) > Double.MIN_NORMAL) {
terms.add(term);
term *= (a+n);
term /= (b+n);
++n;
term *= z;
term /= n;
}
double sum = terms.remove();
while (! terms.isEmpty()) {
sum += terms.remove();
terms.add(sum);
sum = terms.remove();
}
return sum;
}
static double approx_erfc(double x) {
double sq = x*x;
double mult = Math.exp(-sq) / (x * Math.sqrt(Math.PI));
double term = 1.0;
long n = 1;
double sum = 0.0;
while ((sum + term) != sum) {
double pterm = term;
sum += term;
term = 0.5 * pterm * n / sq;
if (term > pterm) {
sum -= 0.5 * pterm;
return sum*mult;
}
n += 2;
pterm = term;
sum -= term;
term = 0.5 * pterm * n / sq;
if (term > pterm) {
sum += 0.5 * pterm;
return sum*mult;
}
n += 2;
}
return sum*mult;
}
@Override
public double applyAsDouble(double operand) { return erf(operand); }
@Override
public String toString() { return "f(a)(erf(a))"; }
@Override
public int hashCode() { return "erf".hashCode(); }
static final double nearZeroMultiplier = 2.0 / Math.sqrt(Math.PI);
public static double erf(double v) {
if (v < 0) {
return -erf(Math.abs(v));
}
if (v < 1.0e-10) {
// Just use the derivate when very near zero:
return v * nearZeroMultiplier;
}
if (v <= 1.0) {
// works best when v is small
return v * nearZeroMultiplier * kummer(0.5, 1.5, -v*v);
}
if (v < 4.3) {
// slower, but works with bigger v
return v * nearZeroMultiplier * Math.exp(-v*v) * kummer(1.0, 1.5, v*v);
}
// works only with "very big" v
return 1.0 - approx_erfc(v);
}
}
// Variable-length operators -----------------------------------------------------------------------------
public static class EqualElements implements Function, Double> {
private final List argumentNames;
private EqualElements(List argumentNames) {
this.argumentNames = List.copyOf(argumentNames);
}
@Override
public Double apply(List values) {
if (values.isEmpty()) return 1.0;
for (Long value : values)
if ( ! value.equals(values.get(0)))
return 0.0;
return 1.0;
}
@Override
public String toString() {
if (argumentNames.size() == 0) return "1";
if (argumentNames.size() == 1) return "1";
if (argumentNames.size() == 2) return argumentNames.get(0) + "==" + argumentNames.get(1);
StringBuilder b = new StringBuilder();
for (int i = 0; i < argumentNames.size() -1; i++) {
b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")");
if ( i < argumentNames.size() -2)
b.append("*");
}
return b.toString();
}
@Override
public int hashCode() { return Objects.hash("equal", argumentNames); }
}
public static class Random implements Function, Double> {
@Override
public Double apply(List values) {
return ThreadLocalRandom.current().nextDouble();
}
@Override
public String toString() { return "random"; }
@Override
public int hashCode() { return "random".hashCode(); }
}
public static class SumElements implements Function, Double> {
private final List argumentNames;
private SumElements(List argumentNames) {
this.argumentNames = List.copyOf(argumentNames);
}
@Override
public Double apply(List values) {
long sum = 0;
for (Long value : values)
sum += value;
return (double)sum;
}
@Override
public String toString() {
return argumentNames.stream().collect(Collectors.joining("+"));
}
@Override
public int hashCode() { return Objects.hash("sum", argumentNames); }
}
public static class Constant implements Function, Double> {
private final double value;
public Constant(double value) {
this.value = value;
}
@Override
public Double apply(List values) {
return value;
}
@Override
public String toString() { return Double.toString(value); }
@Override
public int hashCode() { return Objects.hash("constant", value); }
}
}