org.deeplearning4j.nn.conf.graph.SubsetVertex Maven / Gradle / Ivy
/*-
*
* * Copyright 2016 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.conf.graph;
import lombok.EqualsAndHashCode;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import lombok.Data;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
/**
* SubsetVertex is used to select a subset of the activations out of another GraphVertex.
* For example, a subset of the activations out of a layer.
* Note that this subset is specifying by means of an interval of the original activations.
* For example, to get the first 10 activations of a layer (or, first 10 features out of a CNN layer) use
* new SubsetVertex(0,9).
* In the case of convolutional (4d) activations, this is done along depth.
*
* @author Alex Black
*/
@Data
public class SubsetVertex extends GraphVertex {
private int from;
private int to;
/**
* @param from The first column index, inclusive
* @param to The last column index, inclusive
*/
public SubsetVertex(@JsonProperty("from") int from, @JsonProperty("to") int to) {
this.from = from;
this.to = to;
}
@Override
public SubsetVertex clone() {
return new SubsetVertex(from, to);
}
@Override
public boolean equals(Object o) {
if (!(o instanceof SubsetVertex))
return false;
SubsetVertex s = (SubsetVertex) o;
return s.from == from && s.to == to;
}
@Override
public int hashCode() {
return new Integer(from).hashCode() ^ new Integer(to).hashCode();
}
@Override
public int numParams(boolean backprop) {
return 0;
}
@Override
public int minVertexInputs() {
return 1;
}
@Override
public int maxVertexInputs() {
return 1;
}
@Override
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx,
INDArray paramsView, boolean initializeParams) {
return new org.deeplearning4j.nn.graph.vertex.impl.SubsetVertex(graph, name, idx, from, to);
}
@Override
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
if (vertexInputs.length != 1) {
throw new InvalidInputTypeException(
"SubsetVertex expects single input type. Received: " + Arrays.toString(vertexInputs));
}
switch (vertexInputs[0].getType()) {
case FF:
return InputType.feedForward(to - from + 1);
case RNN:
return InputType.recurrent(to - from + 1);
case CNN:
InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional) vertexInputs[0];
int depth = conv.getDepth();
if (to >= depth) {
throw new InvalidInputTypeException("Invalid range: Cannot select depth subset [" + from + "," + to
+ "] inclusive from CNN activations with " + " [depth,width,height] = [" + depth
+ "," + conv.getWidth() + "," + conv.getHeight() + "]");
}
return InputType.convolutional(conv.getHeight(), conv.getWidth(), from - to + 1);
case CNNFlat:
//TODO work out how to do this - could be difficult...
throw new UnsupportedOperationException(
"Subsetting data in flattened convolutional format not yet supported");
default:
throw new RuntimeException("Unknown input type: " + vertexInputs[0]);
}
}
@Override
public MemoryReport getMemoryReport(InputType... inputTypes) {
//Get op without dup - no additional memory use
InputType outputType = getOutputType(-1, inputTypes);
return new LayerMemoryReport.Builder(null, SubsetVertex.class, inputTypes[0], outputType).standardMemory(0, 0) //No params
.workingMemory(0, 0, 0, 0).cacheMemory(0, 0) //No caching
.build();
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy