All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.arbiter.layers.LayerSpace Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta7
Show newest version
/*
 *
 *  * Copyright 2016 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */
package org.deeplearning4j.arbiter.layers;

import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.util.CollectionUtils;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.WeightInit;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * LayerSpace contains the common Layre hyperparameters; should match {@link Layer} in terms of features
 *
 * @author Alex Black
 */
public abstract class LayerSpace implements ParameterSpace {

    protected ParameterSpace activationFunction;
    protected ParameterSpace weightInit;
    protected ParameterSpace biasInit;
    protected ParameterSpace dist;
    protected ParameterSpace learningRate;
    protected ParameterSpace biasLearningRate;
    protected ParameterSpace> learningRateAfter;
    protected ParameterSpace lrScoreBasedDecay;
    protected ParameterSpace l1;
    protected ParameterSpace l2;
    protected ParameterSpace dropOut;
    protected ParameterSpace momentum;
    protected ParameterSpace> momentumAfter;
    protected ParameterSpace updater;
    protected ParameterSpace epsilon;
    protected ParameterSpace rho;
    protected ParameterSpace rmsDecay;
    protected ParameterSpace adamMeanDecay;
    protected ParameterSpace adamVarDecay;
    protected ParameterSpace gradientNormalization;
    protected ParameterSpace gradientNormalizationThreshold;

    protected int numParameters;

    @SuppressWarnings("unchecked")
    protected LayerSpace(Builder builder) {
        this.activationFunction = builder.activationFunction;
        this.weightInit = builder.weightInit;
        this.biasInit = builder.biasInit;
        this.dist = builder.dist;
        this.learningRate = builder.learningRate;
        this.biasLearningRate = builder.biasLearningRate;
        this.learningRateAfter = builder.learningRateAfter;
        this.lrScoreBasedDecay = builder.lrScoreBasedDecay;
        this.l1 = builder.l1;
        this.l2 = builder.l2;
        this.dropOut = builder.dropOut;
        this.momentum = builder.momentum;
        this.momentumAfter = builder.momentumAfter;
        this.updater = builder.updater;
        this.epsilon = builder.epsilon;
        this.rho = builder.rho;
        this.rmsDecay = builder.rmsDecay;
        this.adamMeanDecay = builder.adamMeanDecay;
        this.adamVarDecay = builder.adamVarDecay;
        this.gradientNormalization = builder.gradientNormalization;
        this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold;
    }

//    public abstract L randomLayer();

    public List collectLeaves() {
        List list = new ArrayList<>();
        if (activationFunction != null) list.addAll(activationFunction.collectLeaves());
        if (weightInit != null) list.addAll(weightInit.collectLeaves());
        if (biasInit != null) list.addAll(biasInit.collectLeaves());
        if (dist != null) list.addAll(dist.collectLeaves());
        if (learningRate != null) list.addAll(learningRate.collectLeaves());
        if (biasLearningRate != null) list.addAll(biasLearningRate.collectLeaves());
        if (learningRateAfter != null) list.addAll(learningRateAfter.collectLeaves());
        if (lrScoreBasedDecay != null) list.addAll(lrScoreBasedDecay.collectLeaves());
        if (l1 != null) list.addAll(l1.collectLeaves());
        if (l2 != null) list.addAll(l2.collectLeaves());
        if (dropOut != null) list.addAll(dropOut.collectLeaves());
        if (momentum != null) list.addAll(momentum.collectLeaves());
        if (momentumAfter != null) list.addAll(momentumAfter.collectLeaves());
        if (updater != null) list.addAll(updater.collectLeaves());
        if (epsilon != null) list.addAll(epsilon.collectLeaves());
        if (rho != null) list.addAll(rho.collectLeaves());
        if (rmsDecay != null) list.addAll(rmsDecay.collectLeaves());
        if (adamMeanDecay != null) list.addAll(adamMeanDecay.collectLeaves());
        if (adamVarDecay != null) list.addAll(adamVarDecay.collectLeaves());
        if (gradientNormalization != null) list.addAll(gradientNormalization.collectLeaves());
        if (gradientNormalizationThreshold != null) list.addAll(gradientNormalizationThreshold.collectLeaves());
        return list;
    }

    @Override
    public int numParameters() {
        return numParameters;
    }

    @Override
    public boolean isLeaf() {
        return false;
    }

    @Override
    public void setIndices(int... indices) {
        throw new UnsupportedOperationException("Cannot set indices for non-leaf parameter space");
    }


    protected void setLayerOptionsBuilder(Layer.Builder builder, double[] values) {
        if (activationFunction != null) builder.activation(activationFunction.getValue(values));
        if (weightInit != null) builder.weightInit(weightInit.getValue(values));
        if (biasInit != null) builder.biasInit(biasInit.getValue(values));
        if (dist != null) builder.dist(dist.getValue(values));
        if (learningRate != null) builder.learningRate(learningRate.getValue(values));
        if (biasLearningRate != null) builder.biasLearningRate(biasLearningRate.getValue(values));
        if (learningRateAfter != null) builder.learningRateSchedule(learningRateAfter.getValue(values));
        if (lrScoreBasedDecay != null) builder.learningRate(lrScoreBasedDecay.getValue(values));
        if (l1 != null) builder.l1(l1.getValue(values));
        if (l2 != null) builder.l2(l2.getValue(values));
        if (dropOut != null) builder.dropOut(dropOut.getValue(values));
        if (momentum != null) builder.momentum(momentum.getValue(values));
        if (momentumAfter != null) builder.momentumAfter(momentumAfter.getValue(values));
        if (updater != null) builder.updater(updater.getValue(values));
        if (epsilon != null) builder.epsilon(epsilon.getValue(values));
        if (rho != null) builder.rho(rho.getValue(values));
        if (rmsDecay != null) builder.rmsDecay(rmsDecay.getValue(values));
        if (adamMeanDecay != null) builder.adamMeanDecay(adamMeanDecay.getValue(values));
        if (adamVarDecay != null) builder.adamVarDecay(adamVarDecay.getValue(values));
        if (gradientNormalization != null) builder.gradientNormalization(gradientNormalization.getValue(values));
        if (gradientNormalizationThreshold != null)
            builder.gradientNormalizationThreshold(gradientNormalizationThreshold.getValue(values));
    }


    @Override
    public String toString() {
        return toString(", ");
    }

    protected String toString(String delim) {
        StringBuilder sb = new StringBuilder();
        if (activationFunction != null) sb.append("activationFunction: ").append(activationFunction).append(delim);
        if (weightInit != null) sb.append("weightInit: ").append(weightInit).append(delim);
        if (biasInit != null) sb.append("biasInit: ").append(biasInit).append(delim);
        if (dist != null) sb.append("dist: ").append(dist).append(delim);
        if (learningRate != null) sb.append("learningRate: ").append(learningRate).append(delim);
        if (biasLearningRate != null) sb.append("biasLearningRate: ").append(biasLearningRate).append(delim);
        if (learningRateAfter != null) sb.append("learningRateAfter: ").append(learningRateAfter).append(delim);
        if (lrScoreBasedDecay != null) sb.append("lrScoreBasedDecay: ").append(lrScoreBasedDecay).append(delim);
        if (l1 != null) sb.append("l1: ").append(l1).append(delim);
        if (l2 != null) sb.append("l2: ").append(l2).append(delim);
        if (dropOut != null) sb.append("dropOut: ").append(dropOut).append(delim);
        if (momentum != null) sb.append("momentum: ").append(momentum).append(delim);
        if (momentumAfter != null) sb.append("momentumAfter: ").append(momentumAfter).append(delim);
        if (updater != null) sb.append("updater: ").append(updater).append(delim);
        if (epsilon != null) sb.append("epsilon: ").append(epsilon).append(delim);
        if (rho != null) sb.append("rho: ").append(rho).append(delim);
        if (rmsDecay != null) sb.append("rmsDecay: ").append(rmsDecay).append(delim);
        if (adamMeanDecay != null) sb.append("adamMeanDecay: ").append(adamMeanDecay).append(delim);
        if (adamVarDecay != null) sb.append("adamVarDecay: ").append(adamVarDecay).append(delim);
        if (gradientNormalization != null)
            sb.append("gradientNormalization: ").append(gradientNormalization).append(delim);
        if (gradientNormalizationThreshold != null)
            sb.append("gradientNormalizationThreshold").append(gradientNormalizationThreshold);
        String s = sb.toString();

        if (s.endsWith(delim)) {
            //Remove final delimiter
            int last = s.lastIndexOf(delim);
            return s.substring(0, last);
        } else return s;
    }

    //    public abstract static class Builder> {
    @SuppressWarnings("unchecked")
    public abstract static class Builder {
        protected ParameterSpace activationFunction;
        protected ParameterSpace weightInit;
        protected ParameterSpace biasInit;
        protected ParameterSpace dist;
        protected ParameterSpace learningRate;
        protected ParameterSpace biasLearningRate;
        protected ParameterSpace> learningRateAfter;
        protected ParameterSpace lrScoreBasedDecay;
        protected ParameterSpace l1;
        protected ParameterSpace l2;
        protected ParameterSpace dropOut;
        protected ParameterSpace updater;
        protected ParameterSpace momentum;
        protected ParameterSpace> momentumAfter;
        protected ParameterSpace epsilon;
        protected ParameterSpace rho;
        protected ParameterSpace rmsDecay;
        protected ParameterSpace adamMeanDecay;
        protected ParameterSpace adamVarDecay;
        protected ParameterSpace gradientNormalization;
        protected ParameterSpace gradientNormalizationThreshold;


        public T activation(String activationFunction) {
            return activation(new FixedValue(activationFunction));
        }

        public T activation(ParameterSpace activationFunction) {
            this.activationFunction = activationFunction;
            return (T) this;
        }

        public T weightInit(WeightInit weightInit) {
            return (T) weightInit(new FixedValue(weightInit));
        }

        public T weightInit(ParameterSpace weightInit) {
            this.weightInit = weightInit;
            return (T) this;
        }

        public T dist(Distribution dist) {
            return dist(new FixedValue(dist));
        }

        public T dist(ParameterSpace dist) {
            this.dist = dist;
            return (T) this;
        }

        public T learningRate(double learningRate) {
            return learningRate(new FixedValue(learningRate));
        }

        public T learningRate(ParameterSpace learningRate) {
            this.learningRate = learningRate;
            return (T) this;
        }

        public T biasLearningRate(double biasLearningRate) {
            return biasLearningRate(new FixedValue(biasLearningRate));
        }

        public T biasLearningRate(ParameterSpace biasLearningRate) {
            this.biasLearningRate = biasLearningRate;
            return (T) this;
        }


        public T learningRateAfter(Map learningRateAfter) {
            return learningRateAfter(new FixedValue>(learningRateAfter));
        }

        public T learningRateAfter(ParameterSpace> learningRateAfter) {
            this.learningRateAfter = learningRateAfter;
            return (T) this;
        }

        public T learningRateScoreBasedDecayRate(double lrScoreBasedDecay) {
            return learningRateScoreBasedDecayRate(new FixedValue(lrScoreBasedDecay));
        }

        public T learningRateScoreBasedDecayRate(ParameterSpace lrScoreBasedDecay) {
            this.lrScoreBasedDecay = lrScoreBasedDecay;
            return (T) this;
        }

        public T l1(double l1) {
            return l1(new FixedValue(l1));
        }

        public T l1(ParameterSpace l1) {
            this.l1 = l1;
            return (T) this;
        }

        public T l2(double l2) {
            return l2(new FixedValue(l2));
        }

        public T l2(ParameterSpace l2) {
            this.l2 = l2;
            return (T) this;
        }

        public T dropOut(double dropOut) {
            return dropOut(new FixedValue(dropOut));
        }

        public T dropOut(ParameterSpace dropOut) {
            this.dropOut = dropOut;
            return (T) this;
        }

        public T momentum(double momentum) {
            return momentum(new FixedValue(momentum));
        }

        public T momentum(ParameterSpace momentum) {
            this.momentum = momentum;
            return (T) this;
        }

        public T momentumAfter(Map momentumAfter) {
            return momentumAfter(new FixedValue>(momentumAfter));
        }

        public T momentumAfter(ParameterSpace> momentumAfter) {
            this.momentumAfter = momentumAfter;
            return (T) this;
        }

        public T updater(Updater updater) {
            return updater(new FixedValue(updater));
        }

        public T updater(ParameterSpace updater) {
            this.updater = updater;
            return (T) this;
        }

        public T epsilon(double epsilon) {
            return epsilon(new FixedValue(epsilon));
        }

        public T epsilon(ParameterSpace epsilon) {
            this.epsilon = epsilon;
            return (T) this;
        }

        public T rho(double rho) {
            return rho(new FixedValue(rho));
        }

        public T rho(ParameterSpace rho) {
            this.rho = rho;
            return (T) this;
        }

        public T rmsDecay(double rmsDecay) {
            return rmsDecay(new FixedValue(rmsDecay));
        }

        public T rmsDecay(ParameterSpace rmsDecay) {
            this.rmsDecay = rmsDecay;
            return (T) this;
        }

        public T adamMeanDecay(double adamMeanDecay) {
            return adamMeanDecay(new FixedValue(adamMeanDecay));
        }

        public T adamMeanDecay(ParameterSpace adamMeanDecay) {
            this.adamMeanDecay = adamMeanDecay;
            return (T) this;
        }

        public T adamVarDecay(double adamVarDecay) {
            return adamVarDecay(new FixedValue(adamVarDecay));
        }

        public T adamVarDecay(ParameterSpace adamVarDecay) {
            this.adamVarDecay = adamVarDecay;
            return (T) this;
        }

        public T gradientNormalization(GradientNormalization gradientNormalization) {
            return gradientNormalization(new FixedValue(gradientNormalization));
        }

        public T gradientNormalization(ParameterSpace gradientNormalization) {
            this.gradientNormalization = gradientNormalization;
            return (T) this;
        }

        public T gradientNormalizationThreshold(double threshold) {
            return gradientNormalizationThreshold(new FixedValue(threshold));
        }

        public T gradientNormalizationThreshold(ParameterSpace gradientNormalizationThreshold) {
            this.gradientNormalizationThreshold = gradientNormalizationThreshold;
            return (T) this;
        }

        public abstract  E build();
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy