All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.tensorflow.TensorflowUtil Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.tensorflow;

import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.simiacryptus.ref.lang.RefAware;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.util.Util;
import org.tensorflow.*;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.*;
import org.tensorflow.op.Ops;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class TensorflowUtil {
  private static final SumFn doubleSum = new SumFn(Double.class);
  private static final SumFn floatSum = new SumFn(Float.class);

  @Nullable
  public static Operation find(@Nonnull Graph graph, String name) {
    Iterator operations = graph.operations();
    while (operations.hasNext()) {
      Operation operation = operations.next();
      if (operation.name().equals(name)) {
        return operation;
      }
    }
    return null;
  }

  public static byte[] makeGraph(@Nonnull Consumer builder) {
    try (Graph graph = new Graph()) {
      builder.accept(Ops.create(graph));
      byte[] bytes = graph.toGraphDef();
      try {
        validate(GraphDef.parseFrom(bytes));
      } catch (InvalidProtocolBufferException e) {
        throw Util.throwException(e);
      }
      return bytes;
    }
  }

  public static void validate(@Nonnull GraphDef graphDef) {
    List names = graphDef.getNodeList().stream().map(x -> x.getName()).collect(Collectors.toList());
    graphDef.getNodeList().stream().map(x -> x.getName()).distinct().forEach(o -> names.remove(o));
    if (!names.isEmpty()) {
      throw new IllegalStateException("Duplicate names: " + RefUtil.get(names.stream().reduce((a, b) -> a + ", " + b)));
    }
  }

  @Nonnull
  public static String addConst(GraphDef.Builder graphBuilder, String name, TensorShapeProto shape, String... label) {
    graphBuilder.addNode(newConst(name, shape, label));
    return name;
  }

  @Nonnull
  public static NodeDef newConst(@Nonnull String name, TensorShapeProto shape, String[] label) {
    return newConst(name, getTensorAttr(shape, label), DataType.DT_STRING);
  }

  @Nonnull
  public static AttrValue getTensorAttr(TensorShapeProto shape, String... label) {
    return AttrValue.newBuilder().setTensor(buildTensor(shape, label)).build();
  }

  @Nonnull
  public static NodeDef newConst(@Nonnull String name, @Nonnull AttrValue attrValue, @Nonnull DataType dtString) {
    return NodeDef.newBuilder()
        .setName(name)
        .setOp("Const")
        .putAttr("dtype", AttrValue.newBuilder().setType(dtString).build())
        .putAttr("value", attrValue)
        .build();
  }

  @Nonnull
  public static GraphDef editGraph(@Nonnull GraphDef graph, @Nonnull @RefAware Function edit) {
    GraphDef build = edit.apply(graph.toBuilder()).build();
    RefUtil.freeRef(build);
    return build;
  }

  public static void editNode(GraphDef.Builder graphBuilder, String name, @Nonnull @RefAware Function edit) {
    List nodeList = graphBuilder.getNodeList();
    NodeDef nodeDef = nodeList.stream().filter(x -> x.getName().equals(name)).findAny()
        .orElseGet(() -> {
          throw new NoSuchElementException(String.format(
              "%s not found in %s",
              name,
              RefUtil.get(nodeList.stream().map(nodeDef1 -> nodeDef1.getName()).reduce((a, b) -> a + "," + b))
          ));
        });
    int index = nodeList.indexOf(nodeDef);
    graphBuilder.removeNode(index);
    graphBuilder.addNode(edit.apply(nodeDef.toBuilder()).build());
    RefUtil.freeRef(edit);
  }

  public static void editNodes(GraphDef.Builder graphBuilder, Function edit) {
    new ArrayList<>(graphBuilder.getNodeList()).stream().forEach(previousValue -> {
      NodeDef newValue = edit.apply(previousValue);
      if (newValue != previousValue) {
        graphBuilder.removeNode(graphBuilder.getNodeList().indexOf(previousValue));
        graphBuilder.addNode(newValue);
      }
    });
  }

  @Nonnull
  public static  Tensor add(@Nonnull Tensor... tensors) {
    return add(Arrays.stream(tensors));
  }

  @Nonnull
  public static TensorProto buildTensor(TensorShapeProto tensorShapeProto, @Nonnull int... vs) {
    TensorProto.Builder tensor = TensorProto.newBuilder()
        .setDtype(DataType.DT_INT32)
        .setTensorShape(tensorShapeProto);
    for (int l : vs) {
      tensor.addIntVal(l);
    }
    return tensor.build();
  }

  @Nonnull
  public static TensorProto buildTensor(TensorShapeProto tensorShapeProto, @Nonnull String... vs) {
    TensorProto.Builder tensor = TensorProto.newBuilder()
        .setDtype(DataType.DT_STRING)
        .setTensorShape(tensorShapeProto);
    for (String l : vs) {
      tensor.addStringVal(ByteString.copyFromUtf8(l));
    }
    return tensor.build();
  }

  @Nonnull
  public static TensorShapeProto buildTensorShape(@Nonnull long... dims) {
    TensorShapeProto.Builder builder = TensorShapeProto.newBuilder();
    for (long l : dims) {
      builder.addDim(TensorShapeProto.Dim.newBuilder().setSize(l).build());
    }
    return builder.build();
  }

  @Nonnull
  public static  Tensor add(@Nonnull Stream> stream) {
    return RefUtil.get(stream.reduce((a, b) -> {
      if (a.dataType() == org.tensorflow.DataType.DOUBLE) {
        Tensor tensor = doubleSum.add(a.expect(Double.class), b.expect(Double.class));
        a.close();
        b.close();
        return tensor;
      } else {
        Tensor tensor = floatSum.add(a.expect(Float.class), b.expect(Float.class));
        a.close();
        b.close();
        return tensor;
      }
    }));
  }

  @Nonnull
  public static List rankNode(@Nonnull NodeDef node, @Nonnull DataType type, @Nonnull String rankNode) {
    String endNode = rankNode + "/end";
    String startNode = rankNode + "/start";
    String stepNode = rankNode + "/step";
    return Arrays.asList(
        newConst(startNode, AttrValue.newBuilder().setTensor(
            buildTensor(TensorflowUtil.buildTensorShape(), 0)
        ).build(), DataType.DT_INT32),
        newConst(stepNode, AttrValue.newBuilder().setTensor(
            buildTensor(TensorflowUtil.buildTensorShape(), 1)
        ).build(), DataType.DT_INT32),
        NodeDef.newBuilder()
            .setName(endNode)
            .addInput(node.getName())
            .setOp("Rank")
            .putAttr("T", AttrValue.newBuilder().setType(type).build())
            .build(),
        NodeDef.newBuilder()
            .addInput(startNode)
            .addInput(endNode)
            .addInput(stepNode)
            .setName(rankNode)
            .setOp("Range")
            .build()
    );
  }

  public static class SumFn {
    @Nonnull
    private final Graph sumGraph;
    @Nonnull
    private final Session sumSession;
    private final Output in1;
    private final Output in2;
    private final Output out;

    public SumFn(Class dtype) {
      sumGraph = new Graph();
      Ops ops = Ops.create(sumGraph);
      in1 = ops.placeholder(dtype).asOutput();
      in2 = ops.placeholder(dtype).asOutput();
      out = ops.math.add(
          in1,
          in2
      ).asOutput();
      sumSession = new Session(sumGraph);
    }

    @Nonnull
    public Tensor add(Tensor a, Tensor b) {
      return sumSession.runner().feed(in1, a).feed(in2, b).fetch(out).run().get(0).expect(Double.class);
    }

    public void close() {
      sumSession.close();
      sumGraph.close();
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy