Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.deeplearning4j.util.NetworkUtils;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
/**
* A neural network layer.
*/
@Data
@EqualsAndHashCode(callSuper = true)
@NoArgsConstructor
public abstract class BaseLayer extends Layer implements Serializable, Cloneable {
protected IActivation activationFn;
protected IWeightInit weightInitFn;
protected double biasInit;
protected double gainInit;
protected List regularization;
protected List regularizationBias;
protected IUpdater iUpdater;
protected IUpdater biasUpdater;
protected IWeightNoise weightNoise;
protected GradientNormalization gradientNormalization = GradientNormalization.None; //Clipping, rescale based on l2 norm, etc
protected double gradientNormalizationThreshold = 1.0; //Threshold for l2 and element-wise gradient clipping
public BaseLayer(Builder builder) {
super(builder);
this.layerName = builder.layerName;
this.activationFn = builder.activationFn;
this.weightInitFn = builder.weightInitFn;
this.biasInit = builder.biasInit;
this.gainInit = builder.gainInit;
this.regularization = builder.regularization;
this.regularizationBias = builder.regularizationBias;
this.iUpdater = builder.iupdater;
this.biasUpdater = builder.biasUpdater;
this.gradientNormalization = builder.gradientNormalization;
this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold;
this.weightNoise = builder.weightNoise;
}
/**
* Reset the learning related configs of the layer to default. When instantiated with a global neural network
* configuration the parameters specified in the neural network configuration will be used. For internal use with
* the transfer learning API. Users should not have to call this method directly.
*/
public void resetLayerDefaultConfig() {
//clear the learning related params for all layers in the origConf and set to defaults
this.setIUpdater(null);
this.setWeightInitFn(null);
this.setBiasInit(Double.NaN);
this.setGainInit(Double.NaN);
this.regularization = null;
this.regularizationBias = null;
this.setGradientNormalization(GradientNormalization.None);
this.setGradientNormalizationThreshold(1.0);
this.iUpdater = null;
this.biasUpdater = null;
}
@Override
public BaseLayer clone() {
BaseLayer clone = (BaseLayer) super.clone();
if (clone.iDropout != null) {
clone.iDropout = clone.iDropout.clone();
}
if(regularization != null){
//Regularization fields are _usually_ thread safe and immutable, but let's clone to be sure
clone.regularization = new ArrayList<>(regularization.size());
for(Regularization r : regularization){
clone.regularization.add(r.clone());
}
}
if(regularizationBias != null){
clone.regularizationBias = new ArrayList<>(regularizationBias.size());
for(Regularization r : regularizationBias){
clone.regularizationBias.add(r.clone());
}
}
return clone;
}
/**
* Get the updater for the given parameter. Typically the same updater will be used for all updaters, but this is
* not necessarily the case
*
* @param paramName Parameter name
* @return IUpdater for the parameter
*/
@Override
public IUpdater getUpdaterByParam(String paramName) {
if (biasUpdater != null && initializer().isBiasParam(this, paramName)) {
return biasUpdater;
}
return iUpdater;
}
@Override
public GradientNormalization getGradientNormalization() {
return gradientNormalization;
}
@Override
public List getRegularizationByParam(String paramName){
if(initializer().isWeightParam(this, paramName)){
return regularization;
} else if(initializer().isBiasParam(this, paramName)){
return regularizationBias;
}
return null;
}
@SuppressWarnings("unchecked")
@Getter
@Setter
public abstract static class Builder> extends Layer.Builder {
/**
* Set the activation function for the layer. This overload can be used for custom {@link IActivation}
* instances
*
*/
protected IActivation activationFn = null;
/**
* Weight initialization scheme to use, for initial weight values
*
* @see IWeightInit
*/
protected IWeightInit weightInitFn = null;
/**
* Bias initialization value, for layers with biases. Defaults to 0
*
*/
protected double biasInit = Double.NaN;
/**
* Gain initialization value, for layers with Layer Normalization. Defaults to 1
*
*/
protected double gainInit = Double.NaN;
/**
* Regularization for the parameters (excluding biases).
*/
protected List regularization = new ArrayList<>();
/**
* Regularization for the bias parameters only
*/
protected List regularizationBias = new ArrayList<>();
/**
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link
* org.nd4j.linalg.learning.config.Nesterovs}
*
*/
protected IUpdater iupdater = null;
/**
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as set by {@link
* #updater(IUpdater)}
*
*/
protected IUpdater biasUpdater = null;
/**
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc.
*
* @see GradientNormalization
*/
protected GradientNormalization gradientNormalization = null;
/**
* Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer,
* GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue Not used
* otherwise. L2 threshold for first two types of clipping, or absolute value threshold for last type of
* clipping.
*/
protected double gradientNormalizationThreshold = Double.NaN;
/**
* Set the weight noise (such as {@link org.deeplearning4j.nn.conf.weightnoise.DropConnect} and {@link
* org.deeplearning4j.nn.conf.weightnoise.WeightNoise}) for this layer
*
*/
protected IWeightNoise weightNoise;
/**
* Set the activation function for the layer. This overload can be used for custom {@link IActivation}
* instances
*
* @param activationFunction Activation function to use for the layer
*/
public T activation(IActivation activationFunction) {
this.setActivationFn(activationFunction);
return (T) this;
}
/**
* Set the activation function for the layer, from an {@link Activation} enumeration value.
*
* @param activation Activation function to use for the layer
*/
public T activation(Activation activation) {
return activation(activation.getActivationFunction());
}
/**
* Weight initialization scheme to use, for initial weight values
*
* @see IWeightInit
*/
public T weightInit(IWeightInit weightInit) {
this.setWeightInitFn(weightInit);
return (T) this;
}
/**
* Weight initialization scheme to use, for initial weight values
*
* @see WeightInit
*/
public T weightInit(WeightInit weightInit) {
if (weightInit == WeightInit.DISTRIBUTION) {
throw new UnsupportedOperationException(
"Not supported!, Use weightInit(Distribution distribution) instead!");
}
this.setWeightInitFn(weightInit.getWeightInitFunction());
return (T) this;
}
/**
* Set weight initialization scheme to random sampling via the specified distribution. Equivalent to: {@code
* .weightInit(new WeightInitDistribution(distribution))}
*
* @param distribution Distribution to use for weight initialization
*/
public T weightInit(Distribution distribution) {
return weightInit(new WeightInitDistribution(distribution));
}
/**
* Bias initialization value, for layers with biases. Defaults to 0
*
* @param biasInit Value to use for initializing biases
*/
public T biasInit(double biasInit) {
this.setBiasInit(biasInit);
return (T) this;
}
/**
* Gain initialization value, for layers with Layer Normalization. Defaults to 1
*
* @param gainInit Value to use for initializing gain
*/
public T gainInit(double gainInit) {
this.gainInit = gainInit;
return (T) this;
}
/**
* Distribution to sample initial weights from. Equivalent to: {@code .weightInit(new
* WeightInitDistribution(distribution))}
*/
@Deprecated
public T dist(Distribution dist) {
return weightInit(dist);
}
/**
* L1 regularization coefficient (weights only). Use {@link #l1Bias(double)} to configure the l1 regularization
* coefficient for the bias.
*/
public T l1(double l1) {
//Check if existing L1 exists; if so, replace it
NetworkUtils.removeInstances(this.regularization, L1Regularization.class);
if(l1 > 0.0) {
this.regularization.add(new L1Regularization(l1));
}
return (T) this;
}
/**
* L2 regularization coefficient (weights only). Use {@link #l2Bias(double)} to configure the l2 regularization
* coefficient for the bias.
* Note: Generally, {@link WeightDecay} (set via {@link #weightDecay(double,boolean)} should be preferred to
* L2 regularization. See {@link WeightDecay} javadoc for further details.
*/
public T l2(double l2) {
//Check if existing L2 exists; if so, replace it. Also remove weight decay - it doesn't make sense to use both
NetworkUtils.removeInstances(this.regularization, L2Regularization.class);
if(l2 > 0.0) {
NetworkUtils.removeInstancesWithWarning(this.regularization, WeightDecay.class, "WeightDecay regularization removed: incompatible with added L2 regularization");
this.regularization.add(new L2Regularization(l2));
}
return (T) this;
}
/**
* L1 regularization coefficient for the bias. Default: 0. See also {@link #l1(double)}
*/
public T l1Bias(double l1Bias) {
NetworkUtils.removeInstances(this.regularizationBias, L1Regularization.class);
if(l1Bias > 0.0) {
this.regularizationBias.add(new L1Regularization(l1Bias));
}
return (T) this;
}
/**
* L2 regularization coefficient for the bias. Default: 0. See also {@link #l2(double)}
* Note: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double,boolean)} should be preferred to
* L2 regularization. See {@link WeightDecay} javadoc for further details.
*/
public T l2Bias(double l2Bias) {
NetworkUtils.removeInstances(this.regularizationBias, L2Regularization.class);
if(l2Bias > 0.0) {
NetworkUtils.removeInstancesWithWarning(this.regularizationBias, WeightDecay.class, "WeightDecay regularization removed: incompatible with added L2 regularization");
this.regularizationBias.add(new L2Regularization(l2Bias));
}
return (T) this;
}
/**
* Add weight decay regularization for the network parameters (excluding biases).
* This applies weight decay with multiplying the learning rate - see {@link WeightDecay} for more details.
*
* @param coefficient Weight decay regularization coefficient
* @see #weightDecay(double, boolean)
*/
public Builder weightDecay(double coefficient) {
return weightDecay(coefficient, true);
}
/**
* Add weight decay regularization for the network parameters (excluding biases). See {@link WeightDecay} for more details.
*
* @param coefficient Weight decay regularization coefficient
* @param applyLR Whether the learning rate should be multiplied in when performing weight decay updates. See {@link WeightDecay} for more details.
* @see #weightDecay(double, boolean)
*/
public Builder weightDecay(double coefficient, boolean applyLR) {
//Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both
NetworkUtils.removeInstances(this.regularization, WeightDecay.class);
if(coefficient > 0.0) {
NetworkUtils.removeInstancesWithWarning(this.regularization, L2Regularization.class, "L2 regularization removed: incompatible with added WeightDecay regularization");
this.regularization.add(new WeightDecay(coefficient, applyLR));
}
return this;
}
/**
* Weight decay for the biases only - see {@link #weightDecay(double)} for more details.
* This applies weight decay with multiplying the learning rate.
*
* @param coefficient Weight decay regularization coefficient
* @see #weightDecayBias(double, boolean)
*/
public Builder weightDecayBias(double coefficient) {
return weightDecayBias(coefficient, true);
}
/**
* Weight decay for the biases only - see {@link #weightDecay(double)} for more details
*
* @param coefficient Weight decay regularization coefficient
*/
public Builder weightDecayBias(double coefficient, boolean applyLR) {
//Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both
NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class);
if(coefficient > 0.0) {
NetworkUtils.removeInstancesWithWarning(this.regularizationBias, L2Regularization.class, "L2 regularization removed: incompatible with added WeightDecay regularization");
this.regularizationBias.add(new WeightDecay(coefficient, applyLR));
}
return this;
}
/**
* Set the regularization for the parameters (excluding biases) - for example {@link WeightDecay}
*
* @param regularization Regularization to apply for the network parameters/weights (excluding biases)
*/
public Builder regularization(List regularization) {
this.setRegularization(regularization);
return this;
}
/**
* Set the regularization for the biases only - for example {@link WeightDecay}
*
* @param regularizationBias Regularization to apply for the network biases only
*/
public Builder regularizationBias(List regularizationBias) {
this.setRegularizationBias(regularizationBias);
return this;
}
/**
* Gradient updater. For example, SGD for standard stochastic gradient descent, NESTEROV for Nesterov momentum,
* RSMPROP for RMSProp, etc.
*
* @see Updater
*/
@Deprecated
public T updater(Updater updater) {
return updater(updater.getIUpdaterWithDefaultConfig());
}
/**
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link
* org.nd4j.linalg.learning.config.Nesterovs}
*
* @param updater Updater to use
*/
public T updater(IUpdater updater) {
this.setIupdater(updater);
return (T) this;
}
/**
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as set by {@link
* #updater(IUpdater)}
*
* @param biasUpdater Updater to use for bias parameters
*/
public T biasUpdater(IUpdater biasUpdater) {
this.setBiasUpdater(biasUpdater);
return (T) this;
}
/**
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc.
*
* @param gradientNormalization Type of normalization to use. Defaults to None.
* @see GradientNormalization
*/
public T gradientNormalization(GradientNormalization gradientNormalization) {
this.setGradientNormalization(gradientNormalization);
return (T) this;
}
/**
* Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer,
* GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue Not used
* otherwise. L2 threshold for first two types of clipping, or absolute value threshold for last type of
* clipping.
*/
public T gradientNormalizationThreshold(double threshold) {
this.setGradientNormalizationThreshold(threshold);
return (T) this;
}
/**
* Set the weight noise (such as {@link org.deeplearning4j.nn.conf.weightnoise.DropConnect} and {@link
* org.deeplearning4j.nn.conf.weightnoise.WeightNoise}) for this layer
*
* @param weightNoise Weight noise instance to use
*/
public T weightNoise(IWeightNoise weightNoise) {
this.setWeightNoise(weightNoise);
return (T) this;
}
}
}