All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.conf.layers.samediff;

import lombok.Data;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;

import java.util.List;
import java.util.Map;

@Data
public abstract class SameDiffVertex extends GraphVertex implements TrainingConfig {

    private SDVertexParams vertexParams;
    private String name;

    protected List regularization;
    protected List regularizationBias;
    protected IUpdater updater;
    protected IUpdater biasUpdater;
    protected GradientNormalization gradientNormalization;
    protected double gradientNormalizationThreshold = Double.NaN;
    protected DataType dataType;

    /**
     * Define the vertex
     * @param sameDiff   SameDiff instance
     * @param layerInput Input to the layer - keys as defined by {@link #defineParametersAndInputs(SDVertexParams)}
     * @param paramTable Parameter table - keys as defined by {@link #defineParametersAndInputs(SDVertexParams)}
     * @param maskVars  Masks of input, if available - keys as defined by {@link #defineParametersAndInputs(SDVertexParams)}
     * @return The final layer variable corresponding to the activations/output from the forward pass
     */
    public abstract SDVariable defineVertex(SameDiff sameDiff, Map layerInput,
                                            Map paramTable, Map maskVars);

    /**
     * Define the parameters - and inputs - for the network.
     * Use {@link SDVertexParams#addWeightParam(String, long...)} and
     * {@link SDVertexParams#addBiasParam(String, long...)}.
     * Note also you must define (and optionally name) the inputs to the vertex. This is required so that
     * DL4J knows how many inputs exists for the vertex.
     * @param params Object used to set parameters for this layer
     */
    public abstract void defineParametersAndInputs(SDVertexParams params);

    /**
     * Set the initial parameter values for this layer, if required
     * @param params Parameter arrays that may be initialized
     */
    public abstract void initializeParameters(Map params);

    public SDVertexParams getVertexParams() {
        if (vertexParams == null) {
            vertexParams = new SDVertexParams();
            defineParametersAndInputs(vertexParams);
        }
        return vertexParams;
    }

    @Override
    public long numParams(boolean backprop) {
        SDLayerParams params = getVertexParams();
        long count = 0;
        for (long[] l : params.getParamShapes().values()) {
            count += ArrayUtil.prodLong(l);
        }
        return (int) count;
    }

    @Override
    public int minVertexInputs() {
        return 1;
    }

    @Override
    public int maxVertexInputs() {
        return -1;
    }

    @Override
    public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx,
                                                                      INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
        this.name = name;
        return new SameDiffGraphVertex(this, graph, name, idx, paramsView, initializeParams, networkDatatype);
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    public Pair feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize) {
        throw new UnsupportedOperationException("Not yet supported");
    }

    /**
     * Validate input arrays to confirm that they fulfill the assumptions of the layer. If they don't, throw an exception.
     * @param input inputs to the layer
     */
    public void validateInput(INDArray[] input){/* no-op */}

    @Override
    public MemoryReport getMemoryReport(InputType... inputTypes) {
        return null;
    }


    public char paramReshapeOrder(String paramName) {
        return 'c';
    }


    public void applyGlobalConfig(NeuralNetConfiguration.Builder b) {
        if(regularization == null || regularization.isEmpty()){
            regularization = b.getRegularization();
        }
        if(regularizationBias == null || regularizationBias.isEmpty()){
            regularizationBias = b.getRegularizationBias();
        }
        if (updater == null) {
            updater = b.getIUpdater();
        }
        if (biasUpdater == null) {
            biasUpdater = b.getBiasUpdater();
        }
        if (gradientNormalization == null) {
            gradientNormalization = b.getGradientNormalization();
        }
        if (Double.isNaN(gradientNormalizationThreshold)) {
            gradientNormalizationThreshold = b.getGradientNormalizationThreshold();
        }

        applyGlobalConfigToLayer(b);
    }

    public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) {
        //Default implementation: no op
    }

    @Override
    public String getLayerName() {
        return name;
    }

    @Override
    public List getRegularizationByParam(String paramName){
        if((regularization == null || regularization.isEmpty()) && (regularizationBias == null || regularizationBias.isEmpty())){
            return null;
        }
        if (getVertexParams().isWeightParam(paramName)) {
            return regularization;
        }
        if (getVertexParams().isBiasParam(paramName)) {
            return regularizationBias;
        }
        throw new IllegalStateException("Unknown parameter name: " + paramName + " - not in weights ("
                + getVertexParams().getWeightParameterKeys() + ") or biases ("
                + getVertexParams().getBiasParameterKeys() + ")");
    }

    @Override
    public boolean isPretrainParam(String paramName) {
        return false;
    }

    @Override
    public IUpdater getUpdaterByParam(String paramName) {
        if (getVertexParams().isWeightParam(paramName)) {
            return updater;
        }
        if (getVertexParams().isBiasParam(paramName)) {
            if (biasUpdater == null) {
                return updater;
            }
            return biasUpdater;
        }
        throw new IllegalStateException("Unknown parameter name: " + paramName + " - not in weights ("
                        + getVertexParams().getWeightParameterKeys() + ") or biases ("
                        + getVertexParams().getBiasParameterKeys() + ")");
    }

    @Override
    public GradientNormalization getGradientNormalization() {
        return gradientNormalization;
    }

    @Override
    public double getGradientNormalizationThreshold() {
        return gradientNormalizationThreshold;
    }

    @Override
    public void setDataType(DataType dataType) {
        this.dataType = dataType;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy