org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import lombok.Getter;
import lombok.Setter;
import lombok.val;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
public class TensorArray extends BaseTensorOp {
@Getter
@Setter
protected DataType tensorArrayDataType;
@Getter
@Setter
protected SDVariable flow;
@Getter
@Setter
protected boolean clearOnRead = true;
@Override
public String tensorflowName() {
return "TensorArrayV3";
}
public TensorArray(String name, SameDiff sameDiff, DataType dataType){
super(name, sameDiff, new SDVariable[]{});
this.tensorArrayDataType = dataType;
}
public TensorArray(SameDiff sameDiff, DataType dataType){
super(sameDiff, new SDVariable[]{});
this.tensorArrayDataType = dataType;
}
public TensorArray(TensorArray ta) {
super(ta.sameDiff, new SDVariable[]{});
this.tensorArrayDataType = ta.tensorArrayDataType;
}
public TensorArray(TensorArray ta, SDVariable[] inputs){
super(ta.sameDiff, inputs);
this.tensorArrayDataType = ta.tensorArrayDataType;
}
@Override
public void configureFromArguments() {
super.configureFromArguments();
if(!bArguments.isEmpty()) {
this.clearOnRead = bArguments.get(0);
}
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) {
val idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
NodeDef iddNode = null;
for(int i = 0; i < graph.getNodeCount(); i++) {
if(graph.getNode(i).getName().equals(idd)) {
iddNode = graph.getNode(i);
}
}
val arr = TFGraphMapper.getNDArrayFromTensor(iddNode);
if (arr != null) {
int idx = arr.getInt(0);
addIArgument(idx);
}
this.tensorArrayDataType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
}
public TensorArray(){
this(DataType.FLOAT);
}
public TensorArray(DataType dataType){
this.tensorArrayDataType = dataType;
}
@Override
public String toString() {
return opName();
}
@Override
public String opName() {
return "create_list";
}
@Override
public Op.Type opType() {
return Op.Type.CUSTOM;
}
public SDVariable getVar() {
if(flow != null)
return flow;
return outputVariables()[0];
}
@Override
public SameDiff getSameDiff() {
val sd = this.sameDiff;
if (sd.getChild() != null) {
return sd.getChild();
}
return sd;
}
private SDVariable intToVar(int... index){
return this.sameDiff.constant(Nd4j.createFromArray(index));
}
//----------- read ops-----------------\\
public SDVariable read(int index) {
return new TensorArrayRead(getSameDiff(), new SDVariable[]{getVar(), intToVar(index)}).outputVariable();
}
public SDVariable read(SDVariable from,SDVariable index) {
return new TensorArrayRead(getSameDiff(), new SDVariable[]{from, index}).outputVariable();
}
public SDVariable read(SDVariable index) {
return new TensorArrayRead(getSameDiff(), new SDVariable[]{getVar(), index}).outputVariable();
}
public SDVariable gather(SDVariable flow, int... indices){
return new TensorArrayGather(getSameDiff(), new SDVariable[]{getVar(), sameDiff.constant(Nd4j.createFromArray(indices)), flow}).outputVariable();
}
public SDVariable gather(SDVariable flow, SDVariable indices){
return new TensorArrayGather(getSameDiff(), new SDVariable[]{getVar(), indices, flow}).outputVariable();
}
public SDVariable stack(SDVariable flow){
return new TensorArrayGather(getSameDiff(), new SDVariable[]{getVar(), intToVar(-1), flow}).outputVariable();
}
public SDVariable concat(SDVariable flow) {
return new TensorArrayConcat(getSameDiff(), new SDVariable[]{getVar()}).outputVariable();
}
//----------- write ops-----------------\\
public SDVariable write(SDVariable flow, int index, SDVariable value){
return write(flow, intToVar(index), value);
}
public SDVariable write(SDVariable flow, SDVariable index, SDVariable value){
return new TensorArrayWrite(getSameDiff(),
new SDVariable[]{getVar(),
index, value, flow}).outputVariable();
}
public SDVariable scatter(SDVariable flow, SDVariable value, int... indices){
return new TensorArrayScatter(getSameDiff(),
new SDVariable[]{getVar(),
intToVar(indices),
value, flow}).outputVariable();
}
public SDVariable scatter(SDVariable flow, SDVariable value, SDVariable indices){
return new TensorArrayScatter(getSameDiff(),
new SDVariable[]{getVar(),
indices,
value, flow}).outputVariable();
}
public SDVariable unstack(SDVariable flow, SDVariable value) {
return new TensorArrayScatter(getSameDiff(),
new SDVariable[]{getVar(),
intToVar(-1),
value, flow}).outputVariable();
}
public SDVariable size( SDVariable value) {
return new TensorArraySize(getSameDiff(),value).outputVariable();
}
public SDVariable remove( SDVariable value,SDVariable idx) {
return new TensorArrayRemove(getSameDiff(),value,idx).outputVariable();
}
public SDVariable remove( SDVariable value,int idx) {
return new TensorArrayRemove(getSameDiff(),value,idx).outputVariable();
}
public SDVariable remove( SDVariable value) {
return remove(value,-1);
}
@Override
public List calculateOutputDataTypes(List inputDataType) {
//The SDVariable that is the output of this "function" is just a dummy variable anyway...
//Usually 2 outputs... seems like first one is dummy, second one is a float??
//TODO work out exactly what this second output is for (it's used in TensorArrayWrite for example...
return Arrays.asList(DataType.BOOL, DataType.FLOAT);
}
@Override
public int getNumOutputs(){
return 2;
}
/**
* Returns the item at the specified index
* in the specified list.
* @param sd the same diff instance to use
* @param inputs the inputs including the relevant tensor array variable and position
* @return
*/
public static SDVariable itemAtIndex(SameDiff sd,SDVariable[] inputs) {
return itemAtIndex(sd,inputs,null);
}
/**
* Returns the item at the specified index
* in the specified list. The output variable
* name to specify for the final output.
* @param sd the same diff instance to use
* @param inputs the inputs including the relevant tensor array variable and position
* @param outputVarName the name of the output variable for the read
* @return
*/
public static SDVariable itemAtIndex(SameDiff sd,SDVariable[] inputs,String outputVarName) {
SDVariable sequenceVar = inputs[0];
SDVariable position = inputs.length < 2 ? sd.constant(-1) : inputs[1];
TensorArray ta = getTensorArray(sd, sequenceVar);
SDVariable read = ta.read(sequenceVar,position);
for(int i = 0; i < inputs.length; i++)
read.addControlDependency(inputs[i]);
if(outputVarName != null) {
read = read.rename(outputVarName);
}
for(int i = 0; i < inputs.length; i++)
read.addControlDependency(inputs[i]);
return read;
}
/**
* Returns the required shape for elements in this tensor array.
* If a second input is not present an {@link IllegalArgumentException} is thrown.
* @return
*/
public long[] requiredShape() {
Preconditions.checkState(args().length > 1,"Missing input shape.");
INDArray inputShape = arg(1).getArr();
long[] inputShapeArr = inputShape.toLongVector();
return inputShapeArr;
}
/**
* Get the associated {@link TensorArray} instance
* related to this op.
* Sometimes when a TensorArray op is returned
* it can be renamed or may not directly be the associated
* {@link TensorArray} instance. This helps discover the underlying
* {@link TensorArray} op for use to declare other operations to manipulate
* that instance such as {@link TensorArray#read(int)}
* @param sd the input instance
* @param sequenceVar the relevant variable to discover the {@link TensorArray}
* for
* @return
*/
public static TensorArray getTensorArray(SameDiff sd, SDVariable sequenceVar) {
DifferentialFunction baseTensorOp = sd.getVariableOutputOp(sequenceVar.name());
TensorArray ta = null;
if(baseTensorOp instanceof TensorArray) {
ta = (TensorArray) baseTensorOp;
} else {
while(!(baseTensorOp instanceof TensorArray)) {
for(SDVariable input : baseTensorOp.args()) {
if(sd.getVariableOutputOp(input.name()) instanceof TensorArray) {
baseTensorOp = sd.getVariableOutputOp(input.name());
ta = (TensorArray) baseTensorOp;
return ta;
} else {
return getTensorArray(sd,input);
}
}
}
}
return ta;
}
/**
* Remove the last element from the relevant
* {@link TensorArray}
* @param sameDiff the samediff instance to use
* @param inputSequence the relevant variable for the associated
* {@link TensorArray}
* @return
*/
public static SDVariable removeFromTensorArray(SameDiff sameDiff,SDVariable inputSequence) {
return removeFromTensorArray(sameDiff,inputSequence, sameDiff.constant(-1),null);
}
/**
* Remove an element from the relevant
* {@link TensorArray}
* @param sameDiff the samediff sinstance to use
* @param inputSequence the relevant variable for the associated
* {@link TensorArray}
* @param position the position to remove
* @return
*/
public static SDVariable removeFromTensorArray(SameDiff sameDiff,SDVariable inputSequence,SDVariable position) {
return removeFromTensorArray(sameDiff,inputSequence,position,null);
}
/**
* Remove an element from the relevant
* {@link TensorArray}
* @param sameDiff the samediff instance to use
* @param inputSequence the relevant variable for the associated
* {@link TensorArray}
* @param position the position to remove
* @param outputVarName the name of the output variable
* @return
*/
public static SDVariable removeFromTensorArray(SameDiff sameDiff,SDVariable inputSequence,SDVariable position,String outputVarName) {
TensorArray ta = TensorArray.getTensorArray(sameDiff,inputSequence);
SDVariable outputVar = ta.remove(inputSequence,position);
outputVar.addControlDependency(inputSequence);
outputVar.addControlDependency(position);
if(outputVarName != null)
return outputVar.rename(outputVarName);
return outputVar;
}
/**
* Create an empty sequence with the specified data type.
* An output variable name will be generated.
* @param sd the samediff instance to use
* @param sequence the output variable of the sequence to get the size of
* @return the output variable of the created sequence
*/
public static SDVariable sizeOfTensorArray(SameDiff sd,SDVariable sequence) {
return sizeOfTensorArray(sd,sequence,null);
}
/**
* Create an empty sequence with the specified data type.
* An output variable name will be generated.
* @param sd the samediff instance to use
* @param sequence the output variable of the sequence to get the size of
* @param outputVarName the output name of the size variable
* @return the output variable of the created sequence
*/
public static SDVariable sizeOfTensorArray(SameDiff sd,SDVariable sequence,String outputVarName) {
TensorArray tensorArray = TensorArray.getTensorArray(sd,sequence);
SDVariable outputVar = tensorArray.size(sequence);
outputVar.addControlDependency(sequence);
if(outputVarName != null)
outputVar = outputVar.rename(outputVarName);
return outputVar;
}
/**
* Create an empty sequence with the specified data type.
* An output variable name will be generated.
* @param sd the samediff instance to use
* @param dataType the data type of the sequence
* @return the output variable of the created sequence
*/
public static SDVariable createEmpty(SameDiff sd,DataType dataType) {
return createEmpty(sd,dataType,null);
}
/**
* Create an empty sequence with the specified data type.
* @param sd the samediff instance to use
* @param dataType the data type of the sequence
* @param outputVarName the output variable name of the sequence
* @return the output variable of the created sequence
*/
public static SDVariable createEmpty(SameDiff sd,DataType dataType,String outputVarName) {
TensorArray ta = sd.tensorArray(dataType);
SDVariable outputVar = ta.outputVariable();
if(outputVar.name() != null)
return outputVar.rename(outputVarName);
return outputVar;
}
/**
* Create an {@link TensorArray} op from the given inputs,
* note this is the same as calling {@link #createTensorArrayFrom(SameDiff, SDVariable[],String)}
* with null. The null value will avoid renaming the output
* @param sd the {@link SameDiff} instance to use
* @param inputs the input variables to create a {@link TensorArray} for
* @return the output variable for the tensor array
*/
public static SDVariable createTensorArrayFrom(SameDiff sd,SDVariable[] inputs) {
return createTensorArrayFrom(sd,inputs,null);
}
/**
* Create an {@link TensorArray} op from the given inputs
* @param sd the {@link SameDiff} instance to use
* @param inputs the input variables to create a {@link TensorArray} for
* @param outputVarName the name of the output variable to use for the final output of the loop
* @return the output variable for the tensor array
*/
public static SDVariable createTensorArrayFrom(SameDiff sd,SDVariable[] inputs,String outputVarName) {
TensorArray outputVar = sd.tensorArray(inputs[0].dataType());
SDVariable outTmp = outputVar.getVar();
for(int i = 0; i < inputs.length; i++) {
val write = outputVar.write(outTmp,i,inputs[i]);
if(outTmp != null) {
write.addControlDependency(outTmp);
}
outTmp = write;
}
if(outputVarName != null) {
outTmp = outTmp.rename(outputVarName);
}
return outTmp;
}
}