All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.graph.vertex.GraphVertex Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.graph.vertex;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;

import java.io.Serializable;
import java.util.Map;

public interface GraphVertex extends Trainable, Serializable {

    /** Get the name/label of the GraphVertex
     */
    String getVertexName();

    /** Get the index of the GraphVertex */
    int getVertexIndex();

    /** Get the number of input arrays. For example, a Layer may have only one input array, but in general a GraphVertex
     * may have an arbtrary (>=1) number of input arrays (for example, from multiple other layers)
     */
    int getNumInputArrays();

    /** Get the number of outgoing connections from this GraphVertex. A GraphVertex may only have a single output (for
     * example, the activations out of a layer), but this output may be used as the input to an arbitrary number of other
     * GraphVertex instances. This method returns the number of GraphVertex instances the output of this GraphVertex is input for.
     */
    int getNumOutputConnections();

    /** A representation of the vertices that are inputs to this vertex (inputs duing forward pass)
* Specifically, if inputVertices[X].getVertexIndex() = Y, and inputVertices[X].getVertexEdgeNumber() = Z * then the Zth output connection (see {@link #getNumOutputConnections()} of vertex Y is the Xth input to this vertex */ VertexIndices[] getInputVertices(); /** Sets the input vertices. * @see #getInputVertices() */ void setInputVertices(VertexIndices[] inputVertices); /** A representation of the vertices that this vertex is connected to (outputs duing forward pass) * Specifically, if outputVertices[X].getVertexIndex() = Y, and outputVertices[X].getVertexEdgeNumber() = Z * then the Xth output of this vertex is connected to the Zth input of vertex Y */ VertexIndices[] getOutputVertices(); /** set the output vertices. * @see #getOutputVertices() */ void setOutputVertices(VertexIndices[] outputVertices); /** Whether the GraphVertex contains a {@link Layer} object or not */ boolean hasLayer(); /** Whether the GraphVertex is an input vertex */ boolean isInputVertex(); /** Whether the GraphVertex is an output vertex */ boolean isOutputVertex(); /** Set the GraphVertex to be an output vertex */ void setOutputVertex(boolean outputVertex); /** Get the Layer (if any). Returns null if {@link #hasLayer()} == false */ Layer getLayer(); /** Set the input activations. * @param inputNumber Must be in range 0 to {@link #getNumInputArrays()}-1 * @param input The input array * @param workspaceMgr */ void setInput(int inputNumber, INDArray input, LayerWorkspaceMgr workspaceMgr); /** Set the errors (epsilon - aka dL/dActivation) for this GraphVertex */ void setEpsilon(INDArray epsilon); /** Clear the internal state (if any) of the GraphVertex. For example, any stored inputs/errors */ void clear(); /** Whether the GraphVertex can do forward pass. Typically, this is just whether all inputs are set. */ boolean canDoForward(); /** Whether the GraphVertex can do backward pass. Typically, this is just whether all errors/epsilons are set */ boolean canDoBackward(); /** Do forward pass using the stored inputs * @param training if true: forward pass at training time. If false: forward pass at test time * @return The output (for example, activations) of the GraphVertex */ INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr); /** Do backward pass * @param tbptt If true: do backprop using truncated BPTT * @return The gradients (may be null), and the errors/epsilons for all inputs to this GraphVertex */ Pair doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr); /** Get the array of inputs previously set for this GraphVertex */ INDArray[] getInputs(); /** Get the epsilon/error (i.e., dL/dOutput) array previously set for this GraphVertex */ INDArray getEpsilon(); /** Set all inputs for this GraphVertex * @see #setInput(int, INDArray, LayerWorkspaceMgr) */ void setInputs(INDArray... inputs); /** * See {@link Layer#setBackpropGradientsViewArray(INDArray)} * @param backpropGradientsViewArray */ void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray); Pair feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize); /** * Only applies to layer vertices. Will throw exceptions on others. * If applied to a layer vertex it will treat the parameters of the layer within it as constant. * Activations through these will be calculated as they would as test time regardless of training mode */ void setLayerAsFrozen(); /** * This method clears inpjut for this vertex */ void clearVertex(); /** * Get the parameter table for the vertex * @param backpropOnly If true: exclude unsupervised training parameters * @return Parameter table */ Map paramTable(boolean backpropOnly); }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy