All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.tensor.functions.ScalarFunctions Maven / Gradle / Ivy

// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * Factory of scalar Java functions.
 * The purpose of this is to embellish anonymous functions with a runtime type
 * such that they can be inspected and will return a parsable toString.
 *
 * @author bratseth
 */
public class ScalarFunctions {

    public static DoubleBinaryOperator add() { return new Add(); }
    public static DoubleBinaryOperator divide() { return new Divide(); }
    public static DoubleBinaryOperator equal() { return new Equal(); }
    public static DoubleBinaryOperator greater() { return new Greater(); }
    public static DoubleBinaryOperator less() { return new Less(); }
    public static DoubleBinaryOperator max() { return new Max(); }
    public static DoubleBinaryOperator min() { return new Min(); }
    public static DoubleBinaryOperator mean() { return new Mean(); }
    public static DoubleBinaryOperator multiply() { return new Multiply(); }
    public static DoubleBinaryOperator pow() { return new Pow(); }
    public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); }
    public static DoubleBinaryOperator subtract() { return new Subtract(); }
    public static DoubleBinaryOperator hamming() { return new Hamming(); }

    public static DoubleUnaryOperator abs() { return new Abs(); }
    public static DoubleUnaryOperator acos() { return new Acos(); }
    public static DoubleUnaryOperator asin() { return new Asin(); }
    public static DoubleUnaryOperator atan() { return new Atan(); }
    public static DoubleUnaryOperator ceil() { return new Ceil(); }
    public static DoubleUnaryOperator cos() { return new Cos(); }
    public static DoubleUnaryOperator exp() { return new Exp(); }
    public static DoubleUnaryOperator floor() { return new Floor(); }
    public static DoubleUnaryOperator log() { return new Log(); }
    public static DoubleUnaryOperator neg() { return new Neg(); }
    public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); }
    public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); }
    public static DoubleUnaryOperator sin() { return new Sin(); }
    public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); }
    public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
    public static DoubleUnaryOperator square() { return new Square(); }
    public static DoubleUnaryOperator tan() { return new Tan(); }
    public static DoubleUnaryOperator tanh() { return new Tanh(); }
    public static DoubleUnaryOperator erf() { return new Erf(); }

    public static DoubleUnaryOperator elu() { return new Elu(); }
    public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); }
    public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); }
    public static DoubleUnaryOperator leakyrelu(double alpha) { return new LeakyRelu(alpha); }
    public static DoubleUnaryOperator relu() { return new Relu(); }
    public static DoubleUnaryOperator selu() { return new Selu(); }
    public static DoubleUnaryOperator selu(double scale, double alpha) { return new Selu(scale, alpha); }

    public static Function, Double> random() { return new Random(); }
    public static Function, Double> equal(List argumentNames) { return new EqualElements(argumentNames); }
    public static Function, Double> sum(List argumentNames) { return new SumElements(argumentNames); }
    public static Function, Double> constant(double value) { return new Constant(value); }

    // Binary operators -----------------------------------------------------------------------------

    public static class Add implements DoubleBinaryOperator {
        @Override
        public double applyAsDouble(double left, double right) { return left + right; }
        @Override
        public String toString() { return "f(a,b)(a + b)"; }
        @Override
        public int hashCode() { return "add".hashCode(); }
    }

    public static class Equal implements DoubleBinaryOperator {
        @Override
        public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
        @Override
        public String toString() { return "f(a,b)(a==b)"; }
        @Override
        public int hashCode() { return "equal".hashCode(); }
    }

    public static class Greater implements DoubleBinaryOperator {
        @Override
        public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; }
        @Override
        public String toString() { return "f(a,b)(a > b)"; }
        @Override
        public int hashCode() { return "greater".hashCode(); }
    }

    public static class Less implements DoubleBinaryOperator {
        @Override
        public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; }
        @Override
        public String toString() { return "f(a,b)(a < b)"; }
        @Override
        public int hashCode() { return "less".hashCode(); }
    }

    public static class Max implements DoubleBinaryOperator {
        @Override
        public double applyAsDouble(double left, double right) { return Math.max(left, right); }
        @Override
        public String toString() { return "f(a,b)(max(a, b))"; }
        @Override
        public int hashCode() { return "max".hashCode(); }
    }

    public static class Min implements DoubleBinaryOperator {
        @Override
        public double applyAsDouble(double left, double right) { return Math.min(left, right); }
        @Override
        public String toString() { return "f(a,b)(min(a, b))"; }
        @Override
        public int hashCode() { return "min".hashCode(); }
    }

    public static class Mean implements DoubleBinaryOperator {
        @Override
        public double applyAsDouble(double left, double right) { return (left + right) / 2; }
        @Override
        public String toString() { return "f(a,b)((a + b) / 2)"; }
        @Override
        public int hashCode() { return "mean".hashCode(); }
    }

    public static class Multiply implements DoubleBinaryOperator {
        @Override
        public double applyAsDouble(double left, double right) { return left * right; }
        @Override
        public String toString() { return "f(a,b)(a * b)"; }
        @Override
        public int hashCode() { return "multiply".hashCode(); }
    }

    public static class Pow implements DoubleBinaryOperator {
        @Override
        public double applyAsDouble(double left, double right) { return Math.pow(left, right); }
        @Override
        public String toString() { return "f(a,b)(pow(a, b))"; }
        @Override
        public int hashCode() { return "pow".hashCode(); }
    }

    public static class Divide implements DoubleBinaryOperator {
        @Override
        public double applyAsDouble(double left, double right) { return left / right; }
        @Override
        public String toString() { return "f(a,b)(a / b)"; }
        @Override
        public int hashCode() { return "divide".hashCode(); }
    }

    public static class SquaredDifference implements DoubleBinaryOperator {
        @Override
        public double applyAsDouble(double left, double right) { return (left - right) * (left - right); }
        @Override
        public String toString() { return "f(a,b)((a-b) * (a-b))"; }
        @Override
        public int hashCode() { return "squareddifference".hashCode(); }
    }

    public static class Subtract implements DoubleBinaryOperator {
        @Override
        public double applyAsDouble(double left, double right) { return left - right; }
        @Override
        public String toString() { return "f(a,b)(a - b)"; }
        @Override
        public int hashCode() { return "subtract".hashCode(); }
    }

    
    public static class Hamming implements DoubleBinaryOperator {
        public static double hamming(double left, double right) {
            double distance = 0;
            byte a = (byte) left;
            byte b = (byte) right;
            for (int i = 0; i < 8; i++) {
                byte bit = (byte) (1 << i);
                if ((a & bit) != (b & bit)) {
                    distance += 1;
                }
            }
            return distance;
        }
        @Override
        public double applyAsDouble(double left, double right) { return hamming(left, right); }
        @Override
        public String toString() { return "f(a,b)(hamming(a,b))"; }
        @Override
        public int hashCode() { return "hamming".hashCode(); }
    }


    // Unary operators ------------------------------------------------------------------------------

    public static class Abs implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.abs(operand); }
        @Override
        public String toString() { return "f(a)(fabs(a))"; }
        @Override
        public int hashCode() { return "abs".hashCode(); }
    }

    public static class Acos implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.acos(operand); }
        @Override
        public String toString() { return "f(a)(acos(a))"; }
        @Override
        public int hashCode() { return "acos".hashCode(); }
    }

    public static class Asin implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.asin(operand); }
        @Override
        public String toString() { return "f(a)(asin(a))"; }
        @Override
        public int hashCode() { return "asin".hashCode(); }
    }

    public static class Atan implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.atan(operand); }
        @Override
        public String toString() { return "f(a)(atan(a))"; }
        @Override
        public int hashCode() { return "atan".hashCode(); }
    }

    public static class Ceil implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.ceil(operand); }
        @Override
        public String toString() { return "f(a)(ceil(a))"; }
        @Override
        public int hashCode() { return "ceil".hashCode(); }
    }

    public static class Cos implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.cos(operand); }
        @Override
        public String toString() { return "f(a)(cos(a))"; }
        @Override
        public int hashCode() { return "cos".hashCode(); }
    }

    public static class Elu implements DoubleUnaryOperator {
        private final double alpha;
        public Elu() {
            this(1.0);
        }
        public Elu(double alpha) {
            this.alpha = alpha;
        }
        @Override
        public double applyAsDouble(double operand) { return operand < 0 ? alpha * (Math.exp(operand) - 1) : operand; }
        @Override
        public String toString() { return "f(a)(if(a < 0, " + alpha + " * (exp(a)-1), a))"; }
        @Override
        public int hashCode() { return Objects.hash("elu", alpha); }
    }

    public static class Exp implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.exp(operand); }
        @Override
        public String toString() { return "f(a)(exp(a))"; }
        @Override
        public int hashCode() { return "exp".hashCode(); }
    }

    public static class Floor implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.floor(operand); }
        @Override
        public String toString() { return "f(a)(floor(a))"; }
        @Override
        public int hashCode() { return "floor".hashCode(); }
    }

    public static class Log implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.log(operand); }
        @Override
        public String toString() { return "f(a)(log(a))"; }
        @Override
        public int hashCode() { return "log".hashCode(); }
    }

    public static class Neg implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return -operand; }
        @Override
        public String toString() { return "f(a)(-a)"; }
        @Override
        public int hashCode() { return "neg".hashCode(); }
    }

    public static class Reciprocal implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return 1.0 / operand; }
        @Override
        public String toString() { return "f(a)(1 / a)"; }
        @Override
        public int hashCode() { return "reciprocal".hashCode(); }
    }

    public static class Relu implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.max(operand, 0); }
        @Override
        public String toString() { return "f(a)(max(0, a))"; }
        @Override
        public int hashCode() { return "relu".hashCode(); }
    }

    public static class Selu implements DoubleUnaryOperator {
        // See https://arxiv.org/abs/1706.02515
        private final double scale; // 1.0507009873554804934193349852946;
        private final double alpha; // 1.6732632423543772848170429916717;
        public Selu() {
            this(1.0507009873554804934193349852946, 1.6732632423543772848170429916717);
        }
        public Selu(double scale, double alpha) {
            this.scale = scale;
            this.alpha = alpha;
        }
        @Override
        public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); }
        @Override
        public String toString() { return "f(a)(" + scale + " * if(a >= 0, a, " + alpha + " * (exp(a) - 1)))"; }
        @Override
        public int hashCode() { return Objects.hash("selu", scale, alpha); }
    }

    public static class LeakyRelu implements DoubleUnaryOperator {
        private final double alpha;
        public LeakyRelu() {
            this(0.01);
        }
        public LeakyRelu(double alpha) {
            this.alpha = alpha;
        }
        @Override
        public double applyAsDouble(double operand) { return Math.max(alpha * operand, operand); }
        @Override
        public String toString() { return "f(a)(max(" + alpha + " * a, a))"; }
        @Override
        public int hashCode() { return Objects.hash("leakyrelu", alpha); }
    }

    public static class Sin implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.sin(operand); }
        @Override
        public String toString() { return "f(a)(sin(a))"; }
        @Override
        public int hashCode() { return "sin".hashCode(); }
    }

    public static class Rsqrt implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); }
        @Override
        public String toString() { return "f(a)(1.0 / sqrt(a))"; }
        @Override
        public int hashCode() { return "rsqrt".hashCode(); }
    }

    public static class Sigmoid implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return 1.0 / (1.0 + Math.exp(-operand)); }
        @Override
        public String toString() { return "f(a)(1 / (1 + exp(-a)))"; }
        @Override
        public int hashCode() { return "sigmoid".hashCode(); }
    }

    public static class Sqrt implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.sqrt(operand); }
        @Override
        public String toString() { return "f(a)(sqrt(a))"; }
        @Override
        public int hashCode() { return "sqrt".hashCode(); }
    }

    public static class Square implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return operand * operand; }
        @Override
        public String toString() { return "f(a)(a * a)"; }
        @Override
        public int hashCode() { return "square".hashCode(); }
    }

    public static class Tan implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.tan(operand); }
        @Override
        public String toString() { return "f(a)(tan(a))"; }
        @Override
        public int hashCode() { return "tan".hashCode(); }
    }

    public static class Tanh implements DoubleUnaryOperator {
        @Override
        public double applyAsDouble(double operand) { return Math.tanh(operand); }
        @Override
        public String toString() { return "f(a)(tanh(a))"; }
        @Override
        public int hashCode() { return "tanh".hashCode(); }
    }

    public static class Erf implements DoubleUnaryOperator {
        static final Comparator byAbs = (x,y) -> Double.compare(Math.abs(x), Math.abs(y));

        static double kummer(double a, double b, double z) {
            PriorityQueue terms = new PriorityQueue<>(byAbs);
            double term = 1.0;
            long n = 0;
            while (Math.abs(term) > Double.MIN_NORMAL) {
                terms.add(term);
                term *= (a+n);
                term /= (b+n);
                ++n;
                term *= z;
                term /= n;
            }
            double sum = terms.remove();
            while (! terms.isEmpty()) {
                sum += terms.remove();
                terms.add(sum);
                sum = terms.remove();
            }
            return sum;
        }

        static double approx_erfc(double x) {
            double sq = x*x;
            double mult = Math.exp(-sq) / (x * Math.sqrt(Math.PI));
            double term = 1.0;
            long n = 1;
            double sum = 0.0;
            while ((sum + term) != sum) {
                double pterm = term;
                sum += term;
                term = 0.5 * pterm * n / sq;
                if (term > pterm) {
                    sum -= 0.5 * pterm;
                    return sum*mult;
                }
                n += 2;
                pterm = term;
                sum -= term;
                term = 0.5 * pterm * n / sq;
                if (term > pterm) {
                    sum += 0.5 * pterm;
                    return sum*mult;
                }
                n += 2;
            }
            return sum*mult;
        }

        @Override
        public double applyAsDouble(double operand) { return erf(operand); }
        @Override
        public String toString() { return "f(a)(erf(a))"; }
        @Override
        public int hashCode() { return "erf".hashCode(); }

        static final double nearZeroMultiplier = 2.0 / Math.sqrt(Math.PI);

        public static double erf(double v) {
            if (v < 0) {
                return -erf(Math.abs(v));
            }
            if (v < 1.0e-10) {
                // Just use the derivate when very near zero:
                return v * nearZeroMultiplier;
            }
            if (v <= 1.0) {
                // works best when v is small
                return v * nearZeroMultiplier * kummer(0.5, 1.5, -v*v);
            }
            if (v < 4.3) {
                // slower, but works with bigger v
                return v * nearZeroMultiplier * Math.exp(-v*v) * kummer(1.0, 1.5, v*v);
            }
            // works only with "very big" v
            return 1.0 - approx_erfc(v);
        }
    }

    // Variable-length operators -----------------------------------------------------------------------------

    public static class EqualElements implements Function, Double> {
        private final List argumentNames;
        private EqualElements(List argumentNames) {
            this.argumentNames = List.copyOf(argumentNames);
        }

        @Override
        public Double apply(List values) {
            if (values.isEmpty()) return 1.0;
            for (Long value : values)
                if ( ! value.equals(values.get(0)))
                    return 0.0;
            return 1.0;
        }
        @Override
        public String toString() {
            if (argumentNames.size() == 0) return "1";
            if (argumentNames.size() == 1) return "1";
            if (argumentNames.size() == 2) return argumentNames.get(0) + "==" + argumentNames.get(1);

            StringBuilder b = new StringBuilder();
            for (int i = 0; i < argumentNames.size() -1; i++) {
                b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")");
                if ( i < argumentNames.size() -2)
                    b.append("*");
            }
            return b.toString();
        }
        @Override
        public int hashCode() { return Objects.hash("equal", argumentNames); }
    }

    public static class Random implements Function, Double> {
        @Override
        public Double apply(List values) {
            return ThreadLocalRandom.current().nextDouble();
        }
        @Override
        public String toString() { return "random"; }
        @Override
        public int hashCode() { return "random".hashCode(); }
    }

    public static class SumElements implements Function, Double> {
        private final List argumentNames;
        private SumElements(List argumentNames) {
            this.argumentNames = List.copyOf(argumentNames);
        }

        @Override
        public Double apply(List values) {
            long sum = 0;
            for (Long value : values)
                sum += value;
            return (double)sum;
        }
        @Override
        public String toString() {
            return argumentNames.stream().collect(Collectors.joining("+"));
        }
        @Override
        public int hashCode() { return Objects.hash("sum", argumentNames); }
    }

    public static class Constant implements Function, Double> {
        private final double value;

        public Constant(double value) {
            this.value = value;
        }
        @Override
        public Double apply(List values) {
            return value;
        }
        @Override
        public String toString() { return Double.toString(value); }
        @Override
        public int hashCode() { return Objects.hash("constant", value); }
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy