Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.list.compat.TensorList;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Map;
public class TensorArrayV3 extends BaseTensorOp {
TensorList list;
@Override
public String tensorflowName() {
return "TensorArrayV3";
}
public TensorArrayV3(String name, SameDiff sameDiff){
super(name, sameDiff, new SDVariable[]{});
this.list = new TensorList(this.getOwnName());
}
public TensorArrayV3(SameDiff sameDiff){
super(sameDiff, new SDVariable[]{});
this.list = new TensorList(this.getOwnName());
}
public TensorArrayV3(TensorArrayV3 ta){
super(ta.sameDiff, new SDVariable[]{});
this.list = ta.list;
}
public TensorArrayV3(TensorArrayV3 ta, SDVariable[] inputs){
super(ta.sameDiff, inputs);
this.list = ta.list;
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) {
val idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
NodeDef iddNode = null;
for(int i = 0; i < graph.getNodeCount(); i++) {
if(graph.getNode(i).getName().equals(idd)) {
iddNode = graph.getNode(i);
}
}
val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",iddNode,graph);
if (arr != null) {
int idx = arr.getInt(0);
addIArgument(idx);
}
this.list = new TensorList(this.getOwnName());
}
public TensorArrayV3(){
this.list = new TensorList(this.getOwnName());
}
@Override
public TensorList execute(SameDiff sameDiff) {
/*val list = new TensorList(this.getOwnName());
// we might have size here
return list;*/
return this.list;
}
@Override
public String toString() {
return opName();
}
@Override
public String opName() {
return "tensorarrayv3";
}
@Override
public Op.Type opType() {
return Op.Type.CUSTOM;
}
private SDVariable getVar(){
// get var associated with the tensor list
getSameDiff().putListByName(list.getName(), list);
val name = list.getName();
if (getSameDiff().variableMap().containsKey(name)){
return getSameDiff().variableMap().get(name);
}
return getSameDiff().var(list.getName(), new long[]{1});
}
@Override
public SameDiff getSameDiff(){
val sd = this.sameDiff;
if (sd.getChild() != null){
return sd.getChild();
}
return sd;
}
private SDVariable intToVar(int... index){
return this.sameDiff.var(Nd4j.create(ArrayUtil.toDouble(index)));
}
//----------- read ops-----------------\\
public SDVariable read(int index){
return new TensorArrayReadV3(getSameDiff(), new SDVariable[]{getVar(), intToVar(index)}).outputVariable();
}
public SDVariable read(SDVariable index){
return new TensorArrayReadV3(getSameDiff(), new SDVariable[]{getVar(), index}).outputVariable();
}
public SDVariable gather(int... indices){
return new TensorArrayGatherV3(getSameDiff(), new SDVariable[]{getVar(), intToVar(indices)}).outputVariable();
}
public SDVariable gather(SDVariable indices){
return new TensorArrayGatherV3(getSameDiff(), new SDVariable[]{getVar(), indices}).outputVariable();
}
public SDVariable stack(){
return new TensorArrayGatherV3(getSameDiff(), new SDVariable[]{getVar(), intToVar(-1)}).outputVariable();
}
public SDVariable concat(){
return new TensorArrayConcatV3(getSameDiff(), new SDVariable[]{getVar()}).outputVariable();
}
//----------- write ops-----------------\\
public void write(int index, SDVariable value){
//return new TensorArrayV3(this,
new TensorArrayWriteV3(getSameDiff(),
new SDVariable[]{getVar(),
intToVar(index), value}).outputVariables();//);
}
public void write(SDVariable index, SDVariable value){
System.out.println("TA write - " + this.sameDiff);
//return new TensorArrayV3(this,
new TensorArrayWriteV3(getSameDiff(),
new SDVariable[]{getVar(),
index, value}).outputVariables();//);
}
public void scatter(SDVariable value, int... indices){
//return new TensorArrayV3(this,
new TensorArrayScatterV3(getSameDiff(),
new SDVariable[]{getVar(),
intToVar(indices),
value}).outputVariables();//);
}
public void scatter(SDVariable value, SDVariable indices){
//return new TensorArrayV3(this,
new TensorArrayScatterV3(getSameDiff(),
new SDVariable[]{getVar(),
indices,
value}).outputVariables();//);
}
public void unstack(SDVariable value){
//return new TensorArrayV3(this,
new TensorArrayScatterV3(getSameDiff(),
new SDVariable[]{getVar(),
intToVar(-1),
value}).outputVariables();//);
}
}