All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.transferlearning.FineTuneConfiguration Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.transferlearning;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.deeplearning4j.util.NetworkUtils;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
import org.nd4j.common.primitives.Optional;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.core.JsonProcessingException;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "type")
@JsonInclude(JsonInclude.Include.NON_NULL)
@NoArgsConstructor
@AllArgsConstructor
@Data
public class FineTuneConfiguration {

    protected IActivation activationFn;
    protected IWeightInit weightInitFn;
    protected Double biasInit;
    protected List regularization;
    protected List regularizationBias;
    protected boolean removeL2 = false;     //For: .l2(0.0) -> user means "no l2" so we should remove it if it is present in the original model...
    protected boolean removeL2Bias = false;
    protected boolean removeL1 = false;
    protected boolean removeL1Bias = false;
    protected boolean removeWD = false;
    protected boolean removeWDBias = false;
    protected Optional dropout;
    protected Optional weightNoise;
    protected IUpdater updater;
    protected IUpdater biasUpdater;
    protected Boolean miniBatch;
    protected Integer maxNumLineSearchIterations;
    protected Long seed;
    protected OptimizationAlgorithm optimizationAlgo;
    protected StepFunction stepFunction;
    protected Boolean minimize;
    protected Optional gradientNormalization;
    protected Double gradientNormalizationThreshold;
    protected ConvolutionMode convolutionMode;
    protected ConvolutionLayer.AlgoMode cudnnAlgoMode;
    protected Optional> constraints;

    protected Boolean pretrain;
    protected Boolean backprop;
    protected BackpropType backpropType;
    protected Integer tbpttFwdLength;
    protected Integer tbpttBackLength;

    protected WorkspaceMode trainingWorkspaceMode;
    protected WorkspaceMode inferenceWorkspaceMode;

    public static Builder builder() {
        return new Builder();
    }

    /*
     * Can't use Lombok @Builder annotation due to optionals (otherwise we have a bunch of ugly .x(Optional value)
      * methods - lombok builder doesn't support excluding fields? :(
     * Note the use of optional here: gives us 3 states...
     * 1. Null: not set
     * 2. Optional (empty): set to null
     * 3. Optional (not empty): set to specific value
     *
     * Obviously, having null only makes sense for some things (dropout, etc) whereas null for other things doesn't
     * make sense
     */
    @ToString
    public static class Builder {
        private IActivation activation;
        private IWeightInit weightInitFn;
        private Double biasInit;
        protected List regularization = new ArrayList<>();
        protected List regularizationBias = new ArrayList<>();
        protected boolean removeL2 = false;     //For: .l2(0.0) -> user means "no l2" so we should remove it if it is present in the original model...
        protected boolean removeL2Bias = false;
        protected boolean removeL1 = false;
        protected boolean removeL1Bias = false;
        protected boolean removeWD = false;
        protected boolean removeWDBias = false;
        private Optional dropout;
        private Optional weightNoise;
        private IUpdater updater;
        private IUpdater biasUpdater;
        private Boolean miniBatch;
        private Integer maxNumLineSearchIterations;
        private Long seed;
        private OptimizationAlgorithm optimizationAlgo;
        private StepFunction stepFunction;
        private Boolean minimize;
        private Optional gradientNormalization;
        private Double gradientNormalizationThreshold;
        private ConvolutionMode convolutionMode;
        private ConvolutionLayer.AlgoMode cudnnAlgoMode;
        private Optional> constraints;
        private Boolean pretrain;
        private Boolean backprop;
        private BackpropType backpropType;
        private Integer tbpttFwdLength;
        private Integer tbpttBackLength;
        private WorkspaceMode trainingWorkspaceMode;
        private WorkspaceMode inferenceWorkspaceMode;

        public Builder() {

        }

        /**
         * Activation function / neuron non-linearity
         */
        public Builder activation(IActivation activationFn) {
            this.activation = activationFn;
            return this;
        }

        /**
         * Activation function / neuron non-linearity
         */
        public Builder activation(Activation activation) {
            this.activation = activation.getActivationFunction();
            return this;
        }

        /**
         * Weight initialization scheme to use, for initial weight values
         *
         * @see IWeightInit
         */
        public Builder weightInit(IWeightInit weightInit) {
            this.weightInitFn = weightInit;
            return this;
        }

        /**
         * Weight initialization scheme to use, for initial weight values
         *
         * @see WeightInit
         */
        public Builder weightInit(WeightInit weightInit) {
            if(weightInit == WeightInit.DISTRIBUTION) {
                throw new UnsupportedOperationException("Not supported!, User weightInit(Distribution distribution) instead!");
            }

            this.weightInitFn = weightInit.getWeightInitFunction();
            return this;
        }


        /**
         * Set weight initialization scheme to random sampling via the specified distribution.
         * Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))}
         *
         * @param distribution Distribution to use for weight initialization
         */
        public Builder weightInit(Distribution distribution){
            return weightInit(new WeightInitDistribution(distribution));
        }

        /**
         * Constant for bias initialization. Default: 0.0
         *
         * @param biasInit Constant for bias initialization
         */
        public Builder biasInit(double biasInit) {
            this.biasInit = biasInit;
            return this;
        }

        /**
         * Distribution to sample initial weights from.
         * Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))}
         */
        @Deprecated
        public Builder dist(Distribution dist) {
            return weightInit(dist);
        }

        /**
         * L1 regularization coefficient for the weights (excluding biases)
         */
        public Builder l1(double l1) {
            NetworkUtils.removeInstances(regularization, L1Regularization.class);
            if(l1 > 0.0) {
                regularization.add(new L1Regularization(l1));
            }
            return this;
        }

        /**
         * L2 regularization coefficient for the weights (excluding biases)
* Note: Generally, {@link WeightDecay} (set via {@link #weightDecay(double,boolean)} should be preferred to * L2 regularization. See {@link WeightDecay} javadoc for further details.
*/ public Builder l2(double l2) { NetworkUtils.removeInstances(regularization, L2Regularization.class); if(l2 > 0.0) { NetworkUtils.removeInstancesWithWarning(regularization, WeightDecay.class, "WeightDecay regularization removed: incompatible with added L2 regularization"); regularization.add(new L2Regularization(l2)); } else { removeL2 = true; } return this; } /** * L1 regularization coefficient for the bias parameters */ public Builder l1Bias(double l1Bias) { NetworkUtils.removeInstances(regularizationBias, L1Regularization.class); if(l1Bias > 0.0) { regularizationBias.add(new L1Regularization(l1Bias)); } else { removeL1Bias = true; } return this; } /** * L2 regularization coefficient for the bias parameters
* Note: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double,boolean)} should be preferred to * L2 regularization. See {@link WeightDecay} javadoc for further details.
*/ public Builder l2Bias(double l2Bias) { NetworkUtils.removeInstances(regularizationBias, L2Regularization.class); if(l2Bias > 0.0) { NetworkUtils.removeInstancesWithWarning(regularizationBias, WeightDecay.class, "WeightDecay bias regularization removed: incompatible with added L2 regularization"); regularizationBias.add(new L2Regularization(l2Bias)); } else { removeL2Bias = true; } return this; } /** * Add weight decay regularization for the network parameters (excluding biases).
* This applies weight decay with multiplying the learning rate - see {@link WeightDecay} for more details.
* * @param coefficient Weight decay regularization coefficient * @see #weightDecay(double, boolean) */ public Builder weightDecay(double coefficient) { return weightDecay(coefficient, true); } /** * Add weight decay regularization for the network parameters (excluding biases). See {@link WeightDecay} for more details.
* * @param coefficient Weight decay regularization coefficient * @param applyLR Whether the learning rate should be multiplied in when performing weight decay updates. See {@link WeightDecay} for more details. * @see #weightDecay(double, boolean) */ public Builder weightDecay(double coefficient, boolean applyLR) { //Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both NetworkUtils.removeInstances(this.regularization, WeightDecay.class); if(coefficient > 0.0) { NetworkUtils.removeInstancesWithWarning(this.regularization, L2Regularization.class, "L2 regularization removed: incompatible with added WeightDecay regularization"); this.regularization.add(new WeightDecay(coefficient, applyLR)); } else { removeWD = true; } return this; } /** * Weight decay for the biases only - see {@link #weightDecay(double)} for more details. * This applies weight decay with multiplying the learning rate.
* * @param coefficient Weight decay regularization coefficient * @see #weightDecayBias(double, boolean) */ public Builder weightDecayBias(double coefficient) { return weightDecayBias(coefficient, true); } /** * Weight decay for the biases only - see {@link #weightDecay(double)} for more details
* * @param coefficient Weight decay regularization coefficient */ public Builder weightDecayBias(double coefficient, boolean applyLR) { //Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class); if(coefficient > 0) { NetworkUtils.removeInstancesWithWarning(this.regularizationBias, L2Regularization.class, "L2 bias regularization removed: incompatible with added WeightDecay regularization"); this.regularizationBias.add(new WeightDecay(coefficient, applyLR)); } else { removeWDBias = true; } return this; } /** * Set the dropout * * @param dropout Dropout, such as {@link Dropout}, {@link org.deeplearning4j.nn.conf.dropout.GaussianDropout}, * {@link org.deeplearning4j.nn.conf.dropout.GaussianNoise} etc */ public Builder dropout(IDropout dropout) { this.dropout = Optional.ofNullable(dropout); return this; } /** * Dropout probability. This is the probability of retaining each input activation value for a layer. * dropOut(x) will keep an input activation with probability x, and set to 0 with probability 1-x.
* dropOut(0.0) is a special value / special case - when set to 0.0., dropout is disabled (not applied). Note * that a dropout value of 1.0 is functionally equivalent to no dropout: i.e., 100% probability of retaining * each input activation.
*

* Note 1: Dropout is applied at training time only - and is automatically not applied at test time * (for evaluation, etc)
* Note 2: This sets the probability per-layer. Care should be taken when setting lower values for * complex networks (too much information may be lost with aggressive (very low) dropout values).
* Note 3: Frequently, dropout is not applied to (or, has higher retain probability for) input (first layer) * layers. Dropout is also often not applied to output layers. This needs to be handled MANUALLY by the user * - set .dropout(0) on those layers when using global dropout setting.
* Note 4: Implementation detail (most users can ignore): DL4J uses inverted dropout, as described here: * http://cs231n.github.io/neural-networks-2/ *

* * @param inputRetainProbability Dropout probability (probability of retaining each input activation value for a layer) * @see #dropout(IDropout) */ public Builder dropOut(double inputRetainProbability){ if(inputRetainProbability == 0.0){ return dropout(null); } return dropout(new Dropout(inputRetainProbability)); } /** * Set the weight noise (such as {@link org.deeplearning4j.nn.conf.weightnoise.DropConnect} and * {@link org.deeplearning4j.nn.conf.weightnoise.WeightNoise}) * * @param weightNoise Weight noise instance to use */ public Builder weightNoise(IWeightNoise weightNoise) { this.weightNoise = Optional.ofNullable(weightNoise); return this; } /** * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} * or {@link org.nd4j.linalg.learning.config.Nesterovs} * * @param updater Updater to use */ public Builder updater(IUpdater updater) { this.updater = updater; return this; } /** * @deprecated Use {@link #updater(IUpdater)} */ @Deprecated public Builder updater(Updater updater) { return updater(updater.getIUpdaterWithDefaultConfig()); } /** * Gradient updater configuration, for the biases only. If not set, biases will use the updater as * set by {@link #updater(IUpdater)} * * @param biasUpdater Updater to use for bias parameters */ public Builder biasUpdater(IUpdater biasUpdater) { this.biasUpdater = biasUpdater; return this; } /** * Whether scores and gradients should be divided by the minibatch size.
* Most users should leave this ast he default value of true. */ public Builder miniBatch(boolean miniBatch) { this.miniBatch = miniBatch; return this; } public Builder maxNumLineSearchIterations(int maxNumLineSearchIterations) { this.maxNumLineSearchIterations = maxNumLineSearchIterations; return this; } /** * RNG seed for reproducibility * @param seed RNG seed to use */ public Builder seed(long seed) { this.seed = seed; return this; } /** * RNG seed for reproducibility * @param seed RNG seed to use */ public Builder seed(int seed){ return seed((long)seed); } public Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo) { this.optimizationAlgo = optimizationAlgo; return this; } public Builder stepFunction(StepFunction stepFunction) { this.stepFunction = stepFunction; return this; } public Builder minimize(boolean minimize) { this.minimize = minimize; return this; } /** * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc. * See {@link GradientNormalization} for details * * @param gradientNormalization Type of normalization to use. Defaults to None. * @see GradientNormalization */ public Builder gradientNormalization(GradientNormalization gradientNormalization) { this.gradientNormalization = Optional.ofNullable(gradientNormalization); return this; } /** * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, * GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue
* Not used otherwise.
* L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping */ public Builder gradientNormalizationThreshold(double gradientNormalizationThreshold) { this.gradientNormalizationThreshold = gradientNormalizationThreshold; return this; } /** * Sets the convolution mode for convolutional layers, which impacts padding and output sizes. * See {@link ConvolutionMode} for details. Defaults to ConvolutionMode.TRUNCATE
* @param convolutionMode Convolution mode to use */ public Builder convolutionMode(ConvolutionMode convolutionMode) { this.convolutionMode = convolutionMode; return this; } /** * Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage of cuDNN. * See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. */ public Builder cudnnAlgoMode(ConvolutionLayer.AlgoMode cudnnAlgoMode) { this.cudnnAlgoMode = cudnnAlgoMode; return this; } /** * Set constraints to be applied to all layers. Default: no constraints.
* Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, * etc). These constraints are applied at each iteration, after the parameters have been updated. * * @param constraints Constraints to apply to all parameters of all layers */ public Builder constraints(List constraints) { this.constraints = Optional.ofNullable(constraints); return this; } public Builder pretrain(boolean pretrain) { this.pretrain = pretrain; return this; } public Builder backprop(boolean backprop) { this.backprop = backprop; return this; } /** * The type of backprop. Default setting is used for most networks (MLP, CNN etc), * but optionally truncated BPTT can be used for training recurrent neural networks. * If using TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength() * * @param backpropType Type of backprop. Default: BackpropType.Standard */ public Builder backpropType(BackpropType backpropType) { this.backpropType = backpropType; return this; } /** * When doing truncated BPTT: how many steps of forward pass should we do * before doing (truncated) backprop?
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
* Typically tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter, * but may be larger than it in some circumstances (but never smaller)
* Ideally your training data time series length should be divisible by this * This is the k1 parameter on pg23 of * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf * * @param tbpttFwdLength Forward length > 0, >= backwardLength */ public Builder tbpttFwdLength(int tbpttFwdLength) { this.tbpttFwdLength = tbpttFwdLength; return this; } /** * When doing truncated BPTT: how many steps of backward should we do?
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
* This is the k2 parameter on pg23 of * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf * * @param tbpttBackLength <= forwardLength */ public Builder tbpttBackLength(int tbpttBackLength) { this.tbpttBackLength = tbpttBackLength; return this; } /** * This method defines Workspace mode being used during training: * NONE: workspace won't be used * ENABLED: workspaces will be used for training (reduced memory and better performance) * * @param trainingWorkspaceMode Workspace mode for training * @return Builder */ public Builder trainingWorkspaceMode(WorkspaceMode trainingWorkspaceMode) { this.trainingWorkspaceMode = trainingWorkspaceMode; return this; } /** * This method defines Workspace mode being used during inference:
* NONE: workspace won't be used
* ENABLED: workspaces will be used for inference (reduced memory and better performance) * * @param inferenceWorkspaceMode Workspace mode for inference * @return Builder */ public Builder inferenceWorkspaceMode(WorkspaceMode inferenceWorkspaceMode) { this.inferenceWorkspaceMode = inferenceWorkspaceMode; return this; } public FineTuneConfiguration build() { return new FineTuneConfiguration(activation, weightInitFn, biasInit, regularization, regularizationBias, removeL2, removeL2Bias, removeL1, removeL1Bias, removeWD, removeWDBias, dropout, weightNoise, updater, biasUpdater, miniBatch, maxNumLineSearchIterations, seed, optimizationAlgo, stepFunction, minimize, gradientNormalization, gradientNormalizationThreshold, convolutionMode, cudnnAlgoMode, constraints, pretrain, backprop, backpropType, tbpttFwdLength, tbpttBackLength, trainingWorkspaceMode, inferenceWorkspaceMode); } } public NeuralNetConfiguration appliedNeuralNetConfiguration(NeuralNetConfiguration nnc) { applyToNeuralNetConfiguration(nnc); nnc = new NeuralNetConfiguration.Builder(nnc.clone()).build(); return nnc; } public void applyToNeuralNetConfiguration(NeuralNetConfiguration nnc) { Layer l = nnc.getLayer(); Updater originalUpdater = null; WeightInit origWeightInit = null; if (l != null) { //As per NeuralNetConfiguration.configureLayer and LayerValidation.configureBaseLayer: only copy dropout to base layers // this excludes things like subsampling and activation layers if (dropout != null && l instanceof BaseLayer) { IDropout d = dropout.orElse(null); if(d != null) d = d.clone(); //Clone to avoid shared state between layers l.setIDropout(d); } if(constraints != null) l.setConstraints(constraints.orElse(null)); } if (l != null && l instanceof BaseLayer) { BaseLayer bl = (BaseLayer) l; if (activationFn != null) bl.setActivationFn(activationFn); if (weightInitFn != null) bl.setWeightInitFn(weightInitFn); if (biasInit != null) bl.setBiasInit(biasInit); if (regularization != null && !regularization.isEmpty()) bl.setRegularization(regularization); if (regularizationBias != null && !regularizationBias.isEmpty()) bl.setRegularizationBias(regularizationBias); if (removeL2) NetworkUtils.removeInstances(bl.getRegularization(), L2Regularization.class); if (removeL2Bias) NetworkUtils.removeInstances(bl.getRegularizationBias(), L2Regularization.class); if (removeL1) NetworkUtils.removeInstances(bl.getRegularization(), L1Regularization.class); if (removeL1Bias) NetworkUtils.removeInstances(bl.getRegularizationBias(), L1Regularization.class); if (removeWD) NetworkUtils.removeInstances(bl.getRegularization(), WeightDecay.class); if (removeWDBias) NetworkUtils.removeInstances(bl.getRegularizationBias(), WeightDecay.class); if (gradientNormalization != null) bl.setGradientNormalization(gradientNormalization.orElse(null)); if (gradientNormalizationThreshold != null) bl.setGradientNormalizationThreshold(gradientNormalizationThreshold); if (updater != null){ bl.setIUpdater(updater); } if (biasUpdater != null){ bl.setBiasUpdater(biasUpdater); } if (weightNoise != null){ bl.setWeightNoise(weightNoise.orElse(null)); } } if (miniBatch != null) nnc.setMiniBatch(miniBatch); if (maxNumLineSearchIterations != null) nnc.setMaxNumLineSearchIterations(maxNumLineSearchIterations); if (seed != null) nnc.setSeed(seed); if (optimizationAlgo != null) nnc.setOptimizationAlgo(optimizationAlgo); if (stepFunction != null) nnc.setStepFunction(stepFunction); if (minimize != null) nnc.setMinimize(minimize); if (convolutionMode != null && l instanceof ConvolutionLayer) { ((ConvolutionLayer) l).setConvolutionMode(convolutionMode); } if (cudnnAlgoMode != null && l instanceof ConvolutionLayer) { ((ConvolutionLayer) l).setCudnnAlgoMode(cudnnAlgoMode); } if (convolutionMode != null && l instanceof SubsamplingLayer) { ((SubsamplingLayer) l).setConvolutionMode(convolutionMode); } //Perform validation if (l != null) { LayerValidation.generalValidation(l.getLayerName(), l, get(dropout), regularization, regularizationBias, get(constraints), null, null); } } private static T get(Optional optional){ if(optional == null){ return null; } return optional.orElse(null); } public void applyToMultiLayerConfiguration(MultiLayerConfiguration conf) { if (backpropType != null) conf.setBackpropType(backpropType); if (tbpttFwdLength != null) conf.setTbpttFwdLength(tbpttFwdLength); if (tbpttBackLength != null) conf.setTbpttBackLength(tbpttBackLength); } public void applyToComputationGraphConfiguration(ComputationGraphConfiguration conf) { if (backpropType != null) conf.setBackpropType(backpropType); if (tbpttFwdLength != null) conf.setTbpttFwdLength(tbpttFwdLength); if (tbpttBackLength != null) conf.setTbpttBackLength(tbpttBackLength); } public NeuralNetConfiguration.Builder appliedNeuralNetConfigurationBuilder() { NeuralNetConfiguration.Builder confBuilder = new NeuralNetConfiguration.Builder(); if (activationFn != null) confBuilder.setActivationFn(activationFn); if (weightInitFn != null) confBuilder.setWeightInitFn(weightInitFn); if (biasInit != null) confBuilder.setBiasInit(biasInit); if (regularization != null) confBuilder.setRegularization(regularization); if (regularizationBias != null) confBuilder.setRegularizationBias(regularizationBias); if (dropout != null) confBuilder.setIdropOut(dropout.orElse(null)); if (updater != null) confBuilder.updater(updater); if(biasUpdater != null) confBuilder.biasUpdater(biasUpdater); if (miniBatch != null) confBuilder.setMiniBatch(miniBatch); if (maxNumLineSearchIterations != null) confBuilder.setMaxNumLineSearchIterations(maxNumLineSearchIterations); if (seed != null) confBuilder.setSeed(seed); if (optimizationAlgo != null) confBuilder.setOptimizationAlgo(optimizationAlgo); if (stepFunction != null) confBuilder.setStepFunction(stepFunction); if (minimize != null) confBuilder.setMinimize(minimize); if (gradientNormalization != null) confBuilder.setGradientNormalization(gradientNormalization.orElse(null)); if (gradientNormalizationThreshold != null) confBuilder.setGradientNormalizationThreshold(gradientNormalizationThreshold); if (trainingWorkspaceMode != null) confBuilder.trainingWorkspaceMode(trainingWorkspaceMode); if (inferenceWorkspaceMode != null) confBuilder.inferenceWorkspaceMode(inferenceWorkspaceMode); return confBuilder; } public String toJson() { try { return NeuralNetConfiguration.mapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } } public String toYaml() { try { return NeuralNetConfiguration.mapperYaml().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } } public static FineTuneConfiguration fromJson(String json) { try { return NeuralNetConfiguration.mapper().readValue(json, FineTuneConfiguration.class); } catch (IOException e) { throw new RuntimeException(e); } } public static FineTuneConfiguration fromYaml(String yaml) { try { return NeuralNetConfiguration.mapperYaml().readValue(yaml, FineTuneConfiguration.class); } catch (IOException e) { throw new RuntimeException(e); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy