All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.params.BidirectionalParamInitializer Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.params;

import lombok.val;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point;

public class BidirectionalParamInitializer implements ParamInitializer {
    public static final String FORWARD_PREFIX = "f";
    public static final String BACKWARD_PREFIX = "b";

    private final Bidirectional layer;
    private final Layer underlying;

    private List paramKeys;
    private List weightKeys;
    private List biasKeys;

    public BidirectionalParamInitializer(Bidirectional layer){
        this.layer = layer;
        this.underlying = underlying(layer);
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        return numParams(conf.getLayer());
    }

    @Override
    public long numParams(Layer layer) {
        return 2 * underlying(layer).initializer().numParams(underlying(layer));
    }

    @Override
    public List paramKeys(Layer layer) {
        if(paramKeys == null) {
            Layer u = underlying(layer);
            List orig = u.initializer().paramKeys(u);
            paramKeys = withPrefixes(orig);
        }
        return paramKeys;
    }

    @Override
    public List weightKeys(Layer layer) {
        if(weightKeys == null) {
            Layer u = underlying(layer);
            List orig = u.initializer().weightKeys(u);
            weightKeys = withPrefixes(orig);
        }
        return weightKeys;
    }

    @Override
    public List biasKeys(Layer layer) {
        if(biasKeys == null) {
            Layer u = underlying(layer);
            List orig = u.initializer().weightKeys(u);
            biasKeys = withPrefixes(orig);
        }
        return biasKeys;
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return weightKeys(this.layer).contains(key);
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return biasKeys(this.layer).contains(key);
    }

    @Override
    public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        val n = paramsView.length()/2;
        INDArray paramsReshape = paramsView.reshape(paramsView.length());
        INDArray forwardView = paramsReshape.get(interval(0, n));
        INDArray backwardView = paramsReshape.get(interval(n, 2*n));

        conf.clearVariables();

        NeuralNetConfiguration c1 = conf.clone();
        NeuralNetConfiguration c2 = conf.clone();
        c1.setLayer(underlying);
        c2.setLayer(underlying);
        Map origFwd = underlying.initializer().init(c1, forwardView, initializeParams);
        Map origBwd = underlying.initializer().init(c2, backwardView, initializeParams);
        List variables = addPrefixes(c1.getVariables(), c2.getVariables());
        conf.setVariables(variables);

        Map out = new LinkedHashMap<>();
        for( Map.Entry e : origFwd.entrySet()){
            out.put(FORWARD_PREFIX + e.getKey(), e.getValue());
        }
        for( Map.Entry e : origBwd.entrySet()){
            out.put(BACKWARD_PREFIX + e.getKey(), e.getValue());
        }

        return out;
    }

    private  Map addPrefixes(Map fwd, Map bwd){
        Map out = new LinkedHashMap<>();
        for(Map.Entry e : fwd.entrySet()){
            out.put(FORWARD_PREFIX + e.getKey(), e.getValue());
        }
        for(Map.Entry e : bwd.entrySet()){
            out.put(BACKWARD_PREFIX + e.getKey(), e.getValue());
        }

        return out;
    }

    private List addPrefixes(List fwd, List bwd){
        List out = new ArrayList<>();
        for(String s : fwd){
            out.add(FORWARD_PREFIX + s);
        }
        for(String s : bwd){
            out.add(BACKWARD_PREFIX + s);
        }
        return out;
    }

    @Override
    public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        val n = gradientView.length() / 2;

        INDArray gradientsViewReshape = gradientView.reshape(gradientView.length());
        INDArray forwardView = gradientsViewReshape.get(interval(0, n));
        INDArray backwardView = gradientsViewReshape.get(interval(n, 2*n));

        Map origFwd = underlying.initializer().getGradientsFromFlattened(conf, forwardView);
        Map origBwd = underlying.initializer().getGradientsFromFlattened(conf, backwardView);

        Map out = new LinkedHashMap<>();
        for( Map.Entry e : origFwd.entrySet()) {
            out.put(FORWARD_PREFIX + e.getKey(), e.getValue());
        }
        for( Map.Entry e : origBwd.entrySet()){
            out.put(BACKWARD_PREFIX + e.getKey(), e.getValue());
        }

        return out;
    }

    private Layer underlying(Layer layer){
        Bidirectional b = (Bidirectional)layer;
        return b.getFwd();
    }

    private List withPrefixes(List orig){
        List out = new ArrayList<>();
        for(String s : orig){
            out.add(FORWARD_PREFIX + s);
        }
        for(String s : orig){
            out.add(BACKWARD_PREFIX + s);
        }
        return out;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy