org.deeplearning4j.nn.updater.UpdaterUtils Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.updater;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.nd4j.linalg.learning.config.IUpdater;
public class UpdaterUtils {
public static boolean updaterConfigurationsEquals(Trainable layer1, String param1, Trainable layer2, String param2) {
TrainingConfig l1 = layer1.getConfig();
TrainingConfig l2 = layer2.getConfig();
IUpdater u1 = l1.getUpdaterByParam(param1);
IUpdater u2 = l2.getUpdaterByParam(param2);
//For updaters to be equal (and hence combinable), we require that:
//(a) The updater-specific configurations are equal (inc. LR, LR/momentum schedules etc)
//(b) If one or more of the params are pretrainable params, they are in the same layer
// This last point is necessary as we don't want to modify the pretrain gradient/updater state during
// backprop, or modify the pretrain gradient/updater state of one layer while training another
if (!u1.equals(u2)) {
//Different updaters or different config
return false;
}
boolean isPretrainParam1 = l1.isPretrainParam(param1);
boolean isPretrainParam2 = l2.isPretrainParam(param2);
if (isPretrainParam1 || isPretrainParam2) {
//One or both of params are pretrainable.
//Either layers differ -> don't want to combine a pretrain updaters across layers
//Or one is pretrain and the other isn't -> don't want to combine pretrain updaters within a layer
return layer1 == layer2 && isPretrainParam1 && isPretrainParam2;
}
return true;
}
}