All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.arbiter.BaseNetworkSpace Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta7
Show newest version
/*
 *  * Copyright 2016 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 */

package org.deeplearning4j.arbiter;

import lombok.AllArgsConstructor;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
import org.deeplearning4j.nn.weights.WeightInit;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * This is an abstract ParameterSpace for both MultiLayerNetworks (MultiLayerSpace) and ComputationGraph (ComputationGraphSpace)
 * 

* Functionality here should match {@link org.deeplearning4j.nn.conf.NeuralNetConfiguration.Builder} * * @param Type of network (MultiLayerNetwork or ComputationGraph) * @author Alex Black */ public abstract class BaseNetworkSpace implements ParameterSpace { protected ParameterSpace useDropConnect; protected ParameterSpace iterations; protected Long seed; protected ParameterSpace optimizationAlgo; protected ParameterSpace regularization; protected ParameterSpace schedules; protected ParameterSpace activationFunction; protected ParameterSpace biasInit; protected ParameterSpace weightInit; protected ParameterSpace dist; protected ParameterSpace learningRate; protected ParameterSpace biasLearningRate; protected ParameterSpace> learningRateAfter; protected ParameterSpace lrScoreBasedDecay; protected ParameterSpace learningRateDecayPolicy; protected ParameterSpace> learningRateSchedule; protected ParameterSpace lrPolicyDecayRate; protected ParameterSpace lrPolicyPower; protected ParameterSpace lrPolicySteps; protected ParameterSpace maxNumLineSearchIterations; protected ParameterSpace miniBatch; protected ParameterSpace minimize; protected ParameterSpace stepFunction; protected ParameterSpace l1; protected ParameterSpace l2; protected ParameterSpace dropOut; protected ParameterSpace momentum; protected ParameterSpace> momentumAfter; protected ParameterSpace updater; protected ParameterSpace epsilon; protected ParameterSpace rho; protected ParameterSpace rmsDecay; protected ParameterSpace adamMeanDecay; protected ParameterSpace adamVarDecay; protected ParameterSpace gradientNormalization; protected ParameterSpace gradientNormalizationThreshold; protected ParameterSpace cnnInputSize; protected List layerSpaces = new ArrayList<>(); //NeuralNetConfiguration.ListBuilder/MultiLayerConfiguration.Builder options: protected ParameterSpace backprop; protected ParameterSpace pretrain; protected ParameterSpace backpropType; protected ParameterSpace tbpttFwdLength; protected ParameterSpace tbpttBwdLength; protected int numEpochs = 1; @SuppressWarnings("unchecked") protected BaseNetworkSpace(Builder builder) { this.useDropConnect = builder.useDropConnect; this.iterations = builder.iterations; this.seed = builder.seed; this.optimizationAlgo = builder.optimizationAlgo; this.regularization = builder.regularization; this.schedules = builder.schedules; this.activationFunction = builder.activationFunction; this.biasInit = builder.biasInit; this.weightInit = builder.weightInit; this.dist = builder.dist; this.learningRate = builder.learningRate; this.biasLearningRate = builder.biasLearningRate; this.learningRateAfter = builder.learningRateAfter; this.lrScoreBasedDecay = builder.lrScoreBasedDecay; this.learningRateDecayPolicy = builder.learningRateDecayPolicy; this.learningRateSchedule = builder.learningRateSchedule; this.lrPolicyDecayRate = builder.lrPolicyDecayRate; this.lrPolicyPower = builder.lrPolicyPower; this.lrPolicySteps = builder.lrPolicySteps; this.maxNumLineSearchIterations = builder.maxNumLineSearchIterations; this.miniBatch = builder.miniBatch; this.minimize = builder.minimize; this.stepFunction = builder.stepFunction; this.l1 = builder.l1; this.l2 = builder.l2; this.dropOut = builder.dropOut; this.momentum = builder.momentum; this.momentumAfter = builder.momentumAfter; this.updater = builder.updater; this.epsilon = builder.epsilon; this.rho = builder.rho; this.rmsDecay = builder.rmsDecay; this.gradientNormalization = builder.gradientNormalization; this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold; this.adamMeanDecay = builder.adamMeanDecay; this.adamVarDecay = builder.adamVarDecay; this.backprop = builder.backprop; this.pretrain = builder.pretrain; this.backpropType = builder.backpropType; this.tbpttFwdLength = builder.tbpttFwdLength; this.tbpttBwdLength = builder.tbpttBwdLength; this.cnnInputSize = builder.cnnInputSize; this.numEpochs = builder.numEpochs; } protected NeuralNetConfiguration.Builder randomGlobalConf(double[] values) { //Create MultiLayerConfiguration... NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder(); if (useDropConnect != null) builder.useDropConnect(useDropConnect.getValue(values)); if (iterations != null) builder.iterations(iterations.getValue(values)); if (seed != null) builder.seed(seed); if (optimizationAlgo != null) builder.optimizationAlgo(optimizationAlgo.getValue(values)); if (regularization != null) builder.regularization(regularization.getValue(values)); // if(schedules != null) builder.learningRateSchedule(schedules.getValue(values)); if (activationFunction != null) builder.activation(activationFunction.getValue(values)); if (biasInit != null) builder.biasInit(biasInit.getValue(values)); if (weightInit != null) builder.weightInit(weightInit.getValue(values)); if (dist != null) builder.dist(dist.getValue(values)); if (learningRate != null) builder.learningRate(learningRate.getValue(values)); if (biasLearningRate != null) builder.biasLearningRate(biasLearningRate.getValue(values)); if (learningRateAfter != null) builder.learningRateSchedule(learningRateAfter.getValue(values)); if (lrScoreBasedDecay != null) builder.learningRateScoreBasedDecayRate(lrScoreBasedDecay.getValue(values)); if (learningRateDecayPolicy != null) builder.learningRateDecayPolicy(learningRateDecayPolicy.getValue(values)); if (learningRateSchedule != null) builder.learningRateSchedule(learningRateSchedule.getValue(values)); if (lrPolicyDecayRate != null) builder.lrPolicyDecayRate(lrPolicyDecayRate.getValue(values)); if (lrPolicyPower != null) builder.lrPolicyPower(lrPolicyPower.getValue(values)); if (lrPolicySteps != null) builder.lrPolicySteps(lrPolicySteps.getValue(values)); if (maxNumLineSearchIterations != null) builder.maxNumLineSearchIterations(maxNumLineSearchIterations.getValue(values)); if (miniBatch != null) builder.miniBatch(miniBatch.getValue(values)); if (minimize != null) builder.minimize(minimize.getValue(values)); if (stepFunction != null) builder.stepFunction(stepFunction.getValue(values)); if (l1 != null) builder.l1(l1.getValue(values)); if (l2 != null) builder.l2(l2.getValue(values)); if (dropOut != null) builder.dropOut(dropOut.getValue(values)); if (momentum != null) builder.momentum(momentum.getValue(values)); if (momentumAfter != null) builder.momentumAfter(momentumAfter.getValue(values)); if (updater != null) builder.updater(updater.getValue(values)); if (epsilon != null) builder.epsilon(epsilon.getValue(values)); if (rho != null) builder.rho(rho.getValue(values)); if (rmsDecay != null) builder.rmsDecay(rmsDecay.getValue(values)); if (gradientNormalization != null) builder.gradientNormalization(gradientNormalization.getValue(values)); if (gradientNormalizationThreshold != null) builder.gradientNormalizationThreshold(gradientNormalizationThreshold.getValue(values)); if (adamMeanDecay != null) builder.adamMeanDecay(adamMeanDecay.getValue(values)); if (adamVarDecay != null) builder.adamVarDecay(adamVarDecay.getValue(values)); return builder; } @Override public List collectLeaves() { List list = new ArrayList<>(); if (useDropConnect != null) list.addAll(useDropConnect.collectLeaves()); if (iterations != null) list.addAll(iterations.collectLeaves()); if (optimizationAlgo != null) list.addAll(optimizationAlgo.collectLeaves()); if (regularization != null) list.addAll(regularization.collectLeaves()); if (schedules != null) list.addAll(schedules.collectLeaves()); if (activationFunction != null) list.addAll(activationFunction.collectLeaves()); if (biasInit != null) list.addAll(biasInit.collectLeaves()); if (weightInit != null) list.addAll(weightInit.collectLeaves()); if (dist != null) list.addAll(dist.collectLeaves()); if (learningRate != null) list.addAll(learningRate.collectLeaves()); if (biasLearningRate != null) list.addAll(biasLearningRate.collectLeaves()); if (learningRateAfter != null) list.addAll(learningRateAfter.collectLeaves()); if (lrScoreBasedDecay != null) list.addAll(lrScoreBasedDecay.collectLeaves()); if (learningRateDecayPolicy != null) list.addAll(learningRateDecayPolicy.collectLeaves()); if (learningRateSchedule != null) list.addAll(learningRateSchedule.collectLeaves()); if (lrPolicyDecayRate != null) list.addAll(lrPolicyDecayRate.collectLeaves()); if (lrPolicyPower != null) list.addAll(lrPolicyPower.collectLeaves()); if (lrPolicySteps != null) list.addAll(lrPolicySteps.collectLeaves()); if (maxNumLineSearchIterations != null) list.addAll(maxNumLineSearchIterations.collectLeaves()); if (miniBatch != null) list.addAll(miniBatch.collectLeaves()); if (minimize != null) list.addAll(minimize.collectLeaves()); if (stepFunction != null) list.addAll(minimize.collectLeaves()); if (l1 != null) list.addAll(l1.collectLeaves()); if (l2 != null) list.addAll(l2.collectLeaves()); if (dropOut != null) list.addAll(dropOut.collectLeaves()); if (momentum != null) list.addAll(momentum.collectLeaves()); if (momentumAfter != null) list.addAll(momentumAfter.collectLeaves()); if (updater != null) list.addAll(updater.collectLeaves()); if (epsilon != null) list.addAll(epsilon.collectLeaves()); if (rho != null) list.addAll(rho.collectLeaves()); if (rmsDecay != null) list.addAll(rmsDecay.collectLeaves()); if (gradientNormalization != null) list.addAll(gradientNormalization.collectLeaves()); if (gradientNormalizationThreshold != null) list.addAll(gradientNormalizationThreshold.collectLeaves()); if (cnnInputSize != null) list.addAll(cnnInputSize.collectLeaves()); if (adamMeanDecay != null) list.addAll(adamMeanDecay.collectLeaves()); if (adamVarDecay != null) list.addAll(adamVarDecay.collectLeaves()); if (backprop != null) list.add(backprop); if (pretrain != null) list.add(pretrain); if (backpropType != null) list.add(backpropType); if (tbpttBwdLength != null) list.add(tbpttBwdLength); if (tbpttFwdLength != null) list.add(tbpttFwdLength); return list; } @Override public boolean isLeaf() { return false; } @Override public void setIndices(int... indices) { throw new UnsupportedOperationException("Cannot set indices for non leaf"); } @Override public String toString() { StringBuilder sb = new StringBuilder(); if (useDropConnect != null) sb.append("useDropConnect: ").append(useDropConnect).append("\n"); if (iterations != null) sb.append("iterations: ").append(iterations).append("\n"); if (seed != null) sb.append("seed: ").append(seed).append("\n"); if (optimizationAlgo != null) sb.append("optimizationAlgo: ").append(optimizationAlgo).append("\n"); if (regularization != null) sb.append("regularization: ").append(regularization).append("\n"); if (schedules != null) sb.append("schedules: ").append(schedules).append("\n"); if (activationFunction != null) sb.append("activationFunction: ").append(activationFunction).append("\n"); if (weightInit != null) sb.append("weightInit: ").append(weightInit).append("\n"); if (dist != null) sb.append("dist: ").append(dist).append("\n"); if (learningRate != null) sb.append("learningRate: ").append(learningRate).append("\n"); if (biasLearningRate != null) sb.append("biasLearningRate: ").append(biasLearningRate).append("\n"); if (learningRateAfter != null) sb.append("learningRateAfter: ").append(learningRateAfter).append("\n"); if (lrScoreBasedDecay != null) sb.append("lrScoreBasedDecay: ").append(lrScoreBasedDecay).append("\n"); if (learningRateDecayPolicy != null) sb.append("learningRateDecayPolicy: ").append(learningRateDecayPolicy).append("\n"); if (learningRateSchedule != null) sb.append("learningRateSchedule: ").append(learningRateSchedule).append("\n"); if (lrPolicyDecayRate != null) sb.append("lrPolicyDecayRate: ").append(lrPolicyDecayRate).append("\n"); if (lrPolicyPower != null) sb.append("lrPolicyPower: ").append(lrPolicyPower).append("\n"); if (lrPolicySteps != null) sb.append("lrPolicySteps: ").append(lrPolicySteps).append("\n"); if (maxNumLineSearchIterations != null) sb.append("maxNumLineSearchIterations: ").append(maxNumLineSearchIterations).append("\n"); if (miniBatch != null) sb.append("miniBatch: ").append(miniBatch).append("\n"); if (minimize != null) sb.append("minimize: ").append(minimize).append("\n"); if (stepFunction != null) sb.append("stepFunction: ").append(stepFunction).append("\n"); if (l1 != null) sb.append("l1: ").append(l1).append("\n"); if (l2 != null) sb.append("l2: ").append(l2).append("\n"); if (dropOut != null) sb.append("dropOut: ").append(dropOut).append("\n"); if (momentum != null) sb.append("momentum: ").append(momentum).append("\n"); if (momentumAfter != null) sb.append("momentumAfter: ").append(momentumAfter).append("\n"); if (updater != null) sb.append("updater: ").append(updater).append("\n"); if (epsilon != null) sb.append("epsilon: ").append(epsilon).append("\n"); if (rho != null) sb.append("rho: ").append(rho).append("\n"); if (rmsDecay != null) sb.append("rmsDecay: ").append(rmsDecay).append("\n"); if (gradientNormalization != null) sb.append("gradientNormalization: ").append(gradientNormalization).append("\n"); if (gradientNormalizationThreshold != null) sb.append("gradientNormalizationThreshold: ").append(gradientNormalizationThreshold).append("\n"); if (backprop != null) sb.append("backprop: ").append(backprop).append("\n"); if (pretrain != null) sb.append("pretrain: ").append(pretrain).append("\n"); if (backpropType != null) sb.append("backpropType: ").append(backpropType).append("\n"); if (tbpttFwdLength != null) sb.append("tbpttFwdLength: ").append(tbpttFwdLength).append("\n"); if (tbpttBwdLength != null) sb.append("tbpttBwdLength: ").append(tbpttBwdLength).append("\n"); if (cnnInputSize != null) sb.append("cnnInputSize: ").append(cnnInputSize).append("\n"); if (adamMeanDecay != null) sb.append("adamMeanDecay: ").append(adamMeanDecay).append("\n"); if (adamVarDecay != null) sb.append("adamVarDecay: ").append(adamVarDecay).append("\n"); int i = 0; for (LayerConf conf : layerSpaces) { sb.append("Layer config ").append(i++).append(": (Number layers:").append(conf.numLayers) .append(", duplicate: ").append(conf.duplicateConfig).append("), ") .append(conf.layerSpace.toString()).append("\n"); } return sb.toString(); } @AllArgsConstructor private static class LayerConf { private final LayerSpace layerSpace; private final ParameterSpace numLayers; private final boolean duplicateConfig; } @SuppressWarnings("unchecked") protected abstract static class Builder> { private ParameterSpace activationFunction; private ParameterSpace weightInit; private ParameterSpace biasInit; private ParameterSpace useDropConnect; private ParameterSpace iterations; private Long seed; private ParameterSpace optimizationAlgo; private ParameterSpace regularization; private ParameterSpace schedules; private ParameterSpace dist; private ParameterSpace learningRate; private ParameterSpace biasLearningRate; private ParameterSpace> learningRateAfter; private ParameterSpace lrScoreBasedDecay; private ParameterSpace learningRateDecayPolicy; private ParameterSpace> learningRateSchedule; private ParameterSpace lrPolicyDecayRate; private ParameterSpace lrPolicyPower; private ParameterSpace lrPolicySteps; private ParameterSpace maxNumLineSearchIterations; private ParameterSpace miniBatch; private ParameterSpace minimize; private ParameterSpace stepFunction; private ParameterSpace l1; private ParameterSpace l2; private ParameterSpace dropOut; private ParameterSpace momentum; private ParameterSpace> momentumAfter; private ParameterSpace updater; private ParameterSpace epsilon; private ParameterSpace rho; private ParameterSpace rmsDecay; private ParameterSpace gradientNormalization; private ParameterSpace gradientNormalizationThreshold; private ParameterSpace cnnInputSize; private ParameterSpace adamMeanDecay; private ParameterSpace adamVarDecay; //NeuralNetConfiguration.ListBuilder/MultiLayerConfiguration.Builder options: private ParameterSpace backprop; private ParameterSpace pretrain; private ParameterSpace backpropType; private ParameterSpace tbpttFwdLength; private ParameterSpace tbpttBwdLength; //Early stopping configuration / (fixed) number of epochs: private EarlyStoppingConfiguration earlyStoppingConfiguration; private int numEpochs = 1; public T useDropConnect(boolean useDropConnect) { return useDropConnect(new FixedValue<>(useDropConnect)); } public T useDropConnect(ParameterSpace parameterSpace) { this.useDropConnect = parameterSpace; return (T) this; } public T iterations(int iterations) { return iterations(new FixedValue<>(iterations)); } public T iterations(ParameterSpace parameterSpace) { this.iterations = parameterSpace; return (T) this; } public T seed(long seed) { this.seed = seed; return (T) this; } public T optimizationAlgo(OptimizationAlgorithm optimizationAlgorithm) { return optimizationAlgo(new FixedValue<>(optimizationAlgorithm)); } public T optimizationAlgo(ParameterSpace parameterSpace) { this.optimizationAlgo = parameterSpace; return (T) this; } public T regularization(boolean useRegularization) { return regularization(new FixedValue<>(useRegularization)); } public T regularization(ParameterSpace parameterSpace) { this.regularization = parameterSpace; return (T) this; } public T schedules(boolean schedules) { return schedules(new FixedValue<>(schedules)); } public T schedules(ParameterSpace schedules) { this.schedules = schedules; return (T) this; } public T activation(String activationFunction) { return activation(new FixedValue<>(activationFunction)); } public T activation(ParameterSpace activationFunction) { this.activationFunction = activationFunction; return (T) this; } public T weightInit(WeightInit weightInit) { return weightInit(new FixedValue<>(weightInit)); } public T weightInit(ParameterSpace weightInit) { this.weightInit = weightInit; return (T) this; } public T dist(Distribution dist) { return dist(new FixedValue<>(dist)); } public T dist(ParameterSpace dist) { this.dist = dist; return (T) this; } public T learningRate(double learningRate) { return learningRate(new FixedValue<>(learningRate)); } public T learningRate(ParameterSpace learningRate) { this.learningRate = learningRate; return (T) this; } public T biasLearningRate(double learningRate) { return biasLearningRate(new FixedValue<>(learningRate)); } public T biasLearningRate(ParameterSpace biasLearningRate) { this.biasLearningRate = biasLearningRate; return (T) this; } public T learningRateAfter(Map learningRateAfter) { return learningRateAfter(new FixedValue<>(learningRateAfter)); } public T learningRateAfter(ParameterSpace> learningRateAfter) { this.learningRateAfter = learningRateAfter; return (T) this; } public T learningRateScoreBasedDecayRate(double lrScoreBasedDecay) { return learningRateScoreBasedDecayRate(new FixedValue<>(lrScoreBasedDecay)); } public T learningRateScoreBasedDecayRate(ParameterSpace lrScoreBasedDecay) { this.lrScoreBasedDecay = lrScoreBasedDecay; return (T) this; } public T learningRateDecayPolicy(LearningRatePolicy learningRatePolicy) { return learningRateDecayPolicy(new FixedValue<>(learningRatePolicy)); } public T learningRateDecayPolicy(ParameterSpace learningRateDecayPolicy) { this.learningRateDecayPolicy = learningRateDecayPolicy; return (T) this; } public T learningRateSchedule(Map learningRateSchedule) { return learningRateSchedule(new FixedValue<>(learningRateSchedule)); } public T learningRateSchedule(ParameterSpace> learningRateSchedule) { this.learningRateSchedule = learningRateSchedule; return (T) this; } public T lrPolicyDecayRate(double lrPolicyDecayRate) { return lrPolicyDecayRate(new FixedValue<>(lrPolicyDecayRate)); } public T lrPolicyDecayRate(ParameterSpace lrPolicyDecayRate) { this.lrPolicyDecayRate = lrPolicyDecayRate; return (T) this; } public T lrPolicyPower(double lrPolicyPower) { return lrPolicyPower(new FixedValue<>(lrPolicyPower)); } public T lrPolicyPower(ParameterSpace lrPolicyPower) { this.lrPolicyPower = lrPolicyPower; return (T) this; } public T lrPolicySteps(double lrPolicySteps) { return lrPolicySteps(new FixedValue<>(lrPolicySteps)); } public T lrPolicySteps(ParameterSpace lrPolicySteps) { this.lrPolicySteps = lrPolicySteps; return (T) this; } public T maxNumLineSearchIterations(int maxNumLineSearchIterations) { return maxNumLineSearchIterations(new FixedValue<>(maxNumLineSearchIterations)); } public T maxNumLineSearchIterations(ParameterSpace maxNumLineSearchIterations) { this.maxNumLineSearchIterations = maxNumLineSearchIterations; return (T) this; } public T miniBatch(boolean minibatch) { return miniBatch(new FixedValue<>(minibatch)); } public T miniBatch(ParameterSpace miniBatch) { this.miniBatch = miniBatch; return (T) this; } public T minimize(boolean minimize) { return minimize(new FixedValue<>(minimize)); } public T minimize(ParameterSpace minimize) { this.minimize = minimize; return (T) this; } public T stepFunction(StepFunction stepFunction) { return stepFunction(new FixedValue<>(stepFunction)); } public T stepFunction(ParameterSpace stepFunction) { this.stepFunction = stepFunction; return (T) this; } public T l1(double l1) { return l1(new FixedValue<>(l1)); } public T l1(ParameterSpace l1) { this.l1 = l1; return (T) this; } public T l2(double l2) { return l2(new FixedValue<>(l2)); } public T l2(ParameterSpace l2) { this.l2 = l2; return (T) this; } public T dropOut(double dropOut) { return dropOut(new FixedValue<>(dropOut)); } public T dropOut(ParameterSpace dropOut) { this.dropOut = dropOut; return (T) this; } public T momentum(double momentum) { return momentum(new FixedValue<>(momentum)); } public T momentum(ParameterSpace momentum) { this.momentum = momentum; return (T) this; } public T momentumAfter(Map momentumAfter) { return momentumAfter(new FixedValue<>(momentumAfter)); } public T momentumAfter(ParameterSpace> momentumAfter) { this.momentumAfter = momentumAfter; return (T) this; } public T updater(Updater updater) { return updater(new FixedValue<>(updater)); } public T updater(ParameterSpace updater) { this.updater = updater; return (T) this; } public T epsilon(double epsilon){ return epsilon(new FixedValue<>(epsilon)); } public T epsilon(ParameterSpace epsilon){ this.epsilon = epsilon; return (T) this; } public T rho(double rho) { return rho(new FixedValue<>(rho)); } public T rho(ParameterSpace rho) { this.rho = rho; return (T) this; } public T rmsDecay(double rmsDecay) { return rmsDecay(new FixedValue<>(rmsDecay)); } public T rmsDecay(ParameterSpace rmsDecay) { this.rmsDecay = rmsDecay; return (T) this; } public T gradientNormalization(GradientNormalization gradientNormalization) { return gradientNormalization(new FixedValue<>(gradientNormalization)); } public T gradientNormalization(ParameterSpace gradientNormalization) { this.gradientNormalization = gradientNormalization; return (T) this; } public T gradientNormalizationThreshold(double threshold) { return gradientNormalizationThreshold(new FixedValue<>(threshold)); } public T gradientNormalizationThreshold(ParameterSpace gradientNormalizationThreshold) { this.gradientNormalizationThreshold = gradientNormalizationThreshold; return (T) this; } public T backprop(boolean backprop) { return backprop(new FixedValue<>(backprop)); } public T backprop(ParameterSpace backprop) { this.backprop = backprop; return (T) this; } public T pretrain(boolean pretrain) { return pretrain(new FixedValue<>(pretrain)); } public T pretrain(ParameterSpace pretrain) { this.pretrain = pretrain; return (T) this; } public T backpropType(BackpropType backpropType) { return backpropType(new FixedValue<>(backpropType)); } public T backpropType(ParameterSpace backpropType) { this.backpropType = backpropType; return (T) this; } public T tbpttFwdLength(int tbpttFwdLength) { return tbpttFwdLength(new FixedValue<>(tbpttFwdLength)); } public T tbpttFwdLength(ParameterSpace tbpttFwdLength) { this.tbpttFwdLength = tbpttFwdLength; return (T) this; } public T tbpttBwdLength(int tbpttBwdLength) { return tbpttBwdLength(new FixedValue<>(tbpttBwdLength)); } public T tbpttBwdLength(ParameterSpace tbpttBwdLength) { this.tbpttBwdLength = tbpttBwdLength; return (T) this; } /** * Fixed number of training epochs. Default: 1 * Note if both EarlyStoppingConfiguration and number of epochs is present, early stopping will be used in preference. */ public T numEpochs(int numEpochs) { this.numEpochs = numEpochs; return (T) this; } public T biasInit(double biasInit) { return biasInit(new FixedValue<>(biasInit)); } public T biasInit(ParameterSpace biasInit) { this.biasInit = biasInit; return (T) this; } public T adamMeanDecay(double adamMeanDecay) { return adamMeanDecay(new FixedValue<>(adamMeanDecay)); } public T adamMeanDecay(ParameterSpace adamMeanDecay) { this.adamMeanDecay = adamMeanDecay; return (T) this; } public T adamVarDecay(double adamVarDecay) { return adamVarDecay(new FixedValue<>(adamVarDecay)); } public T adamVarDecay(ParameterSpace adamVarDecay) { this.adamVarDecay = adamVarDecay; return (T) this; } public abstract E build(); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy