All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.rng.distribution.BaseDistribution Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.linalg.api.rng.distribution;

import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.solvers.UnivariateSolverUtils;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;

import java.util.Iterator;

/**
 * Base distribution derived from apache commons math
 * http://commons.apache.org/proper/commons-math/
 * 

* (specifically the {@link org.apache.commons.math3.distribution.AbstractRealDistribution} * * @author Adam Gibson */ public abstract class BaseDistribution implements Distribution { protected Random random; protected double solverAbsoluteAccuracy; public BaseDistribution(Random rng) { this.random = rng; } public BaseDistribution() { this(Nd4j.getRandom()); } /** * For a random variable {@code X} whose values are distributed according * to this distribution, this method returns {@code P(x0 < X <= x1)}. * * @param x0 Lower bound (excluded). * @param x1 Upper bound (included). * @return the probability that a random variable with this distribution * takes a value between {@code x0} and {@code x1}, excluding the lower * and including the upper endpoint. * @throws org.apache.commons.math3.exception.NumberIsTooLargeException if {@code x0 > x1}. *

* The default implementation uses the identity * {@code P(x0 < X <= x1) = P(X <= x1) - P(X <= x0)} * @since 3.1 */ public double probability(double x0, double x1) { if (x0 > x1) { throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, x0, x1, true); } return cumulativeProbability(x1) - cumulativeProbability(x0); } /** * {@inheritDoc} *

* The default implementation returns *

    *
  • {@link #getSupportLowerBound()} for {@code p = 0},
  • *
  • {@link #getSupportUpperBound()} for {@code p = 1}.
  • *
*/ @Override public double inverseCumulativeProbability(final double p) throws OutOfRangeException { /* * IMPLEMENTATION NOTES * -------------------- * Where applicable, use is made of the one-sided Chebyshev inequality * to bracket the root. This inequality states that * P(X - mu >= k * sig) <= 1 / (1 + k^2), * mu: mean, sig: standard deviation. Equivalently * 1 - P(X < mu + k * sig) <= 1 / (1 + k^2), * F(mu + k * sig) >= k^2 / (1 + k^2). * * For k = sqrt(p / (1 - p)), we find * F(mu + k * sig) >= p, * and (mu + k * sig) is an upper-bound for the root. * * Then, introducing Y = -X, mean(Y) = -mu, sd(Y) = sig, and * P(Y >= -mu + k * sig) <= 1 / (1 + k^2), * P(-X >= -mu + k * sig) <= 1 / (1 + k^2), * P(X <= mu - k * sig) <= 1 / (1 + k^2), * F(mu - k * sig) <= 1 / (1 + k^2). * * For k = sqrt((1 - p) / p), we find * F(mu - k * sig) <= p, * and (mu - k * sig) is a lower-bound for the root. * * In cases where the Chebyshev inequality does not apply, geometric * progressions 1, 2, 4, ... and -1, -2, -4, ... are used to bracket * the root. */ if (p < 0.0 || p > 1.0) { throw new OutOfRangeException(p, 0, 1); } double lowerBound = getSupportLowerBound(); if (p == 0.0) { return lowerBound; } double upperBound = getSupportUpperBound(); if (p == 1.0) { return upperBound; } final double mu = getNumericalMean(); final double sig = FastMath.sqrt(getNumericalVariance()); final boolean chebyshevApplies; chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) || Double.isInfinite(sig) || Double.isNaN(sig)); if (lowerBound == Double.NEGATIVE_INFINITY) { if (chebyshevApplies) { lowerBound = mu - sig * FastMath.sqrt((1. - p) / p); } else { lowerBound = -1.0; while (cumulativeProbability(lowerBound) >= p) { lowerBound *= 2.0; } } } if (upperBound == Double.POSITIVE_INFINITY) { if (chebyshevApplies) { upperBound = mu + sig * FastMath.sqrt(p / (1. - p)); } else { upperBound = 1.0; while (cumulativeProbability(upperBound) < p) { upperBound *= 2.0; } } } final UnivariateFunction toSolve = new UnivariateFunction() { public double value(final double x) { return cumulativeProbability(x) - p; } }; double x = UnivariateSolverUtils.solve(toSolve, lowerBound, upperBound, getSolverAbsoluteAccuracy()); if (!isSupportConnected()) { /* Test for plateau. */ final double dx = getSolverAbsoluteAccuracy(); if (x - dx >= getSupportLowerBound()) { double px = cumulativeProbability(x); if (cumulativeProbability(x - dx) == px) { upperBound = x; while (upperBound - lowerBound > dx) { final double midPoint = 0.5 * (lowerBound + upperBound); if (cumulativeProbability(midPoint) < px) { lowerBound = midPoint; } else { upperBound = midPoint; } } return upperBound; } } } return x; } /** * Returns the solver absolute accuracy for inverse cumulative computation. * You can override this method in order to use a Brent solver with an * absolute accuracy different from the default. * * @return the maximum absolute error in inverse cumulative probability estimates */ protected double getSolverAbsoluteAccuracy() { return solverAbsoluteAccuracy; } /** * {@inheritDoc} */ @Override public void reseedRandomGenerator(long seed) { random.setSeed(seed); } /** * {@inheritDoc} *

* The default implementation uses the * * inversion method. * */ @Override public double sample() { return inverseCumulativeProbability(random.nextDouble()); } /** * {@inheritDoc} *

* The default implementation generates the sample by calling * {@link #sample()} in a loop. */ @Override public double[] sample(long sampleSize) { if (sampleSize <= 0) { throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); } // FIXME: int cast double[] out = new double[(int) sampleSize]; for (int i = 0; i < sampleSize; i++) { out[i] = sample(); } return out; } /** * {@inheritDoc} * * @return zero. * @since 3.1 */ @Override public double probability(double x) { return 0d; } @Override public INDArray sample(int[] shape) { INDArray ret = Nd4j.create(shape); return sample(ret); } @Override public INDArray sample(long[] shape) { INDArray ret = Nd4j.create(shape); return sample(ret); } @Override public INDArray sample(INDArray target) { Iterator idxIter = new NdIndexIterator(target.shape()); //For consistent values irrespective of c vs. fortran ordering long len = target.length(); for (long i = 0; i < len; i++) { target.putScalar(idxIter.next(), sample()); } return target; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy