All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.arbiter.layers.BaseLayerSpace Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta7
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.arbiter.layers;

import org.nd4j.shade.guava.base.Preconditions;
import lombok.AccessLevel;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.deeplearning4j.arbiter.adapter.ActivationParameterSpaceAdapter;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.shade.jackson.annotation.JsonInclude;

import java.util.Map;

/**
 * BaseLayerSpace contains the common Layer hyperparameters; should match {@link BaseLayer} in terms of features
 *
 * @author Alex Black
 */
@JsonInclude(JsonInclude.Include.NON_NULL)

@Data
@EqualsAndHashCode(callSuper = true)
@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization
public abstract class BaseLayerSpace extends LayerSpace {
    protected ParameterSpace activationFunction;
    protected ParameterSpace weightInit;
    protected ParameterSpace biasInit;
    protected ParameterSpace dist;
    protected ParameterSpace l1;
    protected ParameterSpace l2;
    protected ParameterSpace l1Bias;
    protected ParameterSpace l2Bias;
    protected ParameterSpace updater;
    protected ParameterSpace biasUpdater;
    protected ParameterSpace weightNoise;
    protected ParameterSpace gradientNormalization;
    protected ParameterSpace gradientNormalizationThreshold;
    protected int numParameters;

    @SuppressWarnings("unchecked")
    protected BaseLayerSpace(Builder builder) {
        super(builder);
        this.activationFunction = builder.activationFunction;
        this.weightInit = builder.weightInit;
        this.biasInit = builder.biasInit;
        this.dist = builder.dist;
        this.l1 = builder.l1;
        this.l2 = builder.l2;
        this.l1Bias = builder.l1Bias;
        this.l2Bias = builder.l2Bias;
        this.updater = builder.updater;
        this.biasUpdater = builder.biasUpdater;
        this.weightNoise = builder.weightNoise;
        this.gradientNormalization = builder.gradientNormalization;
        this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold;
    }

    @Override
    public int numParameters() {
        return numParameters;
    }

    @Override
    public boolean isLeaf() {
        return false;
    }

    @Override
    public void setIndices(int... indices) {
        throw new UnsupportedOperationException("Cannot set indices for non-leaf parameter space");
    }


    protected void setLayerOptionsBuilder(BaseLayer.Builder builder, double[] values) {
        super.setLayerOptionsBuilder(builder, values);
        if (activationFunction != null)
            builder.activation(activationFunction.getValue(values));
        if (biasInit != null)
            builder.biasInit(biasInit.getValue(values));
        if (weightInit != null)
            builder.weightInit(weightInit.getValue(values));
        if (dist != null)
            builder.dist(dist.getValue(values));
        if (l1 != null)
            builder.l1(l1.getValue(values));
        if (l2 != null)
            builder.l2(l2.getValue(values));
        if (l1Bias != null)
            builder.l1Bias(l1Bias.getValue(values));
        if (l2Bias != null)
            builder.l2Bias(l2Bias.getValue(values));
        if (updater != null)
            builder.updater(updater.getValue(values));
        if (biasUpdater != null)
            builder.biasUpdater(biasUpdater.getValue(values));
        if (weightNoise != null)
            builder.weightNoise(weightNoise.getValue(values));
        if (gradientNormalization != null)
            builder.gradientNormalization(gradientNormalization.getValue(values));
        if (gradientNormalizationThreshold != null)
            builder.gradientNormalizationThreshold(gradientNormalizationThreshold.getValue(values));
    }


    @Override
    public String toString() {
        return toString(", ");
    }

    protected String toString(String delim) {
        StringBuilder sb = new StringBuilder();

        for (Map.Entry e : getNestedSpaces().entrySet()) {
            sb.append(e.getKey()).append(": ").append(e.getValue()).append("\n");
        }
        return sb.toString();
    }

    @SuppressWarnings("unchecked")
    public abstract static class Builder extends LayerSpace.Builder {
        protected ParameterSpace activationFunction;
        protected ParameterSpace weightInit;
        protected ParameterSpace biasInit;
        protected ParameterSpace dist;
        protected ParameterSpace l1;
        protected ParameterSpace l2;
        protected ParameterSpace l1Bias;
        protected ParameterSpace l2Bias;
        protected ParameterSpace updater;
        protected ParameterSpace biasUpdater;
        protected ParameterSpace weightNoise;
        protected ParameterSpace gradientNormalization;
        protected ParameterSpace gradientNormalizationThreshold;

        public T activation(Activation... activations){
            Preconditions.checkArgument(activations.length > 0, "Activations length must be 1 or more");
            if(activations.length == 1){
                return activation(activations[0]);
            }
            return activation(new DiscreteParameterSpace<>(activations));
        }

        public T activation(Activation activation) {
            return activation(new FixedValue<>(activation));
        }

        public T activation(IActivation iActivation) {
            return activationFn(new FixedValue<>(iActivation));
        }

        public T activation(ParameterSpace activationFunction) {
            return activationFn(new ActivationParameterSpaceAdapter(activationFunction));
        }

        public T activationFn(ParameterSpace activationFunction) {
            this.activationFunction = activationFunction;
            return (T) this;
        }

        public T weightInit(WeightInit weightInit) {
            return (T) weightInit(new FixedValue(weightInit));
        }

        public T weightInit(ParameterSpace weightInit) {
            this.weightInit = weightInit;
            return (T) this;
        }

        public T weightInit(Distribution distribution){
            weightInit(WeightInit.DISTRIBUTION);
            return dist(distribution);
        }

        public T biasInit(double biasInit){
            return biasInit(new FixedValue<>(biasInit));
        }

        public T biasInit(ParameterSpace biasInit){
            this.biasInit = biasInit;
            return (T) this;
        }

        public T dist(Distribution dist) {
            return dist(new FixedValue<>(dist));
        }

        public T dist(ParameterSpace dist) {
            this.dist = dist;
            return (T) this;
        }

        public T l1(double l1) {
            return l1(new FixedValue(l1));
        }

        public T l1(ParameterSpace l1) {
            this.l1 = l1;
            return (T) this;
        }

        public T l2(double l2) {
            return l2(new FixedValue(l2));
        }

        public T l2(ParameterSpace l2) {
            this.l2 = l2;
            return (T) this;
        }

        public T l1Bias(double l1Bias) {
            return l1Bias(new FixedValue(l1Bias));
        }

        public T l1Bias(ParameterSpace l1Bias) {
            this.l1Bias = l1Bias;
            return (T) this;
        }

        public T l2Bias(double l2Bias) {
            return l2Bias(new FixedValue<>(l2Bias));
        }

        public T l2Bias(ParameterSpace l2Bias) {
            this.l2Bias = l2Bias;
            return (T) this;
        }

        public T updater(IUpdater updater) {
            return updater(new FixedValue<>(updater));
        }

        public T updater(ParameterSpace updater) {
            this.updater = updater;
            return (T) this;
        }

        public T biasUpdater(IUpdater biasUpdater) {
            return biasUpdater(new FixedValue<>(biasUpdater));
        }

        public T biasUpdater(ParameterSpace biasUpdater) {
            this.biasUpdater = biasUpdater;
            return (T) this;
        }

        public T gradientNormalization(GradientNormalization gradientNormalization) {
            return gradientNormalization(new FixedValue(gradientNormalization));
        }

        public T gradientNormalization(ParameterSpace gradientNormalization) {
            this.gradientNormalization = gradientNormalization;
            return (T) this;
        }

        public T gradientNormalizationThreshold(double threshold) {
            return gradientNormalizationThreshold(new FixedValue<>(threshold));
        }

        public T gradientNormalizationThreshold(ParameterSpace gradientNormalizationThreshold) {
            this.gradientNormalizationThreshold = gradientNormalizationThreshold;
            return (T) this;
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy