All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.graph.ReshapeVertex Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.conf.graph;

import lombok.Data;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonProperty;

import java.util.Arrays;

@Data
public class ReshapeVertex extends GraphVertex {
    public static final char DEFAULT_RESHAPE_ORDER = 'c';

    protected char reshapeOrder = 'c';
    protected int[] newShape;
    protected int[] maskShape;

    /**
     * Reshape with the default reshape order of 'c'
     * @param newShape New shape for activations
     */
    public ReshapeVertex(int... newShape){
        this(DEFAULT_RESHAPE_ORDER, newShape, null);
    }

    /**
     * @param reshapeOrder Order (must be 'c' or 'f') for the activations
     * @param newShape     New shape
     * @param maskShape    Mask shape
     */
    public ReshapeVertex(@JsonProperty("reshapeOrder") char reshapeOrder, @JsonProperty("newShape") int[] newShape,
                         @JsonProperty("maskShape") int[] maskShape) {
        Preconditions.checkState(reshapeOrder == 'c' || reshapeOrder == 'f', "Reshape order must be 'c' or 'f'. Got: '%s'", String.valueOf(reshapeOrder));
        this.reshapeOrder = reshapeOrder;
        this.newShape = newShape;
        this.maskShape = maskShape;
    }

    @Override
    public ReshapeVertex clone() {
        return new ReshapeVertex(newShape);
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof ReshapeVertex))
            return false;
        return Arrays.equals(((ReshapeVertex) o).newShape, newShape);
    }

    @Override
    public int hashCode() {
        return reshapeOrder ^ Arrays.hashCode(newShape);
    }

    @Override
    public long numParams(boolean backprop) {
        return 0;
    }

    @Override
    public int minVertexInputs() {
        return 1;
    }

    @Override
    public int maxVertexInputs() {
        return 1;
    }

    @Override
    public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx,
                                                                      INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
        return new org.deeplearning4j.nn.graph.vertex.impl.ReshapeVertex(graph, name, idx, reshapeOrder, newShape, maskShape, networkDatatype);
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
        //Infer output shape from specified shape:
        switch (newShape.length) {
            case 2:
                return InputType.feedForward(newShape[1]);
            case 3:
                return InputType.recurrent(newShape[1]);
            case 4:
                return InputType.convolutional(newShape[2], newShape[3], newShape[1]); //[mb,d,h,w] for activations
            default:
                throw new UnsupportedOperationException(
                                "Cannot infer input type for reshape array " + Arrays.toString(newShape));
        }
    }

    @Override
    public MemoryReport getMemoryReport(InputType... inputTypes) {
        //Assume it's a reshape-with-copy op. In this case: memory use is accounted for in activations
        InputType outputType = getOutputType(-1, inputTypes);
        return new LayerMemoryReport.Builder(null, ReshapeVertex.class, inputTypes[0], outputType).standardMemory(0, 0) //No params
                        .workingMemory(0, 0, 0, 0).cacheMemory(0, 0) //No caching
                        .build();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy