All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.arbiter.layers.BaseLayerSpace Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta7
Show newest version
/*-
 *
 *  * Copyright 2016 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */
package org.deeplearning4j.arbiter.layers;

import lombok.AccessLevel;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.deeplearning4j.arbiter.adapter.ActivationParameterSpaceAdapter;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.shade.jackson.annotation.JsonInclude;

import java.util.Map;

/**
 * BaseLayerSpace contains the common Layer hyperparameters; should match {@link BaseLayer} in terms of features
 *
 * @author Alex Black
 */
@JsonInclude(JsonInclude.Include.NON_NULL)

@Data
@EqualsAndHashCode(callSuper = true)
@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization
public abstract class BaseLayerSpace extends LayerSpace {
    protected ParameterSpace activationFunction;
    protected ParameterSpace weightInit;
    protected ParameterSpace biasInit;
    protected ParameterSpace dist;
    protected ParameterSpace learningRate;
    protected ParameterSpace biasLearningRate;
    protected ParameterSpace> learningRateAfter;
    protected ParameterSpace lrScoreBasedDecay;
    protected ParameterSpace l1;
    protected ParameterSpace l2;
    protected ParameterSpace momentum;
    protected ParameterSpace> momentumAfter;
    protected ParameterSpace updater;
    protected ParameterSpace epsilon;
    protected ParameterSpace rho;
    protected ParameterSpace rmsDecay;
    protected ParameterSpace adamMeanDecay;
    protected ParameterSpace adamVarDecay;
    protected ParameterSpace gradientNormalization;
    protected ParameterSpace gradientNormalizationThreshold;
    protected int numParameters;

    @SuppressWarnings("unchecked")
    protected BaseLayerSpace(Builder builder) {
        this.activationFunction = builder.activationFunction;
        this.weightInit = builder.weightInit;
        this.biasInit = builder.biasInit;
        this.dist = builder.dist;
        this.learningRate = builder.learningRate;
        this.biasLearningRate = builder.biasLearningRate;
        this.learningRateAfter = builder.learningRateAfter;
        this.lrScoreBasedDecay = builder.lrScoreBasedDecay;
        this.l1 = builder.l1;
        this.l2 = builder.l2;
        this.momentum = builder.momentum;
        this.momentumAfter = builder.momentumAfter;
        this.updater = builder.updater;
        this.epsilon = builder.epsilon;
        this.rho = builder.rho;
        this.rmsDecay = builder.rmsDecay;
        this.adamMeanDecay = builder.adamMeanDecay;
        this.adamVarDecay = builder.adamVarDecay;
        this.gradientNormalization = builder.gradientNormalization;
        this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold;
    }

    @Override
    public int numParameters() {
        return numParameters;
    }

    @Override
    public boolean isLeaf() {
        return false;
    }

    @Override
    public void setIndices(int... indices) {
        throw new UnsupportedOperationException("Cannot set indices for non-leaf parameter space");
    }


    protected void setLayerOptionsBuilder(BaseLayer.Builder builder, double[] values) {
        if (activationFunction != null)
            builder.activation(activationFunction.getValue(values));
        if (weightInit != null)
            builder.weightInit(weightInit.getValue(values));
        if (biasInit != null)
            builder.biasInit(biasInit.getValue(values));
        if (dist != null)
            builder.dist(dist.getValue(values));
        if (learningRate != null)
            builder.learningRate(learningRate.getValue(values));
        if (biasLearningRate != null)
            builder.biasLearningRate(biasLearningRate.getValue(values));
        if (learningRateAfter != null)
            builder.learningRateSchedule(learningRateAfter.getValue(values));
        if (lrScoreBasedDecay != null)
            builder.learningRate(lrScoreBasedDecay.getValue(values));
        if (l1 != null)
            builder.l1(l1.getValue(values));
        if (l2 != null)
            builder.l2(l2.getValue(values));
        if (dropOut != null)
            builder.dropOut(dropOut.getValue(values));
        if (momentum != null)
            builder.momentum(momentum.getValue(values));
        if (momentumAfter != null)
            builder.momentumAfter(momentumAfter.getValue(values));
        if (updater != null)
            builder.updater(updater.getValue(values));
        if (epsilon != null)
            builder.epsilon(epsilon.getValue(values));
        if (rho != null)
            builder.rho(rho.getValue(values));
        if (rmsDecay != null)
            builder.rmsDecay(rmsDecay.getValue(values));
        if (adamMeanDecay != null)
            builder.adamMeanDecay(adamMeanDecay.getValue(values));
        if (adamVarDecay != null)
            builder.adamVarDecay(adamVarDecay.getValue(values));
        if (gradientNormalization != null)
            builder.gradientNormalization(gradientNormalization.getValue(values));
        if (gradientNormalizationThreshold != null)
            builder.gradientNormalizationThreshold(gradientNormalizationThreshold.getValue(values));
    }


    @Override
    public String toString() {
        return toString(", ");
    }

    protected String toString(String delim) {
        StringBuilder sb = new StringBuilder();
        if (activationFunction != null)
            sb.append("activationFunction: ").append(activationFunction).append(delim);
        if (weightInit != null)
            sb.append("weightInit: ").append(weightInit).append(delim);
        if (biasInit != null)
            sb.append("biasInit: ").append(biasInit).append(delim);
        if (dist != null)
            sb.append("dist: ").append(dist).append(delim);
        if (learningRate != null)
            sb.append("learningRate: ").append(learningRate).append(delim);
        if (biasLearningRate != null)
            sb.append("biasLearningRate: ").append(biasLearningRate).append(delim);
        if (learningRateAfter != null)
            sb.append("learningRateAfter: ").append(learningRateAfter).append(delim);
        if (lrScoreBasedDecay != null)
            sb.append("lrScoreBasedDecay: ").append(lrScoreBasedDecay).append(delim);
        if (l1 != null)
            sb.append("l1: ").append(l1).append(delim);
        if (l2 != null)
            sb.append("l2: ").append(l2).append(delim);
        if (momentum != null)
            sb.append("momentum: ").append(momentum).append(delim);
        if (momentumAfter != null)
            sb.append("momentumAfter: ").append(momentumAfter).append(delim);
        if (updater != null)
            sb.append("updater: ").append(updater).append(delim);
        if (epsilon != null)
            sb.append("epsilon: ").append(epsilon).append(delim);
        if (rho != null)
            sb.append("rho: ").append(rho).append(delim);
        if (rmsDecay != null)
            sb.append("rmsDecay: ").append(rmsDecay).append(delim);
        if (adamMeanDecay != null)
            sb.append("adamMeanDecay: ").append(adamMeanDecay).append(delim);
        if (adamVarDecay != null)
            sb.append("adamVarDecay: ").append(adamVarDecay).append(delim);
        if (gradientNormalization != null)
            sb.append("gradientNormalization: ").append(gradientNormalization).append(delim);
        if (gradientNormalizationThreshold != null)
            sb.append("gradientNormalizationThreshold").append(gradientNormalizationThreshold);
        String s = sb.toString();

        if (s.endsWith(delim)) {
            //Remove final delimiter
            int last = s.lastIndexOf(delim);
            return s.substring(0, last);
        } else
            return s;
    }

    @SuppressWarnings("unchecked")
    public abstract static class Builder extends LayerSpace.Builder {
        protected ParameterSpace activationFunction;
        protected ParameterSpace weightInit;
        protected ParameterSpace biasInit;
        protected ParameterSpace dist;
        protected ParameterSpace learningRate;
        protected ParameterSpace biasLearningRate;
        protected ParameterSpace> learningRateAfter;
        protected ParameterSpace lrScoreBasedDecay;
        protected ParameterSpace l1;
        protected ParameterSpace l2;
        protected ParameterSpace dropOut;
        protected ParameterSpace updater;
        protected ParameterSpace momentum;
        protected ParameterSpace> momentumAfter;
        protected ParameterSpace epsilon;
        protected ParameterSpace rho;
        protected ParameterSpace rmsDecay;
        protected ParameterSpace adamMeanDecay;
        protected ParameterSpace adamVarDecay;
        protected ParameterSpace gradientNormalization;
        protected ParameterSpace gradientNormalizationThreshold;


        @Deprecated
        public T activation(String activationFunction) {
            return activation(Activation.fromString(activationFunction));
        }

        public T activation(Activation activation) {
            return activation(new FixedValue<>(activation));
        }

        public T activation(IActivation iActivation) {
            return activationFn(new FixedValue<>(iActivation));
        }

        public T activation(ParameterSpace activationFunction) {
            return activationFn(new ActivationParameterSpaceAdapter(activationFunction));
        }

        public T activationFn(ParameterSpace activationFunction) {
            this.activationFunction = activationFunction;
            return (T) this;
        }

        public T weightInit(WeightInit weightInit) {
            return (T) weightInit(new FixedValue(weightInit));
        }

        public T weightInit(ParameterSpace weightInit) {
            this.weightInit = weightInit;
            return (T) this;
        }

        public T dist(Distribution dist) {
            return dist(new FixedValue<>(dist));
        }

        public T dist(ParameterSpace dist) {
            this.dist = dist;
            return (T) this;
        }

        public T learningRate(double learningRate) {
            return learningRate(new FixedValue(learningRate));
        }

        public T learningRate(ParameterSpace learningRate) {
            this.learningRate = learningRate;
            return (T) this;
        }

        public T biasLearningRate(double biasLearningRate) {
            return biasLearningRate(new FixedValue(biasLearningRate));
        }

        public T biasLearningRate(ParameterSpace biasLearningRate) {
            this.biasLearningRate = biasLearningRate;
            return (T) this;
        }


        public T learningRateAfter(Map learningRateAfter) {
            return learningRateAfter(new FixedValue>(learningRateAfter));
        }

        public T learningRateAfter(ParameterSpace> learningRateAfter) {
            this.learningRateAfter = learningRateAfter;
            return (T) this;
        }

        public T learningRateScoreBasedDecayRate(double lrScoreBasedDecay) {
            return learningRateScoreBasedDecayRate(new FixedValue(lrScoreBasedDecay));
        }

        public T learningRateScoreBasedDecayRate(ParameterSpace lrScoreBasedDecay) {
            this.lrScoreBasedDecay = lrScoreBasedDecay;
            return (T) this;
        }

        public T l1(double l1) {
            return l1(new FixedValue(l1));
        }

        public T l1(ParameterSpace l1) {
            this.l1 = l1;
            return (T) this;
        }

        public T l2(double l2) {
            return l2(new FixedValue(l2));
        }

        public T l2(ParameterSpace l2) {
            this.l2 = l2;
            return (T) this;
        }

        public T momentum(double momentum) {
            return momentum(new FixedValue(momentum));
        }

        public T momentum(ParameterSpace momentum) {
            this.momentum = momentum;
            return (T) this;
        }

        public T momentumAfter(Map momentumAfter) {
            return momentumAfter(new FixedValue>(momentumAfter));
        }

        public T momentumAfter(ParameterSpace> momentumAfter) {
            this.momentumAfter = momentumAfter;
            return (T) this;
        }

        public T updater(Updater updater) {
            return updater(new FixedValue(updater));
        }

        public T updater(ParameterSpace updater) {
            this.updater = updater;
            return (T) this;
        }

        public T epsilon(double epsilon) {
            return epsilon(new FixedValue(epsilon));
        }

        public T epsilon(ParameterSpace epsilon) {
            this.epsilon = epsilon;
            return (T) this;
        }

        public T rho(double rho) {
            return rho(new FixedValue(rho));
        }

        public T rho(ParameterSpace rho) {
            this.rho = rho;
            return (T) this;
        }

        public T rmsDecay(double rmsDecay) {
            return rmsDecay(new FixedValue(rmsDecay));
        }

        public T rmsDecay(ParameterSpace rmsDecay) {
            this.rmsDecay = rmsDecay;
            return (T) this;
        }

        public T adamMeanDecay(double adamMeanDecay) {
            return adamMeanDecay(new FixedValue(adamMeanDecay));
        }

        public T adamMeanDecay(ParameterSpace adamMeanDecay) {
            this.adamMeanDecay = adamMeanDecay;
            return (T) this;
        }

        public T adamVarDecay(double adamVarDecay) {
            return adamVarDecay(new FixedValue(adamVarDecay));
        }

        public T adamVarDecay(ParameterSpace adamVarDecay) {
            this.adamVarDecay = adamVarDecay;
            return (T) this;
        }

        public T gradientNormalization(GradientNormalization gradientNormalization) {
            return gradientNormalization(new FixedValue(gradientNormalization));
        }

        public T gradientNormalization(ParameterSpace gradientNormalization) {
            this.gradientNormalization = gradientNormalization;
            return (T) this;
        }

        public T gradientNormalizationThreshold(double threshold) {
            return gradientNormalizationThreshold(new FixedValue<>(threshold));
        }

        public T gradientNormalizationThreshold(ParameterSpace gradientNormalizationThreshold) {
            this.gradientNormalizationThreshold = gradientNormalizationThreshold;
            return (T) this;
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy