smile.stat.distribution.HyperGeometricDistribution Maven / Gradle / Ivy
* Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
* Smile is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
* Smile is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* GNU General Public License for more details.
* You should have received a copy of the GNU General Public License
* along with Smile. If not, see .
package smile.stat.distribution;
import smile.math.MathEx;
import java.io.Serial;
import static smile.math.MathEx.lchoose;
import static smile.math.MathEx.lfactorial;
* The hypergeometric distribution is a discrete probability distribution that
* describes the number of successes in a sequence of n draws from a finite
* population without replacement, just as the binomial distribution describes
* the number of successes for draws with replacement.
* Suppose you are to draw "n" balls without replacement from an urn containing
* "N" balls in total, "m" of which are white. The hypergeometric distribution
* describes the distribution of the number of white balls drawn from the urn.
* @author Haifeng Li
public class HyperGeometricDistribution extends DiscreteDistribution {
private static final long serialVersionUID = 2L;
/** The number of total samples. */
public final int N;
/** The number of defects. */
public final int m;
/** The number of draws. */
public final int n;
/** The random number generator. */
private RandomNumberGenerator rng;
* Constructor.
* @param N the number of total samples.
* @param m the number of defects.
* @param n the number of draws.
public HyperGeometricDistribution(int N, int m, int n) {
if (N < 0) {
throw new IllegalArgumentException("Invalid N: " + N);
if (m < 0 || m > N) {
throw new IllegalArgumentException("Invalid m: " + m);
if (n < 0 || n > N) {
throw new IllegalArgumentException("Invalid n: " + n);
this.N = N;
this.m = m;
this.n = n;
public int length() {
return 3;
public double mean() {
return (double) m * n / N;
public double variance() {
double r = (double) m / N;
return n * (N - n) * r * (1 - r) / (N - 1);
public double entropy() {
throw new UnsupportedOperationException("Hypergeometric distribution does not support entropy()");
public String toString() {
return String.format("Hypergeometric Distribution(%d, %d, %d)", N, m, n);
public double p(int k) {
if (k < Math.max(0, m + n - N) || k > Math.min(m, n)) {
return 0.0;
} else {
return Math.exp(logp(k));
public double logp(int k) {
if (k < Math.max(0, m + n - N) || k > Math.min(m, n)) {
} else {
return lchoose(m, k) + lchoose(N - m, n - k) - lchoose(N, n);
public double cdf(double k) {
int L = Math.max(0, m + n - N);
if (k < L) {
return 0.0;
} else if (k >= Math.min(m, n)) {
return 1.0;
double p = 0.0;
for (int i = L; i <= k; i++) {
p += p(i);
return p;
public double quantile(double p) {
if (p < 0.0 || p > 1.0) {
throw new IllegalArgumentException("Invalid p: " + p);
if (p == 0.0) {
return Math.max(0, m+n-N);
if (p == 1.0) {
return Math.min(m,n);
// Starting guess near peak of density.
// Expand interval until we bracket.
int kl, ku, inc = 1;
int k = Math.max(0, Math.min(n, (int) (n * p)));
if (p < cdf(k)) {
do {
k = Math.max(k - inc, 0);
inc *= 2;
} while (p < cdf(k) && k > 0);
kl = k;
ku = k + inc / 2;
} else {
do {
k = Math.min(k + inc, n + 1);
inc *= 2;
} while (p > cdf(k));
ku = k;
kl = k - inc / 2;
return quantile(p, kl, ku);
* Uses inversion by chop-down search from the mode when the {@code mean < 20}
* and the patchwork-rejection method when the {@code mean >= 20}.
public double rand() {
if (rng == null) {
int mm = m;
int nn = n;
if (mm > N / 2) {
// invert mm
mm = N - mm;
if (nn > N / 2) {
// invert nn
nn = N - nn;
if ((double) nn * mm >= 20 * N) {
// use ratio-of-uniforms method
rng = new Patchwork(N, m, n);
} else {
// inversion method, using chop-down search from mode
rng = new Inversion(N, m, n);
return rng.rand();
abstract static class RandomNumberGenerator {
protected final int N, m, n;
protected int fak;
protected int addd;
RandomNumberGenerator(int N, int m, int n) {
// transformations
fak = 1; // used for undoing transformations
addd = 0;
if (m > N / 2) {
// invert mm
m = N - m;
fak = -1;
addd = n;
if (n > N / 2) {
// invert nn
n = N - n;
addd += fak * m;
fak = -fak;
if (n > m) {
// swap n and m
int swap = n;
n = m;
m = swap;
this.N = N;
this.m = m;
this.n = n;
public int rand() {
// cases with only one possible result end here
if (n == 0) {
return addd;
int x = random();
// undo transformations
return x * fak + addd;
protected abstract int random();
static class Patchwork extends RandomNumberGenerator {
private final int L, k1, k2, k4, k5;
private final double dl, dr, r1, r2, r4, r5, ll, lr, cPm, f1, f2, f4, f5, p1, p2, p3, p4, p5, p6;
* Initialize random number generator.
Patchwork(int N, int mm, int nn) {
super(N, mm, nn);
double Mp, np, p, modef, U; // (X, Y) <-> (V, W)
Mp = m + 1;
np = n + 1;
L = N - m - n;
p = Mp / (N + 2.);
modef = np * p;
// approximate deviation of reflection points k2, k4 from modef - 1/2
U = Math.sqrt(modef * (1. - p) * (1. - (n + 2.) / (N + 3.)) + 0.25);
// mode, reflection points k2 and k4, and points k1 and k5, which
// delimit the centre region of h(x)
// k2 = ceil (modef - 1/2 - U), k1 = 2*k2 - (mode - 1 + delta_ml)
// k4 = floor(modef - 1/2 + U), k5 = 2*k4 - (mode + 1 - delta_mr)
int mode = (int) modef;
int ceil = (int) Math.ceil(modef - 0.5 - U);
k2 = ceil >= mode ? mode - 1 : ceil;
k4 = (int) (modef - 0.5 + U);
k1 = k2 + k2 - mode + 1; // delta_ml = 0
k5 = k4 + k4 - mode; // delta_mr = 1
// range width of the critical left and right centre region
dl = k2 - k1;
dr = k5 - k4;
// recurrence constants r(k) = p(k)/p(k-1) at k = k1, k2, k4+1, k5+1
r1 = (np / (double) k1 - 1.) * (Mp - k1) / (double) (L + k1);
r2 = (np / (double) k2 - 1.) * (Mp - k2) / (double) (L + k2);
r4 = (np / (double) (k4 + 1) - 1.) * (m - k4) / (double) (L + k4 + 1);
r5 = (np / (double) (k5 + 1) - 1.) * (m - k5) / (double) (L + k5 + 1);
// reciprocal values of the scale parameters of expon. tail envelopes
ll = Math.log(r1); // expon. tail left
lr = -Math.log(r5); // expon. tail right
// hypergeom. constant, necessary for computing function values f(k)
cPm = lnpk(mode, L, m, n);
// function values f(k) = p(k)/p(mode) at k = k2, k4, k1, k5
f2 = Math.exp(cPm - lnpk(k2, L, m, n));
f4 = Math.exp(cPm - lnpk(k4, L, m, n));
f1 = Math.exp(cPm - lnpk(k1, L, m, n));
f5 = Math.exp(cPm - lnpk(k5, L, m, n));
// area of the two centre and the two exponential tail regions
// area of the two immediate acceptance regions between k2, k4
p1 = f2 * (dl + 1.); // immed. left
p2 = f2 * dl + p1; // centre left
p3 = f4 * (dr + 1.) + p2; // immed. right
p4 = f4 * dr + p3; // centre right
p5 = f1 / ll + p4; // expon. tail left
p6 = f5 / lr + p5; // expon. tail right
* This method is valid only for {@code mode >= 10} and {@code 0 <= nn <= mm <= N/2}.
* This method is fast when called repeatedly with the same parameters, but
* slow when the parameters change due to a high setup time. The computation
* time hardly depends on the parameters, except that it matters a lot whether
* parameters are within the range where the LnFac function is tabulated.
* Uses the Patchwork Rejection method of Heinz Zechner (HPRS).
* The area below the histogram function f(x) in its body is rearranged by
* two point reflections. Within a large center interval variates are sampled
* efficiently by rejection from uniform hats. Rectangular immediate acceptance
* regions speed up the generation. The remaining tails are covered by
* exponential functions.
* For detailed explanation, see:
* Stadlober, E & Zechner, H: "The Patchwork Rejection Technique for
* Sampling from Unimodal Distributions". ACM Transactions on Modeling
* and Computer Simulation, vol. 9, no. 1, 1999, p. 59-83.
protected int random() {
int Dk, X, V;
double U, Y, W; // (X, Y) <-> (V, W)
while (true) {
// generate uniform number U -- U(0, p6)
// case distinction corresponding to U
if ((U = MathEx.random() * p6) < p2) { // centre left
// immediate acceptance region R2 = [k2, mode) *[0, f2), X = k2, ... mode -1
if ((W = U - p1) < 0.) {
return (k2 + (int) (U / f2));
// immediate acceptance region R1 = [k1, k2)*[0, f1), X = k1, ... k2-1
if ((Y = W / dl) < f1) {
return (k1 + (int) (W / f1));
// computation of candidate X < k2, and its reflected counterpart V > k2
// either squeeze-acceptance of X or acceptance-rejection of V
Dk = (int) (dl * MathEx.random()) + 1;
if (Y <= f2 - Dk * (f2 - f2 / r2)) { // quick accept of
return (k2 - Dk); // X = k2 - Dk
if ((W = f2 + f2 - Y) < 1.) { // quick reject of V
V = k2 + Dk;
if (W <= f2 + Dk * (1. - f2) / (dl + 1.)) { // quick accept of V
return (V);
if (Math.log(W) <= cPm - lnpk(V, L, m, n)) {
return (V); // final accept of V
X = k2 - Dk; // go to final accept/reject
} else if (U < p4) { // centre right
// immediate acceptance region R3 = [mode, k4+1)*[0, f4), X = mode, ... k4
if ((W = U - p3) < 0.) {
return (k4 - (int) ((U - p2) / f4));
// immediate acceptance region R4 = [k4+1, k5+1)*[0, f5)
if ((Y = W / dr) < f5) {
return (k5 - (int) (W / f5));
// computation of candidate X > k4, and its reflected counterpart V < k4
// either squeeze-acceptance of X or acceptance-rejection of V
Dk = (int) (dr * MathEx.random()) + 1;
if (Y <= f4 - Dk * (f4 - f4 * r4)) { // quick accept of
return (k4 + Dk); // X = k4 + Dk
if ((W = f4 + f4 - Y) < 1.) { // quick reject of V
V = k4 - Dk;
if (W <= f4 + Dk * (1. - f4) / dr) { // quick accept of
return V; // V = k4 - Dk
if (Math.log(W) <= cPm - lnpk(V, L, m, n)) {
return (V); // final accept of V
X = k4 + Dk; // go to final accept/reject
} else {
Y = MathEx.random();
if (U < p5) { // expon. tail left
Dk = (int) (1. - Math.log(Y) / ll);
if ((X = k1 - Dk) < 0) {
continue; // 0 <= X <= k1 - 1
Y *= (U - p4) * ll; // Y -- U(0, h(x))
if (Y <= f1 - Dk * (f1 - f1 / r1)) {
return X; // quick accept of X
} else { // expon. tail right
Dk = (int) (1. - Math.log(Y) / lr);
if ((X = k5 + Dk) > n) {
continue; // k5 + 1 <= X <= nn
Y *= (U - p5) * lr; // Y -- U(0, h(x))
if (Y <= f5 - Dk * (f5 - f5 * r5)) {
return X; // quick accept of X
// acceptance-rejection test of candidate X from the original area
// test, whether Y <= f(X), with Y = U*h(x) and U -- U(0, 1)
// log f(X) = log( mode! (mm - mode)! (nn - mode)! (N - mm - nn + mode)! )
// - log( X! (mm - X)! (nn - X)! (N - mm - nn + X)! )
if (Math.log(Y) <= cPm - lnpk(X, L, m, n)) {
return (X);
* subfunction used by random number generator.
private double lnpk(int k, int L, int m, int n) {
return lfactorial(k) + lfactorial(m - k) + lfactorial(n - k) + lfactorial(L + k);
static class Inversion extends RandomNumberGenerator {
private int mode;
private final int mp; // Mode, mode+1
private int bound; // Safety upper bound
private final double fm; // Value at mode
* Initialize random number generator.
Inversion(int N, int mm, int nn) {
super(N, mm, nn);
int L = N - m - n; // Parameter
double Mp = m + 1;
double np = n + 1;
double p = Mp / (N + 2.);
double modef = np * p; // mode, real
mode = (int) modef; // mode, integer
if (mode == modef && p == 0.5) {
mp = mode--;
} else {
mp = mode + 1;
// mode probability, using log factorial function
// (may read directly from fac_table if N < FAK_LEN)
fm = Math.exp(lfactorial(N - m)
- lfactorial(L + mode)
- lfactorial(n - mode)
+ lfactorial(m)
- lfactorial(m - mode)
- lfactorial(mode)
- lfactorial(N)
+ lfactorial(N - n)
+ lfactorial(n));
// safety bound - guarantees at least 17 significant decimal digits
// bound = min(nn, (int)(modef + k*c'))
bound = (int) (modef + 11. * Math.sqrt(modef * (1. - p) * (1. - n / (double) N) + 1.));
if (bound > n) {
bound = n;
* Hypergeometric distribution by inversion method, using down-up
* search starting at the mode using the chop-down technique.
* Assumes {@code 0 <= n <= m <= N/2}.
* Overflow protection is needed when N > 680 or n > 75.
* This method is faster than the rejection method when the variance is low.
protected int random() {
double L = N - m - n; // Parameter
double Mp, np; // mm + 1, nn + 1
double U; // uniform random
double c, d; // factors in iteration
double divisor; // divisor, eliminated by scaling
double k1, k2; // float version of loop counter
Mp = m + 1;
np = n + 1;
// loop until accepted
while (true) {
U = MathEx.random(); // uniform random number to be converted
// start chop-down search at mode
if ((U -= fm) <= 0.) {
return (mode);
c = d = fm;
// alternating down- and upward search from the mode
k1 = mp - 1;
k2 = mode + 1;
for (int i = 1; i <= mode; i++, k1--, k2++) {
// Downward search from k1 = hyp_mp - 1
divisor = (np - k1) * (Mp - k1);
// Instead of dividing c with divisor, we multiply U and d because
// multiplication is faster. This will give overflow if N > 800
U *= divisor;
d *= divisor;
c *= k1 * (L + k1);
if ((U -= c) <= 0.) {
return (mp - i - 1); // = k1 - 1
// Upward search from k2 = hyp_mode + 1
divisor = k2 * (L + k2);
// re-scale parameters to avoid time-consuming division
U *= divisor;
c *= divisor;
d *= (np - k2) * (Mp - k2);
if ((U -= d) <= 0.) {
return (mode + i); // = k2
} // Values of nn > 75 or N > 680 may give overflow if leave out this.
// overflow protection
if (U > 1.E100) {U *= 1.E-100; c *= 1.E-100; d *= 1.E-100;}
// Upward search from k2 = 2*mode + 1 to bound\
k2 = mp + mode;
for (int i = mp + mode; i <= bound; i++, k2++) {
divisor = k2 * (L + k2);
U *= divisor;
d *= (np - k2) * (Mp - k2);
if ((U -= d) <= 0.) {
return i;
// more overflow protection
if (U > 1.E100) {U *= 1.E-100; d *= 1.E-100;}