Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.rng.distribution.impl;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Svd;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.rng.distribution.BaseDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.ArrayUtil;
/**
*
* Limited Orthogonal distribution implementation
*
* @author [email protected]
*/
@Slf4j
public class OrthogonalDistribution extends BaseDistribution {
/**
* Default inverse cumulative probability accuracy.
*
* @since 2.1
*/
public static final double DEFAULT_INVERSE_ABSOLUTE_ACCURACY = 1e-9;
/**
* Serializable version identifier.
*/
private static final long serialVersionUID = 8589540077390120676L;
/**
* Mean of this distribution.
*/
private double gain;
private INDArray gains;
public OrthogonalDistribution(double gain) {
this.gain = gain;
this.random = Nd4j.getRandom();
}
/*
max doesn't want this distripution
public OrthogonalDistribution(@NonNull INDArray gains) {
this.gains = gains;
this.random = Nd4j.getRandom();
}
*/
/**
* Access the mean.
*
* @return the mean for this distribution.
*/
public double getMean() {
throw new UnsupportedOperationException();
}
/**
* Access the standard deviation.
*
* @return the standard deviation for this distribution.
*/
public double getStandardDeviation() {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*/
public double density(double x) {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1
* is returned, as in these cases the actual value is within
* {@code Double.MIN_VALUE} of 0 or 1.
*/
public double cumulativeProbability(double x) {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*
* @since 3.2
*/
@Override
public double inverseCumulativeProbability(final double p) throws OutOfRangeException {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*
* @deprecated See {@link org.apache.commons.math3.distribution.RealDistribution#cumulativeProbability(double, double)}
*/
@Override
@Deprecated
public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*/
@Override
public double probability(double x0, double x1) throws NumberIsTooLargeException {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*/
@Override
protected double getSolverAbsoluteAccuracy() {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*
* For mean parameter {@code mu}, the mean is {@code mu}.
*/
public double getNumericalMean() {
return getMean();
}
/**
* {@inheritDoc}
*
* For standard deviation parameter {@code s}, the variance is {@code s^2}.
*/
public double getNumericalVariance() {
final double s = getStandardDeviation();
return s * s;
}
/**
* {@inheritDoc}
*
* The lower bound of the support is always negative infinity
* no matter the parameters.
*
* @return lower bound of the support (always
* {@code Double.NEGATIVE_INFINITY})
*/
public double getSupportLowerBound() {
return Double.NEGATIVE_INFINITY;
}
/**
* {@inheritDoc}
*
* The upper bound of the support is always positive infinity
* no matter the parameters.
*
* @return upper bound of the support (always
* {@code Double.POSITIVE_INFINITY})
*/
public double getSupportUpperBound() {
return Double.POSITIVE_INFINITY;
}
/**
* {@inheritDoc}
*/
public boolean isSupportLowerBoundInclusive() {
return false;
}
/**
* {@inheritDoc}
*/
public boolean isSupportUpperBoundInclusive() {
return false;
}
/**
* {@inheritDoc}
*
* The support of this distribution is connected.
*
* @return {@code true}
*/
public boolean isSupportConnected() {
return true;
}
/**
* {@inheritDoc}
*/
@Override
public double sample() {
throw new UnsupportedOperationException();
}
@Override
public INDArray sample(int[] shape) {
return sample(ArrayUtil.toLongArray(shape));
}
@Override
public INDArray sample(long[] shape){
long numRows = 1;
for (int i = 0; i < shape.length - 1; i++)
numRows *= shape[i];
long numCols = shape[shape.length - 1];
val dtype = Nd4j.defaultFloatingPointType();
val flatShape = new long[]{numRows, numCols};
val flatRng = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, 1.0), random);
val m = flatRng.rows();
val n = flatRng.columns();
val s = Nd4j.create(dtype, m < n ? m : n);
val u = Nd4j.create(dtype, m, m);
val v = Nd4j.create(dtype, new long[] {n, n}, 'f');
Nd4j.exec(new Svd(flatRng, true, s, u, v));
if (gains == null) {
if (u.rows() >= numRows && u.columns() >= numCols) {
return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
} else {
return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
}
} else {
throw new UnsupportedOperationException();
}
}
@Override
public INDArray sample(INDArray target){
return target.assign(sample(target.shape()));
}
}