All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.conf.layers.samediff;

import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.*;


public abstract class SameDiffLambdaVertex extends SameDiffVertex {

    protected transient VertexInputs inputs;

    /**
     * The defineVertex method is used to define the foward pass for the vertex
     *
     * @param sameDiff SameDiff instance to use to define the vertex
     * @param inputs   Layer input variable
     * @return The output variable (orresponding to the output activations for the vertex)
     */
    public abstract SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs);

    @Override
    public SDVariable defineVertex(SameDiff sameDiff, Map layerInput,
                                   Map paramTable, Map maskVars) {
        VertexInputs vi = getInputs(sameDiff);
        int i = 0;
        if (vi.map.size() == 0 && layerInput.size() > 0) {
            for (SDVariable v : layerInput.values()) {
                vi.map.put(i++, v);
            }
        }
        return defineVertex(sameDiff, getInputs(sameDiff));
    }

    @Override
    public void defineParametersAndInputs(SDVertexParams params) {
        //Parameters are no op, for lamda vertex - but inputs are NOT
        SameDiff temp = SameDiff.create();
        VertexInputs tempInputs = new VertexInputs(temp);
        defineVertex(temp, tempInputs);
        List list = new ArrayList<>();
        for (Integer i : tempInputs.map.keySet()) {
            list.add(tempInputs.map.get(i).name());
        }
        params.defineInputs(list.toArray(new String[list.size()]));
    }

    @Override
    public void initializeParameters(Map params) {
        //No op, for lambda vertex
    }

    @Override
    public GraphVertex clone() {
        try {
            return getClass().getConstructor().newInstance();
        } catch (Exception e){
            throw new RuntimeException("Unable to create new instance of class " + getClass().getName() + " from no-arg constructor");
        }
    }

    protected VertexInputs getInputs(SameDiff sd) {
        if (inputs == null) {
            inputs = new VertexInputs(sd);
        }
        return inputs;
    }

    public class VertexInputs {

        private SameDiff sameDiff;
        private Map map = new LinkedHashMap<>();

        protected VertexInputs(SameDiff sd) {
            this.sameDiff = sd;
        }

        public SDVariable getInput(int inputNum) {
            Preconditions.checkArgument(inputNum >= 0, "Input number must be >= 0." + "Got: %s", inputNum);

            if (!map.containsKey(inputNum)) {
                //Lazily define extra input variable as required
                SDVariable var = sameDiff.var("var_" + inputNum, dataType, -1); //TODO is this shape safe?
                map.put(inputNum, var);
            }

            return map.get(inputNum);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy