org.deeplearning4j.nn.weights.WeightInitUtil Maven / Gradle / Ivy
/*-
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.weights;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
/**
* Weight initialization utility
*
* @author Adam Gibson
*/
public class WeightInitUtil {
/**
* Default order for the arrays created by WeightInitUtil.
*/
public static final char DEFAULT_WEIGHT_INIT_ORDER = 'f';
private WeightInitUtil() {}
public static INDArray initWeights(int[] shape, float min, float max) {
return Nd4j.rand(shape, min, max, Nd4j.getRandom());
}
/**
* Initializes a matrix with the given weight initialization scheme.
* Note: Defaults to fortran ('f') order arrays for the weights. Use {@link #initWeights(int[], WeightInit, Distribution, char, INDArray)}
* to control this
*
* @param shape the shape of the matrix
* @param initScheme the scheme to use
* @return a matrix of the specified dimensions with the specified
* distribution based on the initialization scheme
*/
public static INDArray initWeights(double fanIn, double fanOut, int[] shape, WeightInit initScheme,
Distribution dist, INDArray paramView) {
return initWeights(fanIn, fanOut, shape, initScheme, dist, DEFAULT_WEIGHT_INIT_ORDER, paramView);
}
public static INDArray initWeights(double fanIn, double fanOut, int[] shape, WeightInit initScheme,
Distribution dist, char order, INDArray paramView) {
//Note: using f order here as params get flattened to f order
INDArray ret;
switch (initScheme) {
case DISTRIBUTION:
ret = dist.sample(shape);
break;
case RELU:
ret = Nd4j.randn(order, shape).muli(FastMath.sqrt(2.0 / fanIn)); //N(0, 2/nIn)
break;
case RELU_UNIFORM:
double u = Math.sqrt(6.0 / fanIn);
ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
break;
case SIGMOID_UNIFORM:
double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut));
ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-r, r));
break;
case UNIFORM:
double a = 1.0 / Math.sqrt(fanIn);
ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-a, a));
break;
case XAVIER:
ret = Nd4j.randn(order, shape).muli(FastMath.sqrt(2.0 / (fanIn + fanOut)));
break;
case XAVIER_UNIFORM:
//As per Glorot and Bengio 2010: Uniform distribution U(-s,s) with s = sqrt(6/(fanIn + fanOut))
//Eq 16: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
double s = Math.sqrt(6.0) / Math.sqrt(fanIn + fanOut);
ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-s, s));
break;
case XAVIER_FAN_IN:
ret = Nd4j.randn(order, shape).divi(FastMath.sqrt(fanIn));
break;
case XAVIER_LEGACY:
ret = Nd4j.randn(order, shape).divi(FastMath.sqrt(shape[0] + shape[1]));
break;
case ZERO:
ret = Nd4j.create(shape, order);
break;
default:
throw new IllegalStateException("Illegal weight init value: " + initScheme);
}
INDArray flat = Nd4j.toFlattened(order, ret);
if (flat.length() != paramView.length())
throw new RuntimeException("ParamView length does not match initialized weights length (view length: "
+ paramView.length() + ", view shape: " + Arrays.toString(paramView.shape())
+ "; flattened length: " + flat.length());
paramView.assign(flat);
return paramView.reshape(order, shape);
}
/**
* Reshape the parameters view, without modifying the paramsView array values.
*
* @param shape Shape to reshape
* @param paramsView Parameters array view
*/
public static INDArray reshapeWeights(int[] shape, INDArray paramsView) {
return reshapeWeights(shape, paramsView, DEFAULT_WEIGHT_INIT_ORDER);
}
/**
* Reshape the parameters view, without modifying the paramsView array values.
*
* @param shape Shape to reshape
* @param paramsView Parameters array view
* @param flatteningOrder Order in which parameters are flattened/reshaped
*/
public static INDArray reshapeWeights(int[] shape, INDArray paramsView, char flatteningOrder) {
return paramsView.reshape(flatteningOrder, shape);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy