All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.layers.FrozenLayer Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.layers;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.misc.DummyConfig;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.OneTimeLogger;

@Slf4j
public class FrozenLayer extends BaseWrapperLayer {

    private boolean logUpdate = false;
    private boolean logFit = false;
    private boolean logTestMode = false;
    private boolean logGradient = false;
    private Gradient zeroGradient;
    private transient DummyConfig config;

    public FrozenLayer(Layer insideLayer) {
        super(insideLayer);
        if (insideLayer instanceof OutputLayer) {
            throw new IllegalArgumentException("Output Layers are not allowed to be frozen " + layerId());
        }
        this.zeroGradient = new DefaultGradient(insideLayer.params());
        if (insideLayer.paramTable() != null) {
            for (String paramType : insideLayer.paramTable().keySet()) {
                //save memory??
                zeroGradient.setGradientFor(paramType, null);
            }
        }
    }

    @Override
    public void setCacheMode(CacheMode mode) {
        // no-op
    }

    protected String layerId() {
        String name = underlying.conf().getLayer().getLayerName();
        return "(layer name: " + (name == null ? "\"\"" : name) + ", layer index: " + underlying.getIndex() + ")";
    }

    @Override
    public double calcRegularizationScore(boolean backpropParamsOnly){
        return 0;
    }

    @Override
    public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        return new Pair<>(zeroGradient, null);
    }
    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        logTestMode(training);
        return underlying.activate(false, workspaceMgr);
    }

    @Override
    public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) {
        logTestMode(training);
        return underlying.activate(input, false, workspaceMgr);
    }

    @Override
    public void fit() {
        if (!logFit) {
            OneTimeLogger.info(log, "Frozen layers cannot be fit. Warning will be issued only once per instance");
            logFit = true;
        }
        //no op
    }

    @Override
    public void update(Gradient gradient) {
        if (!logUpdate) {
            OneTimeLogger.info(log, "Frozen layers will not be updated. Warning will be issued only once per instance");
            logUpdate = true;
        }
        //no op
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        if (!logUpdate) {
            OneTimeLogger.info(log, "Frozen layers will not be updated. Warning will be issued only once per instance");
            logUpdate = true;
        }
        //no op
    }
    @Override
    public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {
        if (!logGradient) {
            OneTimeLogger.info(log,
                            "Gradients for the frozen layer are not set and will therefore will not be updated.Warning will be issued only once per instance");
            logGradient = true;
        }
        underlying.score();
        //no op
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        if (!logGradient) {
            OneTimeLogger.info(log,
                            "Gradients for the frozen layer are not set and will therefore will not be updated.Warning will be issued only once per instance");
            logGradient = true;
        }
        //no-op
    }

    @Override
    public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr) {
        if (!logFit) {
            OneTimeLogger.info(log, "Frozen layers cannot be fit.Warning will be issued only once per instance");
            logFit = true;
        }
    }

    @Override
    public Gradient gradient() {
        return zeroGradient;
    }

    //FIXME
    @Override
    public Pair gradientAndScore() {
        if (!logGradient) {
            OneTimeLogger.info(log,
                            "Gradients for the frozen layer are not set and will therefore will not be updated.Warning will be issued only once per instance");
            logGradient = true;
        }
        return new Pair<>(zeroGradient, underlying.score());
    }

    @Override
    public void applyConstraints(int iteration, int epoch) {
        //No-op
    }

    /**
     * Init the model
     */
    @Override
    public void init() {

    }

    public void logTestMode(boolean training) {
        if (!training)
            return;
        if (logTestMode) {
            return;
        } else {
            OneTimeLogger.info(log,
                            "Frozen layer instance found! Frozen layers are treated as always in test mode. Warning will only be issued once per instance");
            logTestMode = true;
        }
    }

    public void logTestMode(TrainingMode training) {
        if (training.equals(TrainingMode.TEST))
            return;
        if (logTestMode) {
            return;
        } else {
            OneTimeLogger.info(log,
                            "Frozen layer instance found! Frozen layers are treated as always in test mode. Warning will only be issued once per instance");
            logTestMode = true;
        }
    }

    public Layer getInsideLayer() {
        return underlying;
    }

    @Override
    public TrainingConfig getConfig(){
        if (config == null) {
            config = new DummyConfig(getUnderlying().getConfig().getLayerName());
        }
        return config;
    }
}






© 2015 - 2024 Weber Informatics LLC | Privacy Policy