All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.updater.UpdaterUtils Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.updater;

import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.nd4j.linalg.learning.config.IUpdater;

public class UpdaterUtils {


    public static boolean updaterConfigurationsEquals(Trainable layer1, String param1, Trainable layer2, String param2) {
        TrainingConfig l1 = layer1.getConfig();
        TrainingConfig l2 = layer2.getConfig();
        IUpdater u1 = l1.getUpdaterByParam(param1);
        IUpdater u2 = l2.getUpdaterByParam(param2);

        //For updaters to be equal (and hence combinable), we require that:
        //(a) The updater-specific configurations are equal (inc. LR, LR/momentum schedules etc)
        //(b) If one or more of the params are pretrainable params, they are in the same layer
        //    This last point is necessary as we don't want to modify the pretrain gradient/updater state during
        //    backprop, or modify the pretrain gradient/updater state of one layer while training another
        if (!u1.equals(u2)) {
            //Different updaters or different config
            return false;
        }

        boolean isPretrainParam1 = l1.isPretrainParam(param1);
        boolean isPretrainParam2 = l2.isPretrainParam(param2);
        if (isPretrainParam1 || isPretrainParam2) {
            //One or both of params are pretrainable.
            //Either layers differ -> don't want to combine a pretrain updaters across layers
            //Or one is pretrain and the other isn't -> don't want to combine pretrain updaters within a layer
            return layer1 == layer2 && isPretrainParam1 && isPretrainParam2;
        }

        return true;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy