All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.weights.WeightInit Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.weights;

import org.deeplearning4j.nn.conf.distribution.Distribution;

public enum WeightInit {
    DISTRIBUTION, ZERO, ONES, SIGMOID_UNIFORM, NORMAL, LECUN_NORMAL, UNIFORM, XAVIER, XAVIER_UNIFORM, XAVIER_FAN_IN, XAVIER_LEGACY, RELU,
    RELU_UNIFORM, IDENTITY, LECUN_UNIFORM, VAR_SCALING_NORMAL_FAN_IN, VAR_SCALING_NORMAL_FAN_OUT, VAR_SCALING_NORMAL_FAN_AVG,
    VAR_SCALING_UNIFORM_FAN_IN, VAR_SCALING_UNIFORM_FAN_OUT, VAR_SCALING_UNIFORM_FAN_AVG;


    /**
     * Create an instance of the weight initialization function
     *
     * @return a new {@link IWeightInit} instance
     */
    public IWeightInit getWeightInitFunction() {
        return getWeightInitFunction(null);
    }

    /**
     * Create an instance of the weight initialization function
     *
     * @param distribution Distribution of the weights (Only used in case DISTRIBUTION)
     * @return a new {@link IWeightInit} instance
     */
    public IWeightInit getWeightInitFunction(Distribution distribution) {
        switch (this) {
            case ZERO:
                return new WeightInitConstant(0.0);
            case ONES:
                return new WeightInitConstant(1.0);
            case DISTRIBUTION:
                return new WeightInitDistribution(distribution);
            case SIGMOID_UNIFORM:
                return new WeightInitSigmoidUniform();
            case LECUN_NORMAL: //Fall through: these 3 are equivalent
            case XAVIER_FAN_IN:
            case NORMAL:
                return new WeightInitNormal();
            case UNIFORM:
                return new WeightInitUniform();
            case XAVIER:
                return new WeightInitXavier();
            case XAVIER_UNIFORM:
                return new WeightInitXavierUniform();
            case XAVIER_LEGACY:
                return new WeightInitXavierLegacy();
            case RELU:
                return new WeightInitRelu();
            case RELU_UNIFORM:
                return new WeightInitReluUniform();
            case IDENTITY:
                return new WeightInitIdentity();
            case LECUN_UNIFORM:
                return new WeightInitLecunUniform();
            case VAR_SCALING_NORMAL_FAN_IN:
                return new WeightInitVarScalingNormalFanIn();
            case VAR_SCALING_NORMAL_FAN_OUT:
                return new WeightInitVarScalingNormalFanOut();
            case VAR_SCALING_NORMAL_FAN_AVG:
                return new WeightInitVarScalingNormalFanAvg();
            case VAR_SCALING_UNIFORM_FAN_IN:
                return new WeightInitVarScalingUniformFanIn();
            case VAR_SCALING_UNIFORM_FAN_OUT:
                return new WeightInitVarScalingUniformFanOut();
            case VAR_SCALING_UNIFORM_FAN_AVG:
                return new WeightInitVarScalingUniformFanAvg();

            default:
                throw new UnsupportedOperationException("Unknown or not supported weight initialization function: " + this);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy