All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.SDVariable Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.samediff;

import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.shape.CreateView;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.weightinit.WeightInitScheme;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;

@Data
@NoArgsConstructor
@Slf4j
public class SDVariable implements Serializable {

    protected SameDiff sameDiff;

    @Getter
    protected String varName;
    @Getter
    @Setter
    protected VariableType variableType;

    @Setter(AccessLevel.NONE)
    protected long[] shape;

    @Getter (AccessLevel.NONE)
    @Setter
    protected DataType dataType;

    private DifferentialFunction creator;

    // autogen_tag::sdvars::start


    public SDVariable(@NonNull String varName, @NonNull VariableType varType, @NonNull SameDiff sameDiff, long[] shape, DataType dataType){
        if(varType != VariableType.PLACEHOLDER)
            Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName);
        if(varName == null)
            varName = sameDiff.generateNewVarName(varName, 0, true);

        this.sameDiff = sameDiff;
        this.varName = varName;
        this.variableType = varType;
        this.dataType = dataType;
        this.shape = shape;
    }

    /**
     * Get the name of the SDVariable
     * @return Name of the variable
     */
    public String name(){
        return varName;
    }

    public void setVarName(String varName) {
        this.varName = varName;
    }

    /**
     * @deprecated Use {@link #name()}
     */
    @Deprecated
    public String getVarName(){
        return name();
    }

    /**
     * Returns true if this variable is a placeholder
     * @return
     */
    public boolean isPlaceHolder() {
        return variableType == VariableType.PLACEHOLDER;
    }

    public boolean isConstant(){
        return variableType == VariableType.CONSTANT;
    }

    /**
     * A getter for the allocated ndarray with this {@link SDVariable}.
     *
     * This getter will lazy initialize an array if one is not found based on the associated shape and
     * {@link WeightInitScheme} - if this is possible. If this is not possible (due to shapes being unknown, etc)
     * null is returned
     *
     * @return the {@link INDArray} associated with this variable.
     */
    public INDArray getArr() {
        return getArr(false);
    }


    // autogen_tag::sdvars::end
    /**
     * A getter for the allocated ndarray with this {@link SDVariable}.
     *
     * This getter will lazy initialize an array if one is not found based on the associated shape and
     * {@link WeightInitScheme} - if this is possible.
* If this is not possible (due to shapes being unknown, etc) either:
* (a) null is returned - if enforceExistence == false, or
* (b) an IllegalStateException is thrown, if enforceExistence == true * * @return the {@link INDArray} associated with this variable. */ public INDArray getArr(boolean enforceExistence) { if(sameDiff.arrayAlreadyExistsForVarName(getVarName())) return sameDiff.getArrForVarName(getVarName()); if(variableType == VariableType.ARRAY && enforceExistence) { throw new UnsupportedOperationException("Cannot get array for ARRAY type SDVariable - use SDVariable.exec or SameDiff.output instead"); } else if(variableType == VariableType.ARRAY) { if(sameDiff.isEagerMode()) { return sameDiff.getEagerArrForVarName(name()); } return null; } INDArray ret = sameDiff.getArrForVarName(getVarName()); if(enforceExistence && ret == null) { throw new IllegalStateException("No array exists for variable \"" + name() + "\""); } return ret; } /** * Alias for the gradient variable - same as {@link #getGradient()}. * The gradient variable is the variable that represents the derivative of the loss function with respect * to the output of this variable. I.e., if this variable is X and loss function is L, then gradient() returns the * variable representing dL/dX.
* Note that only floating point variables can have gradients. */ public SDVariable gradient() { return getGradient(); } /** * The gradient variable is the variable that represents the derivative of the loss function with respect * to the output of this variable. I.e., if this variable is X and loss function is L, then gradient() returns the * variable representing dL/dX
* Note that only floating point variables can have gradients.
* Note also that a gradient may not yet be defined, and/or if no loss function variables have been set.
* You can set the loss function variables using {@link SameDiff#setLossVariables(String...)} and then create the * gradient functions using {@link SameDiff#createGradFunction()}. Alternatively, the gradient function will be * created automatically when training is performed. */ public SDVariable getGradient() { Preconditions.checkState(dataType().isFPType(), "Cannot get gradient of %s datatype variable \"%s\": only floating" + " point variables have gradients", dataType(), getVarName()); return sameDiff.getGradForVariable(getVarName()); } /** * Returns the shape of this variable * @return Shape of the variable */ public long[] getShape() { if (variableType == VariableType.PLACEHOLDER || sameDiff.isEagerMode() && shape != null) { return shape; } else if(variableType == VariableType.VARIABLE || variableType == VariableType.CONSTANT) { if(getArr() != null) return getArr().shape(); } return null; } public void setShape(long... shape) { this.shape = shape; } public long[] placeholderShape(){ if(variableType != VariableType.PLACEHOLDER){ throw new IllegalStateException("placeholderShape() can only be used for placeholder variables: variable \"" + getVarName() + " is a variable of type " + variableType); } return shape; } public DataType dataType() { if(this.dataType == null) { //Try to infer datatype instead of returning null if(variableType != VariableType.ARRAY && getArr() != null) { this.dataType = getArr().dataType(); } else { this.dataType = DataType.UNKNOWN; } } return this.dataType; } public LongShapeDescriptor getShapeDescriptor() { return LongShapeDescriptor.fromShape(getShape(), this.dataType()); } public SDVariable castTo(@NonNull DataType dataType){ return castTo(null, dataType); } public SDVariable castTo(String name, @NonNull DataType dataType){ return sameDiff.castTo(name, this, dataType); } /** * Create a new SDVariable, the contents of which is copied from this current variable * @return The new variable */ public SDVariable dup() { return sameDiff.var(this); } /** * Return a variable with equal shape to the input, but all elements set to the specified value * * @param value Value for returned variable * @return new variable */ public SDVariable assign(Number value) { return sameDiff.scalarSet(this, value.doubleValue()); } /** * Negate op - returns a new variable with the values of the current variable negated * @return Negated variable */ public SDVariable neg(){ return sameDiff.math.neg(this); } /** * Negate op - returns a new variable with the values of the current variable negated * @param name Name of the new variable * @return Negated variable */ public SDVariable neg(String name){ return sameDiff.math().neg(name, this); } /** * See {@link #lt(String, double)} */ public SDVariable lt(double value){ return lt(null, value); } /** * Less than operation: elementwise {@code this < value}
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param value value argument to use in operation * @return Output SDVariable with values 0 (not satisfied) and 1 (where the condition is satisfied) */ public SDVariable lt(String name, double value){ return sameDiff.lt(name, this, value); } /** * See {@link #lte(String, double)} */ public SDVariable lte(double value){ return lte(null, value); } /** * Less than or equals operation: elementwise {@code this <= value}
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param value value argument to use in operation * @return Output SDVariable with values 0 (not satisfied) and 1 (where the condition is satisfied) */ public SDVariable lte(String name, double value){ return sameDiff.lte(name, this, value); } /** * See {@link #gt(String, double)} */ public SDVariable gt(double value){ return gt(null, value); } /** * Greater than operation: elementwise {@code this > value}
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param value value argument to use in operation * @return Output SDVariable with values 0 (not satisfied) and 1 (where the condition is satisfied) */ public SDVariable gt(String name, double value){ return sameDiff.gt(name, this, value); } /** * See {@link #gte(String, double)} */ public SDVariable gte(double value){ return gte(null, value); } /** * Greater than or equals operation: elementwise {@code this >= value}
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param value value argument to use in operation * @return Output SDVariable with values 0 (not satisfied) and 1 (where the condition is satisfied) */ public SDVariable gte(String name, double value){ return sameDiff.gte(name, this, value); } /** * See {@link #eq(String, double)} */ public SDVariable eq(double value){ return eq(null, value); } /** * Equals operation: elementwise {@code this == value}
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param value value argument to use in operation * @return Output SDVariable with values 0 (not satisfied) and 1 (where the condition is satisfied) */ public SDVariable eq(String name, double value){ return sameDiff.eq(name, this, value); } /** * See {@link #neq(SDVariable)} */ public SDVariable neq(double value){ return neq(null, value); } /** * Not equals operation: elementwise {@code this != value}
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param value value argument to use in operation * @return Output SDVariable with values 0 (not satisfied) and 1 (where the condition is satisfied) */ public SDVariable neq(String name, double value){ return sameDiff.neq(name, this, value); } /** * See {@link #lt(String, SDVariable)} */ public SDVariable lt(SDVariable other){ return lt(null, other); } /** * Less than operation: elementwise {@code this < y}
* If x and y arrays have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if x and y have different shapes and are broadcastable, the output shape is broadcast.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param other Variable to compare values against * @return Output SDVariable with values 0 (not satisfied) and 1 (where the condition is satisfied) */ public SDVariable lt(String name, SDVariable other){ return sameDiff.lt(name, this, other); } /** * See {@link #lte(String, SDVariable)} */ public SDVariable lte(SDVariable other){ return lte(null, other); } /** * Less than or equal to operation: elementwise {@code this <= y}
* If x and y arrays have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if x and y have different shapes and are broadcastable, the output shape is broadcast.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param other Variable to compare values against * @return Output SDVariable with values 0 (not satisfied) and 1 (where the condition is satisfied) */ public SDVariable lte(String name, SDVariable other){ return sameDiff.lte(name, this, other); } /** * See {@link #gt(String, SDVariable)} */ public SDVariable gt(SDVariable other){ return gt(null, other); } /** * Greater than operation: elementwise {@code this > y}
* If x and y arrays have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if x and y have different shapes and are broadcastable, the output shape is broadcast.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param other Variable to compare values against * @return Output SDVariable with values 0 (not satisfied) and 1 (where the condition is satisfied) */ public SDVariable gt(String name, SDVariable other){ return sameDiff.gt(name, this, other); } /** * See {@link #gte(String, SDVariable)} */ public SDVariable gte(SDVariable other){ return gte(null, other); } /** * Greater than or equal to operation: elementwise {@code this >= y}
* If x and y arrays have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if x and y have different shapes and are broadcastable, the output shape is broadcast.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param other Variable to compare values against * @return Output SDVariable with values 0 (not satisfied) and 1 (where the condition is satisfied) */ public SDVariable gte(String name, SDVariable other){ return sameDiff.gte(name, this, other); } /** * See {@link #eq(String, SDVariable)} */ public SDVariable eq(SDVariable other){ return eq(null, other); } /** * Equal to operation: elementwise {@code this == y}
* If x and y arrays have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if x and y have different shapes and are broadcastable, the output shape is broadcast.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param other Variable to compare values against * @return Output SDVariable with values 0 (not satisfied) and 1 (where the condition is satisfied) */ public SDVariable eq(String name, SDVariable other){ return sameDiff.eq(name, this, other); } /** * See {@link #neq(String, SDVariable)} */ public SDVariable neq(SDVariable other){ return neq(null, other); } /** * Not equal to operation: elementwise {@code this != y}
* If x and y arrays have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if x and y have different shapes and are broadcastable, the output shape is broadcast.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param other Variable to compare values against * @return Output SDVariable with values 0 (not satisfied) and 1 (where the condition is satisfied) */ public SDVariable neq(String name, SDVariable other){ return sameDiff.neq(name, this, other); } /** * See {@link #mmul(String, SDVariable)} */ public SDVariable mmul(SDVariable other){ return mmul(null, other); } /** * Matrix multiplication: out = mmul(this,other) * * @param name Name of the output variable * @param other Other variable to perform matrix multiplication with * @return Output variable (result of mmul) */ public SDVariable mmul(String name, SDVariable other){ return sameDiff.mmul(name, this, other); } /** * Matrix multiplication: out = mmul(this,other) * * @param name Name of the output variable * @param other Other variable to perform matrix multiplication with * @param mMulTranspose Matrix transpose configuration * @return Output variable (result of mmul) */ public SDVariable mmul(String name, SDVariable other, @NonNull MMulTranspose mMulTranspose) { return sameDiff.mmul(name, this, other, mMulTranspose.isTransposeA(), mMulTranspose.isTransposeB(), mMulTranspose.isTransposeResult()); } /** * See {@link #dot(String, SDVariable, int...)} */ public SDVariable dot(SDVariable other, int... dimensions){ return dot(null, other, dimensions); } /** * Matrix dot product: out = dot(this,other, dimensions) * * @param name Name of the output variable * @param other Other variable to perform matrix multiplication with * @return Output variable (result of mmul) */ public SDVariable dot(String name, SDVariable other, int... dimensions){ return sameDiff.dot(name, this, other, dimensions); } /** * See {@link #add(String, double)} */ public SDVariable add(double scalar) { return add(null,scalar); } /** * Scalar addition: {@code out = this + scalar}
* Output variable has the same shape as the input variable * * @param varName Output variable name * @param scalar Scalar for operation * @return Output variable */ public SDVariable add(String varName, double scalar) { val function = sameDiff.math.add(this,scalar); return sameDiff.updateVariableNameAndReference(function,varName); } /** * See {@link #add(String, SDVariable)} */ public SDVariable add(SDVariable other) { return add(null,other); } /** * Addition operation: elementwise {@code this + x}
* If this and x variables have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast. * * @param name Name of the output variable * @param x Variable to perform operation with * @return Output (result) SDVariable */ public SDVariable add(String name, SDVariable x) { val result = sameDiff.math.add(this, x); return sameDiff.updateVariableNameAndReference(result, name); } /** * For Kotlin operator interop * @see #add(String, SDVariable) */ public SDVariable plus(SDVariable other){ return add(other); } /** * For Kotlin operator interop * @see #add(String, double) */ public SDVariable plus(double other){ return add(other); } /** * See {@link #sub(String, double)} */ public SDVariable sub(double scalar) { return sub(null,scalar); } /** * Scalar subtraction: {@code out = this - scalar}
* Output variable has the same shape as the input variable * * @param varName Output variable name * @param scalar Scalar for operation * @return Output variable */ public SDVariable sub(String varName, double scalar) { val result = sameDiff.math.sub(this, scalar); return sameDiff.updateVariableNameAndReference(result, varName); } /** * See {@link #sub(String, SDVariable)} */ public SDVariable sub(SDVariable x) { return sub(null,x); } /** * Subtraction operation: elementwise {@code this - x}
* If this and x variables have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast. * * @param name Name of the output variable * @param x Variable to perform operation with * @return Output (result) SDVariable */ public SDVariable sub(String name, SDVariable x) { val result = sameDiff.math.sub(this,x); return sameDiff.updateVariableNameAndReference(result,name); } /** * For Kotlin operator interop * @see #sub(String, SDVariable) */ public SDVariable minus(SDVariable other){ return sub(other); } /** * For Kotlin operator interop * @see #sub(String, double) */ public SDVariable minus(double other){ return sub(other); } /** * See {@link #div(String,double)} */ public SDVariable div(double scalar) { return div(null,scalar); } /** * Scalar division: {@code out = this / scalar}
* Output variable has the same shape as the input variable * * @param varName Output variable name * @param scalar Scalar for operation * @return Output variable */ public SDVariable div(String varName, double scalar) { val function = sameDiff.math.div(this,scalar); return sameDiff.updateVariableNameAndReference(function,varName); } /** * See {@link #div(String, SDVariable)} */ public SDVariable div(SDVariable x) { return div(null,x); } /** * Division operation: elementwise {@code this / x}
* If this and x variables have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast. * * @param name Name of the output variable * @param x Variable to perform operation with * @return Output (result) SDVariable */ public SDVariable div(String name, SDVariable x) { val result = sameDiff.math.div(this, x); return sameDiff.updateVariableNameAndReference(result, name); } /** * Floor division operation: elementwise {@code this // x}
* If this and x variables have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast. * * @param name Name of the output variable * @param x Variable to perform operation with * @return Output (result) SDVariable */ public SDVariable fdiv(String name, SDVariable x) { val result = sameDiff.math.floorDiv(this, x); return sameDiff.updateVariableNameAndReference(result, name); } /** * Modulo operation: elementwise {@code this / x}
* If this and x variables have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast. * * @param name Name of the output variable * @param x Variable to perform operation with * @return Output (result) SDVariable */ public SDVariable mod(String name, SDVariable x) { val result = sameDiff.math.mod(this, x); return sameDiff.updateVariableNameAndReference(result, name); } /** * See {@link #mul(String, double)} */ public SDVariable mul(double scalar) { return mul(null,scalar); } /** * Scalar multiplication: {@code out = this * scalar}
* Output variable has the same shape as the input variable * * @param varName Output variable name * @param scalar Scalar for operation * @return Output variable */ public SDVariable mul(String varName, double scalar) { val function = sameDiff.math.mul(this, scalar); return sameDiff.updateVariableNameAndReference(function,varName); } /** * See {@link #mul(String, SDVariable)} */ public SDVariable mul(SDVariable x) { return mul(null,x); } /** * Multiplication operation: elementwise {@code this * x}
* If this and x variables have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast. * * @param name Name of the output variable * @param x Variable to perform operation with * @return Output (result) SDVariable */ public SDVariable mul(String name, SDVariable x) { val result = sameDiff.math.mul(this, x); return sameDiff.updateVariableNameAndReference(result,name); } /** * For Kotlin operator interop * @see #mul(String, SDVariable) */ public SDVariable times(SDVariable other){ return mul(other); } /** * For Kotlin operator interop * @see #mul(String, double) */ public SDVariable times(double other){ return mul(other); } /** * See {@link #pow(String, double)} */ public SDVariable pow(double scalar) { return pow(null, scalar); } /** * Scalar power operation: {@code out = this ^ scalar}
* Output variable has the same shape as the input variable * * @param varName Output variable name * @param scalar Scalar for operation * @return Output variable */ public SDVariable pow(String varName, double scalar) { SDVariable ret = sameDiff.math.pow(this, scalar); return sameDiff.updateVariableNameAndReference(ret, varName); } /** * See {@link #rsub(String, double)} */ public SDVariable rsub(double scalar) { return rsub(null,scalar); } /** * Scalar reverse subtraction: {@code out = scalar - this}
* Output variable has the same shape as the input variable * * @param varName Output variable name * @param scalar Scalar for operation * @return Output variable */ public SDVariable rsub(String varName, double scalar) { val function = sameDiff.math.rsub(this,scalar); return sameDiff.updateVariableNameAndReference(function,varName); } /** * See {@link #rsub(String, SDVariable)} */ public SDVariable rsub(SDVariable x) { return rsub(null,x); } /** * Reverse subtraction operation: elementwise {@code x - this}
* If this and x variables have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast. * * @param name Name of the output variable * @param x Variable to perform operation with * @return Output (result) SDVariable */ public SDVariable rsub(String name, SDVariable x) { val result = sameDiff.math.rsub(this,x); return sameDiff.updateVariableNameAndReference(result,name); } /** * See {@link #rdiv(String, double)} */ public SDVariable rdiv(double scalar) { return rdiv(null,scalar); } /** * Scalar reverse division: {@code out = scalar / this}
* Output variable has the same shape as the input variable * * @param varName Output variable name * @param scalar Scalar for operation * @return Output variable */ public SDVariable rdiv(String varName, double scalar) { val function = sameDiff.math.rdiv(this, scalar); return sameDiff.updateVariableNameAndReference(function, varName); } /** * See {@link #rdiv(String, SDVariable)} */ public SDVariable rdiv(SDVariable sameDiffVariable) { return rdiv(null,sameDiffVariable); } /** * Reverse division operation: elementwise {@code x / this}
* If this and x variables have equal shape, the output shape is the same as the inputs.
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast. * * @param name Name of the output variable * @param x Variable to perform operation with * @return Output (result) SDVariable */ public SDVariable rdiv(String name, SDVariable x) { val result = sameDiff.math.rdiv(this,x); return sameDiff.updateVariableNameAndReference(result,name); } /** * See {@link #squaredDifference(String, SDVariable)} */ public SDVariable squaredDifference(SDVariable x) { return squaredDifference(null,x); } /** * Squared difference operation: {@code (this - x)^2} * @param x Other input variable * @return squared difference between variables */ public SDVariable squaredDifference(String name, SDVariable x) { val result = sameDiff.math().squaredDifference(this, x); return sameDiff.updateVariableNameAndReference(result, name); } /** * See {@link #sum(String, boolean, int...)} */ public SDVariable sum(int... dimensions){ return sum(null, dimensions); } /** * See {@link #sum(String, boolean, int...)} */ public SDVariable sum(boolean keepDims, int... dimensions){ return sum(null, keepDims, dimensions); } /** * See {@link #sum(String, boolean, int...)} */ public SDVariable sum(String name, int... dimensions){ return sum(name, false, dimensions); } /** * Sum array reduction operation, optionally along specified dimensions.
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or * of rank (input rank) if keepdims = true */ public SDVariable sum(String name, boolean keepDims, int... dimensions){ return sameDiff.sum(name, this, keepDims, dimensions); } /** * See {@link #mean(String, boolean, int...)} */ public SDVariable mean(boolean keepDims, int... dimensions){ return mean(null, keepDims, dimensions); } /** * See {@link #mean(String, boolean, int...)} */ public SDVariable mean(String name, int... dimensions){ return mean(name, false, dimensions); } /** * See {@link #mean(String, boolean, int...)} */ public SDVariable mean(int... dimensions){ return mean(null, false, dimensions); } /** * Mean (average) array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable mean(String name, boolean keepDims, int... dimensions){ return sameDiff.mean(name, this, keepDims, dimensions); } /** * See {@link #std(String, boolean, boolean, int...)} */ public SDVariable std(boolean biasCorrected, int... dimensions){ return std(null, biasCorrected, dimensions); } /** * See {@link #std(String, boolean, boolean, int...)} */ public SDVariable std(String name, boolean biasCorrected, int... dimensions){ return sameDiff.standardDeviation(name, this, biasCorrected, dimensions); } /** * Stardard deviation array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable std(String name, boolean biasCorrected, boolean keepDims, int... dimensions){ return sameDiff.standardDeviation(name, this, biasCorrected, keepDims, dimensions); } /** * See {@link #prod(String, boolean, int...)} */ public SDVariable prod(int... dimensions){ return prod(null, dimensions); } /** * See {@link #prod(String, boolean, int...)} */ public SDVariable prod(boolean keepDims, int... dimensions){ return prod(null, keepDims, dimensions); } /** * See {@link #prod(String, boolean, int...)} */ public SDVariable prod(String name, int... dimensions){ return sameDiff.prod(name, this, dimensions); } /** * Product array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable prod(String name, boolean keepDims, int... dimensions){ return sameDiff.prod(name, this, keepDims, dimensions); } /** * See {@link #min(String, boolean, int...)} */ public SDVariable min(int... dimensions){ return min(null, dimensions); } /** * See {@link #min(String, boolean, int...)} */ public SDVariable min(boolean keepDims, int... dimensions){ return min(null, keepDims, dimensions); } /** * See {@link #min(String, boolean, int...)} */ public SDVariable min(String name, int... dimensions){ return min(name, false, dimensions); } /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable min(String name, boolean keepDims, int... dimensions){ return sameDiff.min(name, this, keepDims, dimensions); } /** * See {@link #max(String, boolean, int...)} */ public SDVariable max(int... dimensions) { return max(null, dimensions); } /** * See {@link #max(String, boolean, int...)} */ public SDVariable max(String name, int... dimensions) { return max(name, false, dimensions); } /** * See {@link #max(String, boolean, int...)} */ public SDVariable max(boolean keepDims, int... dimensions) { return max(null, keepDims, dimensions); } /** * Maximum array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable max(String name, boolean keepDims, int... dimensions) { return sameDiff.max(name, this, keepDims, dimensions); } /** * See {@link #norm1(String, boolean, int...)} */ public SDVariable norm1(int... dimensions){ return norm1(null, dimensions); } /** * See {@link #norm1(String, boolean, int...)} */ public SDVariable norm1(boolean keepDims, int... dimensions){ return norm1(null, keepDims, dimensions); } /** * See {@link #norm1(String, boolean, int...)} */ public SDVariable norm1(String name, int... dimensions){ return norm1(name, false, dimensions); } /** * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
* {@code out = sum_i abs(x[i])}
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions dimensions to reduce over * @return Output variable */ public SDVariable norm1(String name, boolean keepDims, int... dimensions) { return sameDiff.norm1(name, this, keepDims, dimensions); } /** * See {@link #norm2(String, boolean, int...)} */ public SDVariable norm2(int... dimensions){ return norm2(null, dimensions); } /** * See {@link #norm2(String, boolean, int...)} */ public SDVariable norm2(boolean keepDims, int... dimensions){ return norm2(null, keepDims, dimensions); } /** * See {@link #norm2(String, boolean, int...)} */ public SDVariable norm2(String name, int... dimensions){ return norm2(name, false, dimensions); } /** * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
* {@code out = sqrt(sum_i x[i]^2)}
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions dimensions to reduce over * @return Output variable */ public SDVariable norm2(String name, boolean keepDims, int... dimensions){ return sameDiff.norm2(name, this, keepDims, dimensions); } /** * See {@link #normmax(String, boolean, int...)} */ public SDVariable normmax(int... dimensions){ return normmax(null, dimensions); } /** * See {@link #normmax(String, boolean, int...)} */ public SDVariable normmax(boolean keepDims, int... dimensions){ return normmax(null, keepDims, dimensions); } /** * See {@link #normmax(String, boolean, int...)} */ public SDVariable normmax(String name, int... dimensions){ return normmax(name, false, dimensions); } /** * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the * specified dimensions:
* {@code out = max(abs(x[i]))}
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions dimensions to reduce over * @return Output variable */ public SDVariable normmax(String name, boolean keepDims, int... dimensions){ return sameDiff.normmax(name, this, keepDims, dimensions); } /** * See {@link #argmax(String, boolean, int...)} */ public SDVariable argmax(int... dimensions){ return argmax(null, dimensions); } /** * See {@link #argmax(String, boolean, int...)} */ public SDVariable argmax(String name, int... dimensions){ return sameDiff.argmax(name, this, dimensions); } /** * Argmax array reduction operation, optionally along specified dimensions.
* Output values are the index of the maximum value of each slice along the specified dimension.
*
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Name of the output variable * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or * of rank (input rank) if keepdims = true */ public SDVariable argmax(String name, boolean keepDims, int... dimensions) { return sameDiff.argmax(name, this, keepDims, dimensions); } /** * See {@link #argmin(String, boolean, int...)} */ public SDVariable argmin(int... dimensions){ return argmin(null, dimensions); } /** * See {@link #argmin(String, boolean, int...)} */ public SDVariable argmin(String name, int... dimensions){ return sameDiff.argmin(name, this, dimensions); } /** * Argmin array reduction operation, optionally along specified dimensions.
* Output values are the index of the minimum value of each slice along the specified dimension.
*
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Name of the output variable * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or * of rank (input rank) if keepdims = true */ public SDVariable argmin(String name, boolean keepDims, int... dimensions) { return sameDiff.argmax(name, this, keepDims, dimensions); } /** * Return the total number of elements in this array * @return */ public SDVariable length() { return sameDiff.prod(shape()); } /** * Get the shape of the array as a dynamic SDVariable * @return Shape SDVariable */ public SDVariable shape(){ return sameDiff.shape(this); } /** * Get the rank of this variable as a dynamic SDVariable * @return Rank SDVariable */ public SDVariable rank(){ return sameDiff.rank(this); } /** * Reshape the current variable to the specified (dynamic) shape. The output variable will have the same values as the * input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape) * * @param newShape New shape for variable * @return Output variable */ public SDVariable reshape(SDVariable newShape) { return sameDiff.reshape(this, newShape); } /** * Reshape the current variable to the specified (dynamic) shape. The output variable will have the same values as the * input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape) * * @param newShape New shape for variable * @return Output variable */ public SDVariable reshape(String name,SDVariable newShape) { return sameDiff.reshape(name,this, newShape); } /** * Reshape the current variable to the specified shape. The output variable will have the same values as the * input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape) * * @param newShape New shape for variable * @return Output variable */ public SDVariable reshape(int... newShape){ return sameDiff.reshape(this, ArrayUtil.toLongArray(newShape)); } /** * Reshape the current variable to the specified shape. The output variable will have the same values as the * input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape) * * @param newShape New shape for variable * @return Output variable */ public SDVariable reshape(long... newShape){ return sameDiff.reshape(this, newShape); } /** * Permute the dimensions of the current variable according to the specified permutation indices.
* Example: if the current variable has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b] * * @param dimensions The new dimension order * @return Output variable (permuted input) */ public SDVariable permute(int... dimensions){ return sameDiff.permute(this, dimensions); } public SDVariable permute(SDVariable dimensions){ return sameDiff.permute( this, dimensions); } /** * Associate the specified array with this variable * @param array Array to associate with this variable * @return This variable */ public SDVariable setArray(INDArray array){ sameDiff.associateArrayWithVariable(array, this); return this; } /** * Evaluate the result of this variable * @return */ public INDArray eval() { Map m = sameDiff.output((Map)null, name()); return m.get(name()); } /** * Evaluate the result of this variable * @return */ public INDArray eval(Map placeholders) { Map m = sameDiff.output(placeholders, name()); return m.get(name()); } @Override public String toString() { return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType + (variableType == VariableType.PLACEHOLDER && shape != null ? ",shape=" + Arrays.toString(shape): "") + ")"; } /** * Add a control dependency for this variable on the specified variable.
* Control dependencies can be used to enforce the execution order. * For example, if a control dependency X->Y exists, then Y will only be executed after X is executed - even * if Y wouldn't normally depend on the result/values of X. * * @param controlDependency Control dependency to add for this variable */ public void addControlDependency(SDVariable controlDependency){ Variable vThis = sameDiff.getVariables().get(getVarName()); Variable vCD = sameDiff.getVariables().get(controlDependency.name()); //If possible: add control dependency on ops if(vThis.getOutputOfOp() != null && vCD.getOutputOfOp() != null ){ //Op -> Op case SameDiffOp oThis = sameDiff.getOps().get(vThis.getOutputOfOp()); SameDiffOp oCD = sameDiff.getOps().get(vCD.getOutputOfOp()); if(oThis.getControlDeps() == null) oThis.setControlDeps(new ArrayList<>()); if(!oThis.getControlDeps().contains(oCD.getName())) oThis.getControlDeps().add(oCD.getName()); if(oCD.getControlDepFor() == null) oCD.setControlDepFor(new ArrayList<>()); if(!oCD.getControlDepFor().contains(oThis.getName())) oCD.getControlDepFor().add(oThis.getName()); } else { if(vThis.getOutputOfOp() != null){ //const/ph -> op case SameDiffOp oThis = sameDiff.getOps().get(vThis.getOutputOfOp()); if(oThis.getVarControlDeps() == null) oThis.setVarControlDeps(new ArrayList<>()); if(!oThis.getVarControlDeps().contains(vCD.getName())) oThis.getVarControlDeps().add(vCD.getName()); if(vCD.getControlDepsForOp() == null) vCD.setControlDepsForOp(new ArrayList<>()); if(!vCD.getControlDepsForOp().contains(oThis.getName())) vCD.getControlDepsForOp().add(oThis.getName()); } else { //const/ph -> const/ph case if(vThis.getControlDeps() == null) vThis.setControlDeps(new ArrayList<>()); if(!vThis.getControlDeps().contains(vCD.getName())) vThis.getControlDeps().add(vCD.getName()); if(vCD.getControlDepsForVar() == null) vCD.setControlDepsForVar(new ArrayList<>()); if(!vCD.getControlDepsForVar().contains(vThis.getName())) vCD.getControlDepsForVar().add(vThis.getName()); } } } /** * Get a variable with content equal to a specified sub-array of this variable.
* Can be used (for example) to get rows, columns, sub-matrices, etc. * @param indices Indices to get * @return Sub-array variable */ public SDVariable getView(SDIndex... indices) { //copy because we can mutate this internally SDVariable[] indicesVars = new SDVariable[indices.length]; for(int i = 0; i < indices.length; i++) { //convert indices to SDVariable based indices switch(indices[i].getIndexType()) { case INTERVAL: indicesVars[i] = CreateView.createInterval(sameDiff,indices[i].getIntervalBegin(),indices[i].getIntervalEnd(),indices[i].getIntervalStrides(),indices[i].isInclusive() ? 1 : 0); break; case POINT: indicesVars[i] = CreateView.createPoint(sameDiff,indices[i].getPointIndex()); break; case POINT_INPUT: indicesVars[i] = CreateView.createPoint(sameDiff,indices[i].getPointVar()); break; case INTERVAL_INPUT: indicesVars[i] = CreateView.createInterval(sameDiff,indices[i].getIntervalInputBegin(),indices[i].getIntervalInputEnd(),indices[i].getIntervalStrideInput(), indices[i].getInclusiveInput()); break; case ALL: indicesVars[i] = CreateView.createAll(sameDiff); break; default: throw new IllegalArgumentException("Illegal type " + indices[i].getIndexType()); } } return sameDiff.createView(this,indicesVars); } /** * Get a variable with content equal to a specified sub-array of this variable.
* Can be used (for example) to get rows, columns, sub-matrices, etc. * @param indices Indices to get * @return Sub-array variable */ public SDVariable get(SDIndex... indices) { int ndims = indices.length; boolean variableIndices = false; //copy because we can mutate this internally SDIndex[] inputIndices = Arrays.copyOf(indices,indices.length); indices = inputIndices; for(int i = 0; i < indices.length; i++) { if(indices[i].getIndexType() == SDIndex.IndexType.POINT_INPUT || indices[i].getIndexType() == SDIndex.IndexType.INTERVAL_INPUT) { variableIndices = true; } //convert indices to SDVariable based indices if(variableIndices && (indices[i].getIndexType() == SDIndex.IndexType.INTERVAL || indices[i].getIndexType() == SDIndex.IndexType.POINT)) { switch(indices[i].getIndexType()) { case INTERVAL: indices[i] = SDIndex.interval(sameDiff.constant(indices[i].getIntervalBegin()),sameDiff.constant(indices[i].getIntervalEnd()),sameDiff.constant(indices[i].getIntervalEnd())); break; case POINT: indices[i] = SDIndex.point(sameDiff.constant(indices[i].getPointIndex()),indices[i].isPointKeepDim()); break; } } } long[] begin = new long[ndims]; long[] end = new long[ndims]; long[] strides = new long[ndims]; int[] begin_mask_arr = new int[ndims]; int[] end_mask_arr = new int[ndims]; int[] shrink_axis_mask_arr = new int[ndims]; SDVariable beginVar = null; SDVariable endVar = null; SDVariable stridesVar = null; for (int i = 0; i < ndims; i++) { strides[i] = 1; SDIndex index = indices[i]; SDIndex.IndexType indexType = index.getIndexType(); if (indexType == SDIndex.IndexType.ALL) { begin_mask_arr[i] = 1; end_mask_arr[i] = 1; } else if (indexType == SDIndex.IndexType.POINT || indexType == SDIndex.IndexType.POINT_INPUT) { if(indexType == SDIndex.IndexType.POINT) { long pointIndex = index.getPointIndex(); begin[i] = pointIndex; end[i] = pointIndex + 1; } else if(indexType == SDIndex.IndexType.POINT_INPUT) { if(beginVar == null && endVar == null) { beginVar = index.getPointVar(); endVar = index.getPointVar().add(1.0); } else { beginVar = sameDiff.concat(0,beginVar,index.getPointVar()); endVar = sameDiff.concat(0,endVar,index.getPointVar().add(1.0)); } } if(!index.isPointKeepDim()) { shrink_axis_mask_arr[i] = 1; } } else if (indexType == SDIndex.IndexType.INTERVAL || indexType == SDIndex.IndexType.INTERVAL_INPUT) { if (index.getIntervalBegin() == null && indexType != SDIndex.IndexType.INTERVAL_INPUT) { begin_mask_arr[i] = 1; } else if(indexType == SDIndex.IndexType.INTERVAL_INPUT) { if(beginVar == null) { beginVar = index.getIntervalInputBegin(); } else { beginVar = sameDiff.concat(0,beginVar,index.getIntervalInputBegin()); } } else { begin[i] = index.getIntervalBegin(); } if (index.getIntervalEnd() == null && indexType != SDIndex.IndexType.INTERVAL_INPUT) { end_mask_arr[i] = 1; } else if(indexType == SDIndex.IndexType.INTERVAL_INPUT) { if(endVar == null) { endVar = index.getIntervalInputEnd(); } else { endVar = sameDiff.concat(0,endVar,index.getIntervalInputEnd()); } } else { end[i] = index.getIntervalEnd(); } if (index.getIntervalStrides() == null) { strides[i] = 1; if(stridesVar != null) { stridesVar = sameDiff.concat(0,stridesVar,sameDiff.constant(1).reshape(1)); } else { stridesVar = sameDiff.constant(1).reshape(1); } } else { strides[i] = index.getIntervalStrides(); if(stridesVar != null) { stridesVar = sameDiff.concat(0,stridesVar,index.getIntervalStrideInput()); } else { stridesVar = index.getIntervalStrideInput(); } } } } // convert binary int[] to int int begin_mask = binArrToInt(begin_mask_arr); int end_mask = binArrToInt(end_mask_arr); int shrink_axis = binArrToInt(shrink_axis_mask_arr); if(variableIndices) { if(stridesVar == null) { stridesVar = sameDiff.onesLike(beginVar); } return this.sameDiff.stridedSlice(this, beginVar, endVar, stridesVar, begin_mask, end_mask, 0, 0, shrink_axis); } else { return this.sameDiff.stridedSlice(this, begin, end, strides, begin_mask, end_mask, 0, 0, shrink_axis); } } public static SDVariable sliceEnd(SDVariable input,SDVariable sliceIndexInput) { SameDiff sameDiff = input.getSameDiff(); SDVariable range = sameDiff.range(sameDiff.constant(0), input.rank(), sameDiff.constant(1), DataType.INT64); //0 1 1 SDVariable mask = range.gt(0.0).castTo(DataType.INT64); SDVariable sliceMask = range.eq(0).castTo(DataType.INT64); SDVariable sliceIndex = sliceMask.mul(sliceIndexInput); SDVariable outputShape = input.shape().mul(mask).add(sliceIndex); return outputShape; } /** * Get a variable with content equal to a specified sub-array of this variable.
* Can be used (for example) to get rows, columns, sub-matrices, etc. * * This will loop over the indices (think of it as a list) and concatenate * each slice of the input array to the final result. * * Expected input for indices would be a vector with indices such as 0,1,2,3,4. * For each element in the index we then concatenate the result to the previous iteration. * * Note that this is slow and should only be used in very specific circumstances. * Otherwise {@link org.nd4j.linalg.api.ops.impl.shape.StridedSlice} will be more performant * for creating views. Many times {@link org.nd4j.linalg.api.ops.impl.shape.StridedSlice} avoids * this slower approach by directly calculating the strides of a view. * * @param indices Indices to get * @return Sub-array variable */ public SDVariable get(SDVariable indices) { SDVariable initialSize = sameDiff.zerosLike(shape()).castTo(DataType.INT64); //pull from the first slice as the starting point and concatenate each result together SDVariable startResult = sameDiff.slice(this, initialSize.castTo(DataType.INT64), sliceEnd(this, sameDiff.onesLike(shape()).castTo(DataType.INT64))); //start at 1 because we start with the initial output (basically the item at the first element in the indices) SDVariable currIteration = sameDiff.var(Nd4j.ones(1).castTo(DataType.INT32)); //this condition is normally used when you want to toss in an extra condition to terminate early SDVariable cond = sameDiff.constant("curr_cond",true); //the total length of the indices to loop till SDVariable indicesLength = indices.length(); //sub graph that uses invoke SameDiff loop = createLoopConcat(this,indices); //collect slices along the first dimension concatenating the result along the way return this.sameDiff.loopWithConditions(ControlFlow.LoopParams.builder() .functionBody(loop) .loopVars(new SDVariable[] { currIteration, indicesLength, cond, startResult, this, indices }).functionBodyInputs(new String[] { //note here all inputs are the same as the outputs, and we return the original //input concatenated with the starting input (the first slice at index 0) //and then loop over each index in the list till we get the specific result "index", "max", "cond", "input", "pullFrom", "indices" }) .functionBodyOutputs(new String[]{ "index", "max", "cond", "output", "pullFrom", "indices"}) .functionName("slices") .loopName("outputs") //note the ordering here is important. Output is the accumulated output of each iteration appending //a result to the previous iteration. We start with the initial input and add more overtime. .build())[3]; } /** * Get a variable with content equal to a specified sub-array of this variable.
* Can be used (for example) to get rows, columns, sub-matrices, etc. * * This will loop over the indices (think of it as a list) and add each slice * specified by the indices from the source to the new array. * * The end result will be this variable but with the new updated results. * * Note that this is slow and should only be used in very specific circumstances. * Otherwise {@link org.nd4j.linalg.api.ops.impl.shape.StridedSlice} will be more performant * for creating views. Many times {@link org.nd4j.linalg.api.ops.impl.shape.StridedSlice} avoids * this slower approach by directly calculating the strides of a view. * * @param indices Indices to get * @param toPut the source array to pull results from to put in to this array * @param putIndices the equivalent indices for the other array * @return the updated array with the elements from the toPut array put in to this new array */ public SDVariable put(SDVariable indices,SDVariable toPut,SDVariable putIndices) { //start at 1 because we start with the initial output (basically the item at the first element in the indices) SDVariable currIteration = sameDiff.var(Nd4j.zeros(1).castTo(DataType.INT32)); //this condition is normally used when you want to toss in an extra condition to terminate early SDVariable cond = sameDiff.constant(true); //the total length of the indices to loop till SDVariable indicesLength = indices.length(); //sub graph that uses invoke SameDiff loop = createLoopPut(this,indices); loop.setEnableCache(false); //collect slices along the first dimension concatenating the result along the way return this.sameDiff.loopWithConditions(ControlFlow.LoopParams.builder() .functionBody(loop) .loopVars(new SDVariable[] { currIteration, indicesLength, cond, this, toPut, indices, putIndices }).functionBodyInputs(new String[] { //note here all inputs are the same as the outputs, and we return the original //the default 3 values (current iteration, max index to loop to and optional condition) //index,max,cond,assignTo,putIn,indices,indicesPut "index", "max", "cond", "assignTo", "toPut", "indices", "indicesPut" }) .functionBodyOutputs(new String[]{ "index", "max", "cond", "assignOutput", "toPut", "indices", "indicesPut"}) .functionName("sliceputs") .loopName("outputs") //note the ordering here is important. Output is the original array where we assigned values. .build())[3]; } /** * Create a graph that takes in the indices as a placeholder, loops over each element in the index vector * and appends the slice to the end result. This graph is equivalent to something like: * INDArray input = ....; * INDArray indices = ...; * INDArray result = input.get(NDArrayIndex.point(indices.getInt(0)); * for(int i = i; i < maxIndex && customInputResult; i++) { * result = Nd4j.concat(0,input.get(NDArrayIndex.point(i))); * } * return result *

* Note this is similar to {@link INDArray#get(INDArray)} * * @param relative the expected target input variable. We use this to pull expected * return data type for the result * @param indices the indices to get * @return the graph for dynamically creating a result graph */ public static SameDiff createLoopPut(SDVariable relative,SDVariable indices) { //standard loop body for loopWithConditions SameDiff loop = SameDiff.create(); //curr index SDVariable index = loop.placeHolder("index",DataType.INT32); //loop until SDVariable maxIndex = loop.placeHolder("max",DataType.INT32); //constant condition of true for custom, just loop till max iterations hit SDVariable currCondition = loop.placeHolder("cond",DataType.BOOL); //the actual variable to pull from SDVariable assignTo = loop.placeHolder("assignTo",relative.dataType()); SDVariable toPut = loop.placeHolder("toPut",relative.dataType()); //the indices to loop over (the input variable SDVariable indicesLoop = loop.placeHolder("indices",indices.dataType()); //standardize indices to length 1 indicesLoop = indicesLoop.reshape("indicesReshape",indicesLoop.length()); SDVariable indicesPut = loop.placeHolder("indicesPut",indices.dataType()); indicesPut = indicesPut.reshape("indicesPutReshape",indicesPut.length()); //the current index to retrieve SDVariable indexToRetrieve = indicesLoop.getView(SDIndex.point(index)).reshape(1).castTo("indexToReceive",DataType.INT64); SDVariable indexToPut = indicesPut.getView(SDIndex.point(index)).reshape(1).castTo("indexToPut",DataType.INT64); SDVariable toAssign = toPut.getView(SDIndex.point(indexToPut)); SDVariable sliceOutput = assignTo.getView(SDIndex.point(indexToRetrieve)); SDVariable assignOutput = loop.assign(sliceOutput,toAssign); SDVariable outputIdentity = loop.identity("assignOutput",assignTo); //ensure the output depends on the final assign so it gets executed, return the final output as a view outputIdentity.addControlDependency(assignOutput); return loop; } /** * Create a graph that takes in the indices as a placeholder, loops over each element in the index vector * and appends the slice to the end result. This graph is equivalent to something like: * INDArray input = ....; * INDArray indices = ...; * INDArray result = input.get(NDArrayIndex.point(indices.getInt(0)); * for(int i = i; i < maxIndex && customInputResult; i++) { * result = Nd4j.concat(0,input.get(NDArrayIndex.point(i))); * } * return result * * Note this is similar to {@link INDArray#get(INDArray)} * @param relative the expected target input variable. We use this to pull expected * return data type for the result * @param indices the indices to get * @return the graph for dynamically creating a result graph */ public static SameDiff createLoopConcat(SDVariable relative,SDVariable indices) { //standard loop body for loopWithConditions SameDiff loop = SameDiff.create(); //curr index SDVariable index = loop.placeHolder("index",DataType.INT32); //loop until SDVariable maxIndex = loop.placeHolder("max",DataType.INT32); //constant condition of true for custom, just loop till max iterations hit SDVariable currCondition = loop.placeHolder("cond",DataType.BOOL); //the input to pull from (in this case this) SDVariable input = loop.placeHolder("input", relative.dataType()); //the actual variable to pull from SDVariable pullFrom = loop.placeHolder("pullFrom",relative.dataType()); //the indices to loop over (the input variable SDVariable indicesLoop = loop.placeHolder("indices",indices.dataType()); //standardize indices to length 1 indicesLoop = indicesLoop.reshape("indicesReshape",indicesLoop.length()); //the current index to retrieve SDVariable indexToRetrieve = indicesLoop.get(SDIndex.point(index)).reshape(1).castTo("indexToReceive",DataType.INT64); //the final concatenated output SDVariable sliceOutput = loop.expandDims("outputSlice",pullFrom.get(SDIndex.point(indexToRetrieve)),0); SDVariable output = loop.concat("output",0,input,sliceOutput); return loop; } /** * Convert this variable to a constant. This is equivalent to "freezing" a variable so that it's value * won't be changed by further training.
* This can only be done for variables and placeholders, not ARRAY type variables (which are usually network activations). * As a constant, this variable will no longer be modified by any subsequent training. * * @return This variable (now a constant) */ public SDVariable convertToConstant(){ return sameDiff.convertToConstant(this); } /** * Convert this variable to a VARIABLE type SDVariable.
* This can only be done for constants and placeholders, not ARRAY type variables (which are usually network activations). * As a variable, this variable will modified during any subsequent training. * * @return This variable (now a variable type SDVariable) */ public SDVariable convertToVariable(){ return sameDiff.convertToVariable(this); } /** * Rename this variable to a new name. Equivalent to {@link SameDiff#renameVariable(String, String)} * * @param newName The new name for the variable - no variable with this name must already exist * @return The current variable (same object) */ public SDVariable rename(String newName) { sameDiff.renameVariable(getVarName(), newName); return this; } /** * Mark this variable as a loss function variable. This means that this variable will be minimized via backprop during training.
* This will add the variable as a loss to any others - i.e., if multiple variables are marked as losses, their values will be summed * to give the total network loss.
* Note that only floating point (Float16/32/64) variables may be marked as a loss.
* Note also that only ARRAY type SDVariables can be marked as losses to be minimized. That is, we cannot mark the value * of a constant, variable or placeholder to be minimized as doing so would not make sense.
* This is equivalent to {@link SameDiff#addLossVariable(String)} */ public void markAsLoss(){ sameDiff.addLossVariable(getVarName()); } /** * Determine if this variable has a gradient with respect to the current loss. Note that: * (a) Non-floating-point variables (integer, string, etc) will never have gradients
* (b) This method will return false if no gradient function has been created yet. See {@link SameDiff#createGradFunction()} * and {@link SameDiff#setLossVariables(String...)}
* (c) Floating point variables may not have any gradient if the current loss does not depend on the variable at all
* @return True if a gradient variable exists for the specified variable, for the current loss */ public boolean hasGradient(){ return sameDiff.variableHasGradient(getVarName()); } private static int binArrToInt(int[] arr) { int x = 0; int m = 1; for (int i = 0; i < arr.length; i++) { if (arr[i] == 1) { x += m; } m *= 2; } return x; } @Override public int hashCode() { int result = super.hashCode(); result = 31 * result + (varName != null ? varName.hashCode() : 0); result = 31 * result + (variableType != null ? variableType.hashCode() : 0); result = 31 * result + (dataType != null ? dataType.hashCode() : 0); return result; } public SDVariable clone(String name,SameDiff sd) { SDVariable v = new SDVariable(); v.varName = name; v.variableType = variableType; v.shape = shape == null ? null : shape.clone(); v.dataType = dataType; v.sameDiff = sd; return v; } public SDVariable clone(SameDiff sd) { SDVariable v = new SDVariable(); v.varName = varName; v.variableType = variableType; v.shape = shape == null ? null : shape.clone(); v.dataType = dataType; v.sameDiff = sd; return v; } @Override public boolean equals(Object o){ if(o == this) return true; if(!(o instanceof SDVariable)) return false; SDVariable s = (SDVariable)o; if(!varName.equals(s.varName)) return false; if(variableType != s.variableType) return false; if(dataType != s.dataType) return false; if(variableType == VariableType.VARIABLE || variableType == VariableType.CONSTANT){ INDArray a1 = getArr(); INDArray a2 = s.getArr(); return a1.equals(a2); } return true; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy