All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.utils.math.FastMath Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2012-2015 Jeff Hain
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
/*
 * =============================================================================
 * Notice of fdlibm package this program is partially derived from:
 *
 * Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
 *
 * Developed at SunSoft, a Sun Microsystems, Inc. business.
 * Permission to use, copy, modify, and distribute this
 * software is freely granted, provided that this notice
 * is preserved.
 * =============================================================================
 */
// This file contains a modified version of Jafama's FastMath:
// https://github.com/jeffhain/jafama/blob/master/src/main/java/net/jafama/FastMath.java
package hivemall.utils.math;

import hivemall.annotations.Experimental;

@Experimental
public final class FastMath {

    private FastMath() {}

    @Deprecated
    public static float sqrt(final float x) {
        return x * invSqrt(x);
    }

    @Deprecated
    public static double sqrt(final double x) {
        return x * invSqrt(x);
    }

    /**
     * https://en.wikipedia.org/wiki/Fast_inverse_square_root
     */
    @Deprecated
    public static float invSqrt(final float x) {
        final float hx = 0.5f * x;
        int i = 0x5f375a86 - (Float.floatToRawIntBits(x) >>> 1);
        float y = Float.intBitsToFloat(i);
        y *= (1.5f - hx * y * y); // pass 1
        y *= (1.5f - hx * y * y); // pass 2
        y *= (1.5f - hx * y * y); // pass 3
        //y *= (1.5f - hx * y * y); // pass 4
        // more pass for more accuracy
        return y;
    }

    /**
     * https://en.wikipedia.org/wiki/Fast_inverse_square_root
     */
    @Deprecated
    public static double invSqrt(final double x) {
        final double hx = 0.5d * x;
        long i = 0x5fe6eb50c7b537a9L - (Double.doubleToRawLongBits(x) >>> 1);
        double y = Double.longBitsToDouble(i);
        y *= (1.5d - hx * y * y); // pass 1
        y *= (1.5d - hx * y * y); // pass 2
        y *= (1.5d - hx * y * y); // pass 3
        y *= (1.5d - hx * y * y); // pass 4
        // more pass for more accuracy
        return y;
    }

    public static double log(final double x) {
        return JafamaMath.log(x);
    }

    /**
     * @return log(1+x)
     */
    public static double log1p(final double x) {
        return JafamaMath.log1p(x);
    }

    /**
     * https://martin.ankerl.com/2007/02/11/optimized-exponential-functions-for-java/
     * 
     * @return e^x
     */
    public static double exp(final double x) {
        return JafamaMath.exp(x);
    }

    /**
     * @return exp(x)-1
     */
    public static double expm1(final double x) {
        return JafamaMath.expm1(x);
    }

    public static double sigmoid(final double x) {
        return 1 / (1 + exp(-x));
    }

    /**
     * Based on Jafama (https://github.com/jeffhain/jafama/) version 2.2.
     */
    private static final class JafamaMath {

        static final double TWO_POW_52 = twoPow(52);

        /**
         * Double.MIN_NORMAL since Java 6.
         */
        static final double DOUBLE_MIN_NORMAL = Double.longBitsToDouble(0x0010000000000000L); // 2.2250738585072014E-308

        // Not storing float/double mantissa size in constants,
        // for 23 and 52 are shorter to read and more
        // bitwise-explicit than some constant's name.

        static final int MIN_DOUBLE_EXPONENT = -1074;
        static final int MAX_DOUBLE_EXPONENT = 1023;

        static final double LOG_2 = StrictMath.log(2.0);

        //--------------------------------------------------------------------------
        // CONSTANTS AND TABLES FOR EXP AND EXPM1
        //--------------------------------------------------------------------------

        static final double EXP_OVERFLOW_LIMIT = Double.longBitsToDouble(0x40862E42FEFA39EFL); // 7.09782712893383973096e+02
        static final double EXP_UNDERFLOW_LIMIT = Double.longBitsToDouble(0xC0874910D52D3051L); // -7.45133219101941108420e+02
        static final int EXP_LO_DISTANCE_TO_ZERO_POT = 0;
        static final int EXP_LO_DISTANCE_TO_ZERO = (1 << EXP_LO_DISTANCE_TO_ZERO_POT);
        static final int EXP_LO_TAB_SIZE_POT = 11;
        static final int EXP_LO_TAB_SIZE = (1 << EXP_LO_TAB_SIZE_POT) + 1;
        static final int EXP_LO_TAB_MID_INDEX = ((EXP_LO_TAB_SIZE - 1) / 2);
        static final int EXP_LO_INDEXING = EXP_LO_TAB_MID_INDEX / EXP_LO_DISTANCE_TO_ZERO;
        static final int EXP_LO_INDEXING_DIV_SHIFT =
                EXP_LO_TAB_SIZE_POT - 1 - EXP_LO_DISTANCE_TO_ZERO_POT;

        static final class MyTExp {
            static final double[] expHiTab =
                    new double[1 + (int) EXP_OVERFLOW_LIMIT - (int) EXP_UNDERFLOW_LIMIT];
            static final double[] expLoPosTab = new double[EXP_LO_TAB_SIZE];
            static final double[] expLoNegTab = new double[EXP_LO_TAB_SIZE];

            static {
                init();
            }

            private static strictfp void init() {
                for (int i = (int) EXP_UNDERFLOW_LIMIT; i <= (int) EXP_OVERFLOW_LIMIT; i++) {
                    expHiTab[i - (int) EXP_UNDERFLOW_LIMIT] = StrictMath.exp(i);
                }
                for (int i = 0; i < EXP_LO_TAB_SIZE; i++) {
                    // x: in [-EXPM1_DISTANCE_TO_ZERO,EXPM1_DISTANCE_TO_ZERO].
                    double x = -EXP_LO_DISTANCE_TO_ZERO + i / (double) EXP_LO_INDEXING;
                    // exp(x)
                    expLoPosTab[i] = StrictMath.exp(x);
                    // 1-exp(-x), accurately computed
                    expLoNegTab[i] = -StrictMath.expm1(-x);
                }
            }
        }

        //--------------------------------------------------------------------------
        // CONSTANTS AND TABLES FOR LOG AND LOG1P
        //--------------------------------------------------------------------------

        static final int LOG_BITS = 12;
        static final int LOG_TAB_SIZE = (1 << LOG_BITS);

        static final class MyTLog {
            static final double[] logXLogTab = new double[LOG_TAB_SIZE];
            static final double[] logXTab = new double[LOG_TAB_SIZE];
            static final double[] logXInvTab = new double[LOG_TAB_SIZE];

            static {
                init();
            }

            private static strictfp void init() {
                for (int i = 0; i < LOG_TAB_SIZE; i++) {
                    // Exact to use inverse of tab size, since it is a power of two.
                    double x = 1 + i * (1.0 / LOG_TAB_SIZE);
                    logXLogTab[i] = StrictMath.log(x);
                    logXTab[i] = x;
                    logXInvTab[i] = 1 / x;
                }
            }
        }

        /**
         * @param value A double value.
         * @return e^value.
         */
        static double exp(final double value) {
            // exp(x) = exp([x])*exp(y)
            // with [x] the integer part of x, and y = x-[x]
            // ===>
            // We find an approximation of y, called z.
            // ===>
            // exp(x) = exp([x])*(exp(z)*exp(epsilon))
            // with epsilon = y - z
            // ===>
            // We have exp([x]) and exp(z) pre-computed in tables, we "just" have to compute exp(epsilon).
            //
            // We use the same indexing (cast to int) to compute x integer part and the
            // table index corresponding to z, to avoid two int casts.
            // Also, to optimize index multiplication and division, we use powers of two,
            // so that we can do it with bits shifts.

            if (value > EXP_OVERFLOW_LIMIT) {
                return Double.POSITIVE_INFINITY;
            } else if (!(value >= EXP_UNDERFLOW_LIMIT)) {
                return (value != value) ? Double.NaN : 0.0;
            }

            final int indexes = (int) (value * EXP_LO_INDEXING);

            final int valueInt;
            if (indexes >= 0) {
                valueInt = (indexes >> EXP_LO_INDEXING_DIV_SHIFT);
            } else {
                valueInt = -((-indexes) >> EXP_LO_INDEXING_DIV_SHIFT);
            }
            final double hiTerm = MyTExp.expHiTab[valueInt - (int) EXP_UNDERFLOW_LIMIT];

            final int zIndex = indexes - (valueInt << EXP_LO_INDEXING_DIV_SHIFT);
            final double y = (value - valueInt);
            final double z = zIndex * (1.0 / EXP_LO_INDEXING);
            final double eps = y - z;
            final double expZ = MyTExp.expLoPosTab[zIndex + EXP_LO_TAB_MID_INDEX];
            final double expEps =
                    (1 + eps * (1 + eps * (1.0 / 2 + eps * (1.0 / 6 + eps * (1.0 / 24)))));
            final double loTerm = expZ * expEps;

            return hiTerm * loTerm;
        }

        /**
         * Much more accurate than exp(value)-1, for arguments (and results) close to zero.
         *
         * @param value A double value.
         * @return e^value-1.
         */
        static double expm1(final double value) {
            // If value is far from zero, we use exp(value)-1.
            //
            // If value is close to zero, we use the following formula:
            // exp(value)-1
            // = exp(valueApprox)*exp(epsilon)-1
            // = exp(valueApprox)*(exp(epsilon)-exp(-valueApprox))
            // = exp(valueApprox)*(1+epsilon+epsilon^2/2!+...-exp(-valueApprox))
            // = exp(valueApprox)*((1-exp(-valueApprox))+epsilon+epsilon^2/2!+...)
            // exp(valueApprox) and exp(-valueApprox) being stored in tables.

            if (Math.abs(value) < EXP_LO_DISTANCE_TO_ZERO) {
                // Taking int part instead of rounding, which takes too long.
                int i = (int) (value * EXP_LO_INDEXING);
                double delta = value - i * (1.0 / EXP_LO_INDEXING);
                return MyTExp.expLoPosTab[i + EXP_LO_TAB_MID_INDEX] * (MyTExp.expLoNegTab[i
                        + EXP_LO_TAB_MID_INDEX]
                        + delta * (1 + delta * (1.0 / 2
                                + delta * (1.0 / 6 + delta * (1.0 / 24 + delta * (1.0 / 120))))));
            } else {
                return exp(value) - 1;
            }
        }

        /**
         * @param value A double value.
         * @return Value logarithm (base e).
         */
        static double log(double value) {
            if (value > 0.0) {
                if (value == Double.POSITIVE_INFINITY) {
                    return Double.POSITIVE_INFINITY;
                }

                // For normal values not close to 1.0, we use the following formula:
                // log(value)
                // = log(2^exponent*1.mantissa)
                // = log(2^exponent) + log(1.mantissa)
                // = exponent * log(2) + log(1.mantissa)
                // = exponent * log(2) + log(1.mantissaApprox) + log(1.mantissa/1.mantissaApprox)
                // = exponent * log(2) + log(1.mantissaApprox) + log(1+epsilon)
                // = exponent * log(2) + log(1.mantissaApprox) + epsilon-epsilon^2/2+epsilon^3/3-epsilon^4/4+...
                // with:
                // 1.mantissaApprox <= 1.mantissa,
                // log(1.mantissaApprox) in table,
                // epsilon = (1.mantissa/1.mantissaApprox)-1
                //
                // To avoid bad relative error for small results,
                // values close to 1.0 are treated aside, with the formula:
                // log(x) = z*(2+z^2*((2.0/3)+z^2*((2.0/5))+z^2*((2.0/7))+...)))
                // with z=(x-1)/(x+1)

                double h;
                if (value > 0.95) {
                    if (value < 1.14) {
                        double z = (value - 1.0) / (value + 1.0);
                        double z2 = z * z;
                        return z * (2 + z2 * ((2.0 / 3) + z2 * ((2.0 / 5)
                                + z2 * ((2.0 / 7) + z2 * ((2.0 / 9) + z2 * ((2.0 / 11)))))));
                    }
                    h = 0.0;
                } else if (value < DOUBLE_MIN_NORMAL) {
                    // Ensuring value is normal.
                    value *= TWO_POW_52;
                    // log(x*2^52)
                    // = log(x)-ln(2^52)
                    // = log(x)-52*ln(2)
                    h = -52 * LOG_2;
                } else {
                    h = 0.0;
                }

                int valueBitsHi = (int) (Double.doubleToRawLongBits(value) >> 32);
                int valueExp = (valueBitsHi >> 20) - MAX_DOUBLE_EXPONENT;
                // Getting the first LOG_BITS bits of the mantissa.
                int xIndex = ((valueBitsHi << 12) >>> (32 - LOG_BITS));

                // 1.mantissa/1.mantissaApprox - 1
                double z = (value * twoPowNormalOrSubnormal(-valueExp)) * MyTLog.logXInvTab[xIndex]
                        - 1;

                z *= (1 - z * ((1.0 / 2) - z * ((1.0 / 3))));

                return h + valueExp * LOG_2 + (MyTLog.logXLogTab[xIndex] + z);

            } else if (value == 0.0) {
                return Double.NEGATIVE_INFINITY;
            } else { // value < 0.0, or value is NaN
                return Double.NaN;
            }
        }

        /**
         * Much more accurate than log(1+value), for arguments (and results) close to zero.
         *
         * @param value A double value.
         * @return Logarithm (base e) of (1+value).
         */
        static double log1p(final double value) {
            if (value > -1.0) {
                if (value == Double.POSITIVE_INFINITY) {
                    return Double.POSITIVE_INFINITY;
                }

                // ln'(x) = 1/x
                // so
                // log(x+epsilon) ~= log(x) + epsilon/x
                //
                // Let u be 1+value rounded:
                // 1+value = u+epsilon
                //
                // log(1+value)
                // = log(u+epsilon)
                // ~= log(u) + epsilon/value
                // We compute log(u) as done in log(double), and then add the corrective term.

                double valuePlusOne = 1.0 + value;
                if (valuePlusOne == 1.0) {
                    return value;
                } else if (Math.abs(value) < 0.15) {
                    double z = value / (value + 2.0);
                    double z2 = z * z;
                    return z * (2 + z2 * ((2.0 / 3) + z2 * ((2.0 / 5)
                            + z2 * ((2.0 / 7) + z2 * ((2.0 / 9) + z2 * ((2.0 / 11)))))));
                }

                int valuePlusOneBitsHi =
                        (int) (Double.doubleToRawLongBits(valuePlusOne) >> 32) & 0x7FFFFFFF;
                int valuePlusOneExp = (valuePlusOneBitsHi >> 20) - MAX_DOUBLE_EXPONENT;
                // Getting the first LOG_BITS bits of the mantissa.
                int xIndex = ((valuePlusOneBitsHi << 12) >>> (32 - LOG_BITS));

                // 1.mantissa/1.mantissaApprox - 1
                double z = (valuePlusOne * twoPowNormalOrSubnormal(-valuePlusOneExp))
                        * MyTLog.logXInvTab[xIndex] - 1;

                z *= (1 - z * ((1.0 / 2) - z * (1.0 / 3)));

                // Adding epsilon/valuePlusOne to z,
                // with
                // epsilon = value - (valuePlusOne-1)
                // (valuePlusOne + epsilon ~= 1+value (not rounded))

                return valuePlusOneExp * LOG_2 + MyTLog.logXLogTab[xIndex]
                        + (z + (value - (valuePlusOne - 1)) / valuePlusOne);
            } else if (value == -1.0) {
                return Double.NEGATIVE_INFINITY;
            } else { // value < -1.0, or value is NaN
                return Double.NaN;
            }
        }

        /**
         * @param power Must be in normal or subnormal values range.
         */
        private static double twoPowNormalOrSubnormal(final int power) {
            if (power <= -MAX_DOUBLE_EXPONENT) { // Not normal.
                return Double.longBitsToDouble(
                    0x0008000000000000L >> (-(power + MAX_DOUBLE_EXPONENT)));
            } else { // Normal.
                return Double.longBitsToDouble(((long) (power + MAX_DOUBLE_EXPONENT)) << 52);
            }
        }

        /**
         * Returns the exact result, provided it's in double range, i.e. if power is in
         * [-1074,1023].
         *
         * @param power An int power.
         * @return 2^power as a double, or +-Infinity in case of overflow.
         */
        private static double twoPow(final int power) {
            if (power <= -MAX_DOUBLE_EXPONENT) { // Not normal.
                if (power >= MIN_DOUBLE_EXPONENT) { // Subnormal.
                    return Double.longBitsToDouble(
                        0x0008000000000000L >> (-(power + MAX_DOUBLE_EXPONENT)));
                } else { // Underflow.
                    return 0.0;
                }
            } else if (power > MAX_DOUBLE_EXPONENT) { // Overflow.
                return Double.POSITIVE_INFINITY;
            } else { // Normal.
                return Double.longBitsToDouble(((long) (power + MAX_DOUBLE_EXPONENT)) << 52);
            }
        }

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy