All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.graph.SubsetVertex Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*-
 *
 *  * Copyright 2016 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.nn.conf.graph;

import lombok.Data;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonProperty;

import java.util.Arrays;

/**
 * SubsetVertex is used to select a subset of the activations out of another GraphVertex.
* For example, a subset of the activations out of a layer.
* Note that this subset is specifying by means of an interval of the original activations. * For example, to get the first 10 activations of a layer (or, first 10 features out of a CNN layer) use * new SubsetVertex(0,9).
* In the case of convolutional (4d) activations, this is done along depth. * * @author Alex Black */ @Data public class SubsetVertex extends GraphVertex { private int from; private int to; /** * @param from The first column index, inclusive * @param to The last column index, inclusive */ public SubsetVertex(@JsonProperty("from") int from, @JsonProperty("to") int to) { this.from = from; this.to = to; } @Override public SubsetVertex clone() { return new SubsetVertex(from, to); } @Override public boolean equals(Object o) { if (!(o instanceof SubsetVertex)) return false; SubsetVertex s = (SubsetVertex) o; return s.from == from && s.to == to; } @Override public int hashCode() { return new Integer(from).hashCode() ^ new Integer(to).hashCode(); } @Override public int numParams(boolean backprop) { return 0; } @Override public int minVertexInputs() { return 1; } @Override public int maxVertexInputs() { return 1; } @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams) { return new org.deeplearning4j.nn.graph.vertex.impl.SubsetVertex(graph, name, idx, from, to); } @Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { if (vertexInputs.length != 1) { throw new InvalidInputTypeException( "SubsetVertex expects single input type. Received: " + Arrays.toString(vertexInputs)); } switch (vertexInputs[0].getType()) { case FF: return InputType.feedForward(to - from + 1); case RNN: return InputType.recurrent(to - from + 1); case CNN: InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional) vertexInputs[0]; int depth = conv.getDepth(); if (to >= depth) { throw new InvalidInputTypeException("Invalid range: Cannot select depth subset [" + from + "," + to + "] inclusive from CNN activations with " + " [depth,width,height] = [" + depth + "," + conv.getWidth() + "," + conv.getHeight() + "]"); } return InputType.convolutional(conv.getHeight(), conv.getWidth(), from - to + 1); case CNNFlat: //TODO work out how to do this - could be difficult... throw new UnsupportedOperationException( "Subsetting data in flattened convolutional format not yet supported"); default: throw new RuntimeException("Unknown input type: " + vertexInputs[0]); } } @Override public MemoryReport getMemoryReport(InputType... inputTypes) { //Get op without dup - no additional memory use InputType outputType = getOutputType(-1, inputTypes); return new LayerMemoryReport.Builder(null, SubsetVertex.class, inputTypes[0], outputType).standardMemory(0, 0) //No params .workingMemory(0, 0, 0, 0).cacheMemory(0, 0) //No caching .build(); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy