All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.berkeley.SloppyMath Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.berkeley;


import java.util.List;
import java.util.Map;

/**
 * The class SloppyMath contains methods for performing basic
 * numeric operations. In some cases, such as max and min, they cut a few
 * corners in the implementation for the sake of efficiency. In particular, they
 * may not handle special notions like NaN and -0.0 correctly. This was the
 * origin of the class name, but some other operations are just useful math
 * additions, such as logSum.
 * 
 * @author Christopher Manning
 * @version 2003/01/02
 */
public final class SloppyMath {

    private SloppyMath() {
    }

    public static double abs(double x) {
		    if (x > 0)
		      return x;
		    return -1.0 * x;
		  }

			public static double lambert(double v, double u){
				double x = -(Math.log(-v)+u);//-Math.log(-z);
				double w = -x;
				double diff=1;
				while (Math.abs(diff)<1.0e-5){
				  double z = -x -Math.log(Math.abs(w));
				  diff = z-w;
				  w = z;
				}
				return w;

				/*
				//Use asymptotic expansion w = log(z) - log(log(z)) for most z
				double summand = (z==0) ? 1 : 0;
				double tmp = Math.log(z+summand);// + i*b*6.28318530717958648;
				double w = tmp - Math.log(tmp + summand);

				//For b = 0, use a series expansion when close to the branch point
				//k = find(b == 0 & abs(z + 0.3678794411714423216) <= 1.5);
				tmp = Math.sqrt(5.43656365691809047*z + 2) - 1;// + i*b*6.28318530717958648;
				//w(k) = tmp(k);
		    w = tmp;
		    
				for (int k=1; k<36; k++){
					// Converge with Halley's iterations, about 5 iterations satisfies
					//the tolerance for most z
					double c1 = Math.exp(w);
					double c2 = w*c1 - z;
					summand = (w != -1) ? 1 : 0;
					double w1 = w + summand;
					double dw = c2/(c1*w1 - ((w + 2)*c2/(2*w1)));
					w = w - dw;
			
				  if (Math.abs(dw) < 0.7e-16*(2+Math.abs(w)))
				     break;
				}
				return w;*/
 	}

  /**
	 * Returns the minimum of three int values.
	 */
  public static int max(int a, int b, int c) {
    int ma;
    ma = a;
    if (b > ma) {
      ma = b;
    }
    if (c > ma) {
      ma = c;
    }
    return ma;
  }


  /**
	 * Returns the minimum of three int values.
	 */
  public static int min(int a, int b, int c) {
    int mi;

    mi = a;
    if (b < mi) {
      mi = b;
    }
    if (c < mi) {
      mi = c;
    }
    return mi;

  }


  /**
	 * Returns the greater of two float values. That is, the
	 * result is the argument closer to positive infinity. If the arguments have
	 * the same value, the result is that same value. Does none of the special
	 * checks for NaN or -0.0f that Math.max does.
	 * 
	 * @param a
	 *            an argument.
	 * @param b
	 *            another argument.
	 * @return the larger of a and b.
	 */
  public static float max(float a, float b) {
    return (a >= b) ? a : b;
  }


  /**
	 * Returns the greater of two double values. That is, the
	 * result is the argument closer to positive infinity. If the arguments have
	 * the same value, the result is that same value. Does none of the special
	 * checks for NaN or -0.0f that Math.max does.
	 * 
	 * @param a
	 *            an argument.
	 * @param b
	 *            another argument.
	 * @return the larger of a and b.
	 */
  public static double max(double a, double b) {
    return (a >= b) ? a : b;
  }


  /**
	 * Returns the smaller of two float values. That is, the
	 * result is the value closer to negative infinity. If the arguments have
	 * the same value, the result is that same value. Does none of the special
	 * checks for NaN or -0.0f that Math.max does.
	 * 
	 * @param a
	 *            an argument.
	 * @param b
	 *            another argument.
	 * @return the smaller of a and b.
	 */
  public static float min(float a, float b) {
    return (a <= b) ? a : b;
  }


  /**
	 * Returns the smaller of two double values. That is, the
	 * result is the value closer to negative infinity. If the arguments have
	 * the same value, the result is that same value. Does none of the special
	 * checks for NaN or -0.0f that Math.max does.
	 * 
	 * @param a
	 *            an argument.
	 * @param b
	 *            another argument.
	 * @return the smaller of a and b.
	 */
  public static double min(double a, double b) {
    return (a <= b) ? a : b;
  }


  /**
	 * Returns true if the argument is a "dangerous" double to have around,
	 * namely one that is infinite, NaN or zero.
	 */
  public static boolean isDangerous(double d) {
    return Double.isInfinite(d) || Double.isNaN(d) || d == 0.0;
  }
  public static boolean isDangerous(float d) {
    return Float.isInfinite(d) || Float.isNaN(d) || d == 0.0;
  }

  public static boolean isGreater(double x, double y) {
	    if (x>1) return (((x-y) / x) > -0.01);
	  	return ((x-y) > -0.0001);
  }


  /**
	 * Returns true if the argument is a "very dangerous" double to have around,
	 * namely one that is infinite or NaN.
	 */
  public static boolean isVeryDangerous(double d) {
    return Double.isInfinite(d) || Double.isNaN(d);
  }

  public static double relativeDifferance(double a, double b) {
      a = Math.abs(a);
      b = Math.abs(b);
      double absMin = Math.min(a,b);
      return Math.abs(a-b) / absMin;      
  }

  public static boolean isDiscreteProb(double d, double tol)
  {
	  return d >=0.0 && d <= 1.0 + tol;
  }
  

  /**
	 * If a difference is bigger than this in log terms, then the sum or
	 * difference of them will just be the larger (to 12 or so decimal places
	 * for double, and 7 or 8 for float).
	 */
  public static final double LOGTOLERANCE = 30.0;
  static final float LOGTOLERANCE_F = 10.0f;


  /**
	 * Returns the log of the sum of two numbers, which are themselves input in
	 * log form. This uses natural logarithms. Reasonable care is taken to do
	 * this as efficiently as possible (under the assumption that the numbers
	 * might differ greatly in magnitude), with high accuracy, and without
	 * numerical overflow. Also, handle correctly the case of arguments being
	 * -Inf (e.g., probability 0).
	 * 
	 * @param lx
	 *            First number, in log form
	 * @param ly
	 *            Second number, in log form
	 * @return log(exp(lx) + exp(ly))
	 */
  public static float logAdd(float lx, float ly) {
    float max, negDiff;
    if (lx > ly) {
      max = lx;
      negDiff = ly - lx;
    } else {
      max = ly;
      negDiff = lx - ly;
    }
    if (max == Double.NEGATIVE_INFINITY || negDiff < -LOGTOLERANCE_F) {
      return max;
    } else {
      return max + (float)Math.log(1.0f + Math.exp(negDiff));
    }
  }


  /**
	 * Returns the log of the sum of two numbers, which are themselves input in
	 * log form. This uses natural logarithms. Reasonable care is taken to do
	 * this as efficiently as possible (under the assumption that the numbers
	 * might differ greatly in magnitude), with high accuracy, and without
	 * numerical overflow. Also, handle correctly the case of arguments being
	 * -Inf (e.g., probability 0).
	 * 
	 * @param lx
	 *            First number, in log form
	 * @param ly
	 *            Second number, in log form
	 * @return log(exp(lx) + exp(ly))
	 */
  public static double logAdd(double lx, double ly) {
    double max, negDiff;
    if (lx > ly) {
      max = lx;
      negDiff = ly - lx;
    } else {
      max = ly;
      negDiff = lx - ly;
    }
    if (max == Double.NEGATIVE_INFINITY || negDiff < -LOGTOLERANCE) {
      return max;
    } else {
      return max + Math.log(1.0 + Math.exp(negDiff));
    }
  }

  public static double logAdd(float[] logV) {
    double maxIndex = 0;
    double max = Double.NEGATIVE_INFINITY;
    for (int i = 0; i < logV.length; i++) {
      if (logV[i] > max) {
        max = logV[i];
        maxIndex = i;
      }
    }
    if (max == Double.NEGATIVE_INFINITY) return Double.NEGATIVE_INFINITY;
    // compute the negative difference
    double threshold = max - LOGTOLERANCE;
    double sumNegativeDifferences = 0.0;
    for (int i = 0; i < logV.length; i++) {
      if (i != maxIndex && logV[i] > threshold) {
        sumNegativeDifferences += Math.exp(logV[i] - max);
      }
    }
    if (sumNegativeDifferences > 0.0) {
      return max + Math.log(1.0 + sumNegativeDifferences);
    } else {
      return max;
    }
  }

  public static void logNormalize(double[] logV) {
      double logSum = logAdd(logV);      
      if (Double.isNaN(logSum)) {
        throw new RuntimeException("Bad log-sum");
      }
      if (logSum == 0.0) return;
      for (int i = 0; i < logV.length; i++) {          
        logV[i] -= logSum;
      }
  }

  public static double logAdd(double[] logV) {
    double maxIndex = 0;
    double max = Double.NEGATIVE_INFINITY;
    for (int i = 0; i < logV.length; i++) {
      if (logV[i] > max) {
        max = logV[i];
        maxIndex = i;
      }
    }
    if (max == Double.NEGATIVE_INFINITY) return Double.NEGATIVE_INFINITY;
    // compute the negative difference
    double threshold = max - LOGTOLERANCE;
    double sumNegativeDifferences = 0.0;
    for (int i = 0; i < logV.length; i++) {
      if (i != maxIndex && logV[i] > threshold) {
        sumNegativeDifferences += Math.exp(logV[i] - max);
      }
    }
    if (sumNegativeDifferences > 0.0) {
      return max + Math.log(1.0 + sumNegativeDifferences);
    } else {
      return max;
    }
  }

  public static double logAdd(List logV) {
	    double max = Double.NEGATIVE_INFINITY;
	    double maxIndex = 0;
	    for (int i = 0; i < logV.size(); i++) {
	      if (logV.get(i) > max) {
	        max = logV.get(i);
	        maxIndex = i;
	      }
	    }
	    if (max == Double.NEGATIVE_INFINITY) return Double.NEGATIVE_INFINITY;
	    // compute the negative difference
	    double threshold = max - LOGTOLERANCE;
	    double sumNegativeDifferences = 0.0;
	    for (int i = 0; i < logV.size(); i++) {
	      if (i != maxIndex && logV.get(i) > threshold) {
	        sumNegativeDifferences += Math.exp(logV.get(i) - max);
	      }
	    }
	    if (sumNegativeDifferences > 0.0) {
	      return max + Math.log(1.0 + sumNegativeDifferences);
	    } else {
	      return max;
	    }
	  }

  
  public static float logAdd_Old(float[] logV) {
    float max = Float.NEGATIVE_INFINITY;
    float maxIndex = 0;
    for (int i = 0; i < logV.length; i++) {
      if (logV[i] > max) {
        max = logV[i];
        maxIndex = i;
      }
    }
    if (max == Float.NEGATIVE_INFINITY) return Float.NEGATIVE_INFINITY;
    // compute the negative difference
    float threshold = max - LOGTOLERANCE_F;
    float sumNegativeDifferences = 0.0f;
    for (int i = 0; i < logV.length; i++) {
      if (i != maxIndex && logV[i] > threshold) {
        sumNegativeDifferences += Math.exp(logV[i] - max);
      }
    }
    if (sumNegativeDifferences > 0.0) {
      return max + (float) Math.log(1.0f + sumNegativeDifferences);
    } else {
      return max;
    }
  }

  /*
	 * adds up the entries logV[0], logV[1], ... , logV[lastIndex-1]
	 */
  public static float logAdd(float[] logV, int lastIndex) {
  	if (lastIndex==0) return Float.NEGATIVE_INFINITY;
  	float max = Float.NEGATIVE_INFINITY;
    float maxIndex = 0;
    for (int i = 0; i < lastIndex; i++) {
      if (logV[i] > max) {
        max = logV[i];
        maxIndex = i;
      }
    }
    if (max == Float.NEGATIVE_INFINITY) return Float.NEGATIVE_INFINITY;
    // compute the negative difference
    float threshold = max - LOGTOLERANCE_F;
    double sumNegativeDifferences = 0.0;
    for (int i = 0; i < lastIndex; i++) {
      if (i != maxIndex && logV[i] > threshold) {
        sumNegativeDifferences += Math.exp((logV[i] - max));
      }
    }
    if (sumNegativeDifferences > 0.0) {
      return max + (float) Math.log(1.0 + sumNegativeDifferences);
    } else {
      return max;
    }
  }

  /*
	 * adds up the entries logV[0], logV[1], ... , logV[lastIndex-1]
	 */
  public static double logAdd(double[] logV, int lastIndex) {
  	if (lastIndex==0) return Double.NEGATIVE_INFINITY;
  	double max = Double.NEGATIVE_INFINITY;
  	double maxIndex = 0;
    for (int i = 0; i < lastIndex; i++) {
      if (logV[i] > max) {
        max = logV[i];
        maxIndex = i;
      }
    }
    if (max == Double.NEGATIVE_INFINITY) return Double.NEGATIVE_INFINITY;
    // compute the negative difference
    double threshold = max - LOGTOLERANCE;
    double sumNegativeDifferences = 0.0;
    for (int i = 0; i < lastIndex; i++) {
      if (i != maxIndex && logV[i] > threshold) {
        sumNegativeDifferences += Math.exp((logV[i] - max));
      }
    }
    if (sumNegativeDifferences > 0.0) {
      return max + Math.log(1.0 + sumNegativeDifferences);
    } else {
      return max;
    }
  }
  /**
	 * Similar to logAdd, but without the final log. I.e. Sum_i exp(logV_i)
	 * 
	 * @param logV
	 * @return
	 */
  public static float addExp_Old(float[] logV) {
    float max = Float.NEGATIVE_INFINITY;
    float maxIndex = 0;
    for (int i = 0; i < logV.length; i++) {
      if (logV[i] > max) {
        max = logV[i];
        maxIndex = i;
      }
    }
    if (max == Float.NEGATIVE_INFINITY) return Float.NEGATIVE_INFINITY;
    // compute the negative difference
    float threshold = max - LOGTOLERANCE_F;
    float sumNegativeDifferences = 0.0f;
    for (int i = 0; i < logV.length; i++) {
      if (i != maxIndex && logV[i] > threshold) {
        sumNegativeDifferences += Math.exp(logV[i] - max);
      }
    }
    return (float) Math.exp(max) * (1.0f + sumNegativeDifferences);
  }

  /*
	 * adds up the entries logV[0], logV[1], ... , logV[lastIndex-1]
	 */
  public static float addExp(float[] logV, int lastIndex) {
  	if (lastIndex==0) return Float.NEGATIVE_INFINITY;
  	float max = Float.NEGATIVE_INFINITY;
    float maxIndex = 0;
    for (int i = 0; i < lastIndex; i++) {
      if (logV[i] > max) {
        max = logV[i];
        maxIndex = i;
      }
    }
    if (max == Float.NEGATIVE_INFINITY) return Float.NEGATIVE_INFINITY;
    // compute the negative difference
    float threshold = max - LOGTOLERANCE_F;
    float sumNegativeDifferences = 0.0f;
    for (int i = 0; i < lastIndex; i++) {
      if (i != maxIndex && logV[i] > threshold) {
        sumNegativeDifferences += Math.exp(logV[i] - max);
      }
    }
    return (float) Math.exp(max) * (1.0f + sumNegativeDifferences);
  }
  /**
	 * Computes n choose k in an efficient way. Works with k == 0 or k == n but
	 * undefined if k < 0 or k > n
	 * 
	 * @param n
	 * @param k
	 * @return fact(n) / fact(k) * fact(n-k)
	 */
  public static int nChooseK(int n, int k) {
    k = Math.min(k, n - k);
    if (k == 0) {
      return 1;
    }
    int accum = n;
    for (int i = 1; i < k; i++) {
      accum *= (n - i);
      accum /= i;
    }
    return accum / k;
  }

  /**
	 * exponentiation like we learned in grade school: multiply b by itself e
	 * times. Uses power of two trick. e must be nonnegative!!! no checking!!!
	 * 
	 * @param b
	 *            base
	 * @param e
	 *            exponent
	 * @return b^e
	 */
  public static int intPow(int b, int e) {
    if (e == 0) {
      return 1;
    }
    int result = 1;
    int currPow = b;
    do {
      if ((e & 1) == 1) result *= currPow;
      currPow = currPow * currPow;
      e >>= 1;
    } while (e > 0);
    return result;
  }

  /**
	 * exponentiation like we learned in grade school: multiply b by itself e
	 * times. Uses power of two trick. e must be nonnegative!!! no checking!!!
	 * 
	 * @param b
	 *            base
	 * @param e
	 *            exponent
	 * @return b^e
	 */
  public static float intPow(float b, int e) {
    if (e == 0) {
      return 1;
    }
    float result = 1;
    float currPow = b;
    do {
      if ((e & 1) == 1) result *= currPow;
      currPow = currPow * currPow;
      e >>= 1;
    } while (e > 0);
    return result;
  }

  /**
	 * exponentiation like we learned in grade school: multiply b by itself e
	 * times. Uses power of two trick. e must be nonnegative!!! no checking!!!
	 * 
	 * @param b
	 *            base
	 * @param e
	 *            exponent
	 * @return b^e
	 */
  public static double intPow(double b, int e) {
    if (e == 0) {
      return 1;
    }
    float result = 1;
    double currPow = b;
    do {
      if ((e & 1) == 1) result *= currPow;
      currPow = currPow * currPow;
      e >>= 1;
    } while (e > 0);
    return result;
  }

  /**
	 * Find a hypergeometric distribution. This uses exact math, trying fairly
	 * hard to avoid numeric overflow by interleaving multiplications and
	 * divisions. (To do: make it even better at avoiding overflow, by using
	 * loops that will do either a multiple or divide based on the size of the
	 * intermediate result.)
	 * 
	 * @param k
	 *            The number of black balls drawn
	 * @param n
	 *            The total number of balls
	 * @param r
	 *            The number of black balls
	 * @param m
	 *            The number of balls drawn
	 * @return The hypergeometric value
	 */
  public static double hypergeometric(int k, int n, int r, int m) {
    if (k < 0 || r > n || m > n || n <= 0 || m < 0 | r < 0) {
      throw new IllegalArgumentException("Invalid hypergeometric");
    }

    // exploit symmetry of problem
    if (m > n / 2) {
      m = n - m;
      k = r - k;
    }
    if (r > n / 2) {
      r = n - r;
      k = m - k;
    }
    if (m > r) {
      int temp = m;
      m = r;
      r = temp;
    }
    // now we have that k <= m <= r <= n/2

    if (k < (m + r) - n || k > m) {
      return 0.0;
    }

    // Do limit cases explicitly
    // It's unclear whether this is a good idea. I put it in fearing
    // numerical errors when the numbers seemed off, but actually there
    // was a bug in the Fisher's exact routine.
    if (r == n) {
      if (k == m) {
        return 1.0;
      } else {
        return 0.0;
      }
    } else if (r == n - 1) {
      if (k == m) {
        return (n - m) / (double) n;
      } else if (k == m - 1) {
        return m / (double) n;
      } else {
        return 0.0;
      }
    } else if (m == 1) {
      if (k == 0) {
        return (n - r) / (double) n;
      } else if (k == 1) {
        return r / (double) n;
      } else {
        return 0.0;
      }
    } else if (m == 0) {
      if (k == 0) {
        return 1.0;
      } else {
        return 0.0;
      }
    } else if (k == 0) {
      double ans = 1.0;
      for (int m0 = 0; m0 < m; m0++) {
        ans *= ((n - r) - m0);
        ans /= (n - m0);
      }
      return ans;
    }

    double ans = 1.0;
    // do (n-r)x...x((n-r)-((m-k)-1))/n x...x (n-((m-k-1)))
    // leaving rest of denominator to getFromOrigin to multimply by (n-(m-1))
    // that's k things which goes into next loop
    for (int nr = n - r, n0 = n; nr > (n - r) - (m - k); nr--, n0--) {
      // System.out.println("Multiplying by " + nr);
      ans *= nr;
      // System.out.println("Dividing by " + n0);
      ans /= n0;
    }
    // System.out.println("Done phase 1");
    for (int k0 = 0; k0 < k; k0++) {
      ans *= (m - k0);
      // System.out.println("Multiplying by " + (m-k0));
      ans /= ((n - (m - k0)) + 1);
      // System.out.println("Dividing by " + ((n-(m+k0)+1)));
      ans *= (r - k0);
      // System.out.println("Multiplying by " + (r-k0));
      ans /= (k0 + 1);
      // System.out.println("Dividing by " + (k0+1));
    }
    return ans;
  }


  /**
	 * Find a one tailed exact binomial test probability. Finds the chance of
	 * this or a higher result
	 * 
	 * @param k
	 *            number of successes
	 * @param n
	 *            Number of trials
	 * @param p
	 *            Probability of a success
	 */
  public static double exactBinomial(int k, int n, double p) {
    double total = 0.0;
    for (int m = k; m <= n; m++) {
      double nChooseM = 1.0;
      for (int r = 1; r <= m; r++) {
        nChooseM *= (n - r) + 1;
        nChooseM /= r;
      }
      // System.out.println(n + " choose " + m + " is " + nChooseM);
      // System.out.println("prob contribution is " +
      // (nChooseM * Math.pow(p, m) * Math.pow(1.0-p, n - m)));
      total += nChooseM * Math.pow(p, m) * Math.pow(1.0 - p, n - m);
    }
    return total;
  }


  /**
	 * Find a one-tailed Fisher's exact probability. Chance of having seen this
	 * or a more extreme departure from what you would have expected given
	 * independence. I.e., k >= the value passed in. Warning: this was done just
	 * for collocations, where you are concerned with the case of k being larger
	 * than predicted. It doesn't correctly handle other cases, such as k being
	 * smaller than expected.
	 * 
	 * @param k
	 *            The number of black balls drawn
	 * @param n
	 *            The total number of balls
	 * @param r
	 *            The number of black balls
	 * @param m
	 *            The number of balls drawn
	 * @return The Fisher's exact p-value
	 */
  public static double oneTailedFishersExact(int k, int n, int r, int m) {
    if (k < 0 || k < (m + r) - n || k > r || k > m || r > n || m > n) {
      throw new IllegalArgumentException("Invalid Fisher's exact: " + "k=" + k + " n=" + n + " r=" + r + " m=" + m + " k<0=" + (k < 0) + " k<(m+r)-n=" + (k < (m + r) - n) + " k>r=" + (k > r) + " k>m=" + (k > m) + " r>n=" + (r > n) + "m>n=" + (m > n));
    }
    // exploit symmetry of problem
    if (m > n / 2) {
      m = n - m;
      k = r - k;
    }
    if (r > n / 2) {
      r = n - r;
      k = m - k;
    }
    if (m > r) {
      int temp = m;
      m = r;
      r = temp;
    }
    // now we have that k <= m <= r <= n/2

    double total = 0.0;
    if (k > m / 2) {
      // sum from k to m
      for (int k0 = k; k0 <= m; k0++) {
        // System.out.println("Calling hypg(" + k0 + "; " + n +
        // ", " + r + ", " + m + ")");
        total += SloppyMath.hypergeometric(k0, n, r, m);
      }
    } else {
      // sum from max(0, (m+r)-n) to k-1, and then subtract from 1
      int min = Math.max(0, (m + r) - n);
      for (int k0 = min; k0 < k; k0++) {
        // System.out.println("Calling hypg(" + k0 + "; " + n +
        // ", " + r + ", " + m + ")");
        total += SloppyMath.hypergeometric(k0, n, r, m);
      }
      total = 1.0 - total;
    }
    return total;
  }


  /**
	 * Find a 2x2 chi-square value. Note: could do this more neatly using
	 * simplified formula for 2x2 case.
	 * 
	 * @param k
	 *            The number of black balls drawn
	 * @param n
	 *            The total number of balls
	 * @param r
	 *            The number of black balls
	 * @param m
	 *            The number of balls drawn
	 * @return The Fisher's exact p-value
	 */
  public static double chiSquare2by2(int k, int n, int r, int m) {
    int[][] cg = {{k, r - k}, {m - k, n - (k + (r - k) + (m - k))}};
    int[] cgr = {r, n - r};
    int[] cgc = {m, n - m};
    double total = 0.0;
    for (int i = 0; i < 2; i++) {
      for (int j = 0; j < 2; j++) {
        double exp = (double) cgr[i] * cgc[j] / n;
        total += (cg[i][j] - exp) * (cg[i][j] - exp) / exp;
      }
    }
    return total;
  }

  public static double exp(double logX) {
    // if x is very near one, use the linear approximation
    if (Math.abs(logX) < 0.001)
      return 1 + logX;
    return Math.exp(logX);
  }

  /**
	 * Tests the hypergeometric distribution code, or other cooccurrences provided
	 * in this module.
	 * 
	 * @param args
	 *            Either none, and the log add rountines are tested, or the
	 *            following 4 arguments: k (cell), n (total), r (row), m (col)
	 */
  public static void main(String[] args) {
    
    System.out.println(approxLog(0.0));
//    if (args.length == 0) {
//      System.err.println("Usage: java edu.stanford.nlp.math.SloppyMath " + "[-logAdd|-fishers k n r m|-bionomial r n p");
//    } else if (args[0].equals("-logAdd")) {
//      System.out.println("Log adds of neg infinity numbers, etc.");
//      System.out.println("(logs) -Inf + -Inf = " + logAdd(Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY));
//      System.out.println("(logs) -Inf + -7 = " + logAdd(Double.NEGATIVE_INFINITY, -7.0));
//      System.out.println("(logs) -7 + -Inf = " + logAdd(-7.0, Double.NEGATIVE_INFINITY));
//      System.out.println("(logs) -50 + -7 = " + logAdd(-50.0, -7.0));
//      System.out.println("(logs) -11 + -7 = " + logAdd(-11.0, -7.0));
//      System.out.println("(logs) -7 + -11 = " + logAdd(-7.0, -11.0));
//      System.out.println("real 1/2 + 1/2 = " + logAdd(Math.log(0.5), Math.log(0.5)));
//    } else if (args[0].equals("-fishers")) {
//      int k = Integer.parseInt(args[1]);
//      int n = Integer.parseInt(args[2]);
//      int r = Integer.parseInt(args[3]);
//      int m = Integer.parseInt(args[4]);
//      double ans = SloppyMath.hypergeometric(k, n, r, m);
//      System.out.println("hypg(" + k + "; " + n + ", " + r + ", " + m + ") = " + ans);
//      ans = SloppyMath.oneTailedFishersExact(k, n, r, m);
//      System.out.println("1-tailed Fisher's exact(" + k + "; " + n + ", " + r + ", " + m + ") = " + ans);
//      double ansChi = SloppyMath.chiSquare2by2(k, n, r, m);
//      System.out.println("chiSquare(" + k + "; " + n + ", " + r + ", " + m + ") = " + ansChi);
//
//      System.out.println("Swapping arguments should give same hypg:");
//      ans = SloppyMath.hypergeometric(k, n, r, m);
//      System.out.println("hypg(" + k + "; " + n + ", " + m + ", " + r + ") = " + ans);
//      int othrow = n - m;
//      int othcol = n - r;
//      int cell12 = m - k;
//      int cell21 = r - k;
//      int cell22 = othrow - (r - k);
//      ans = SloppyMath.hypergeometric(cell12, n, othcol, m);
//      System.out.println("hypg(" + cell12 + "; " + n + ", " + othcol + ", " + m + ") = " + ans);
//      ans = SloppyMath.hypergeometric(cell21, n, r, othrow);
//      System.out.println("hypg(" + cell21 + "; " + n + ", " + r + ", " + othrow + ") = " + ans);
//      ans = SloppyMath.hypergeometric(cell22, n, othcol, othrow);
//      System.out.println("hypg(" + cell22 + "; " + n + ", " + othcol + ", " + othrow + ") = " + ans);
//    } else if (args[0].equals("-binomial")) {
//      int k = Integer.parseInt(args[1]);
//      int n = Integer.parseInt(args[2]);
//      double p = Double.parseDouble(args[3]);
//      double ans = SloppyMath.exactBinomial(k, n, p);
//      System.out.println("Binomial p(X >= " + k + "; " + n + ", " + p + ") = " + ans);
//    }
//		else if (args[0].equals("-approxExp"))
//    {
//    	int numTrials = 0;
//    	double sumError = 0;
//    	double maxError = 0;
//    	for (double x = -700; x < 700; x += 0.1)
//    	{
//    		final double approxExp = approxExp(x);
//			final double exp = Math.exp(x);
//			double error = Math.abs((exp - approxExp) / exp);
//    		if (isVeryDangerous(error)) continue;
//    		maxError = Math.max(error,maxError);
//    		sumError += error;
//    		numTrials++;
//    	}
//    	double avgError = sumError / numTrials;
//    	System.out.println("Avg error was: " + avgError);
//    	System.out.println("Max error was: " + maxError);
//    }
//    	else if (args[0].equals("-approxLog"))
//        {
//        	int numTrials = 0;
//        	double sumError = 0;
//        	double maxError = 0;
//        	double x = Double.MIN_VALUE; 
//        	while (x < Double.MAX_VALUE)
//        	{
//				//        		if (Math.abs(x - 1) < 0.3) continue;
//        		final double approxExp = approxLog(x);
//				final double exp = Math.log(x);
//    			double error = Math.abs((exp - approxExp) / exp);
//        		if (isVeryDangerous(error)) continue;
//        		maxError = Math.max(error,maxError);
//        		sumError += error;
//        		numTrials++;
//        		
//        		if (x < Double.MIN_VALUE * 1000000)
//					x *= 4;
//        		else x *= 1.0001;
//        	}
//        	double avgError = sumError / numTrials;
//        	System.out.println("Avg error was: " + avgError);
//        	System.out.println("Max error was: " + maxError);
//        	
//      
//    } else {
//      System.err.println("Unknown option: " + args[0]);
//    }
  }
  
  public static double noNaNDivide(double num, double denom)
	{
		return denom == 0.0 ? 0.0 : num / denom;
	}

  
	public static double approxLog(double val)
	{
    if (val < 0.0) return Double.NaN;
	  if (val == 0.0) return Double.NEGATIVE_INFINITY;
		double r = val - 1;
		if (Math.abs(r) < 0.3)
		{
			// use first few terms of taylor series
			
			final double rSquared = r * r;
			return r - rSquared / 2 + rSquared * r / 3;
		}
		final double x = (Double.doubleToLongBits(val) >> 32);
		return (x - 1072632447) / 1512775;

	}

	public static double approxExp(double val)
	{

		if (Math.abs(val) < 0.1) return 1 + val;
		final long tmp = (long) (1512775 * val + (1072693248 - 60801));
		return Double.longBitsToDouble(tmp << 32);

	}

	public static double approxPow(final double a, final double b)
	{
		final int tmp = (int) (Double.doubleToLongBits(a) >> 32);
		final int tmp2 = (int) (b * (tmp - 1072632447) + 1072632447);
		return Double.longBitsToDouble(((long) tmp2) << 32);
	}
	

	public static double logSubtract(double a, double b)
	{
		if (a > b)
		{
      // logA logB
      // (logA - logB) = (log
			return a + Math.log(1.0 - Math.exp(b - a));

		}
		else
		{
			return b + Math.log(-1.0 + Math.exp(a - b));
		}
	}

  public static double unsafeSubtract(double a, double b) {
    if (a == b) { // inf - inf (or -inf - -inf)
      return 0.0;
    }
    if (a == Double.NEGATIVE_INFINITY) {
      return a;
    }
    return a-b;
  }

  public static double unsafeAdd(double a, double b) {
    if (a == b) { // inf - inf (or -inf - -inf)
      return 0.0;
    }
    if (a == Double.POSITIVE_INFINITY) {
      return a;
    }
    return a+b;
  }

  public static  double logAdd(Counter counts) {
    double[] arr = new double[counts.size()];
    int index = 0;
    for (Map.Entry entry : counts.entrySet()) {
      arr[index++] = entry.getValue();
    }
    return SloppyMath.logAdd(arr);
  }

//	public static double approxLogAdd(double a, double b)
//	{
//		
//		    final long tmp1 = (long) (1512775 * a + (1072693248 - 60801));
//		    double ea = Double.longBitsToDouble(tmp1 << 32);
//		    final long tmp2 = (long) (1512775 * b + (1072693248 - 60801));
//		    double eb = Double.longBitsToDouble(tmp2 << 32);
//		    
//		    final double x = (Double.doubleToLongBits(ea + eb) >> 32);
//		    return (x - 1072632447) / 1512775;
//		
//	}
  

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy