org.deeplearning4j.nn.graph.vertex.impl.StackVertex Maven / Gradle / Ivy
/*
*
* * Copyright 2016 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.graph.vertex.impl;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
/**
* StackVertex allows for stacking of inputs so that they may be forwarded through
* a network. This is useful for cases such as Triplet Embedding, where shared parameters
* are not supported by the network.
*
* This vertex will automatically stack all available inputs.
*
* @author Justin Long (crockpotveggies)
*/
public class StackVertex extends BaseGraphVertex {
public StackVertex(ComputationGraph graph, String name, int vertexIndex){
this(graph, name, vertexIndex, null, null);
}
public StackVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices,
VertexIndices[] outputVertices) {
super(graph, name, vertexIndex, inputVertices, outputVertices);
}
@Override
public boolean hasLayer() {
return false;
}
@Override
public boolean isOutputVertex() {
return false;
}
@Override
public Layer getLayer() {
return null;
}
@Override
public INDArray doForward(boolean training) {
// stacking along dimension 0
// inputs[] is an array of INDArray (e.g.: shape of 3 x [nExamples, nSize])
// what we want to do is make a stacked output (e.g.: [3 x nExamples, nSize])
int nStack = inputs.length;
int[] inShape = inputs[0].shape();
int[] outShape = new int[inShape.length];
// create the new shape
for ( int i=0; i doBackward(boolean tbptt) {
// this is basically doForward on UnstackVertex
if(!canDoForward()) throw new IllegalStateException("Cannot do forward pass: input not set");
int nStack = inputs.length;
INDArray[] out = new INDArray[nStack];
int step = epsilon.size(0)/nStack;
for(int i=0; i(null,out);
}
@Override
public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
if(backpropGradientsViewArray != null) throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
}
@Override
public String toString(){
return "StackVertex(id=" + this.getVertexIndex() + ",name=\"" + this.getVertexName() + ")";
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy