All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.util.TFConverter Maven / Gradle / Ivy

/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.util;

import com.google.common.collect.Streams;
import com.simiacryptus.mindseye.lang.Coordinate;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.layers.cudnn.*;
import com.simiacryptus.mindseye.layers.cudnn.conv.SimpleConvolutionLayer;
import com.simiacryptus.mindseye.layers.java.FullyConnectedLayer;
import com.simiacryptus.mindseye.layers.tensorflow.MatMulLayer;
import com.simiacryptus.mindseye.layers.tensorflow.TFLayer;
import com.simiacryptus.mindseye.layers.tensorflow.TFLayerBase;
import com.simiacryptus.mindseye.network.DAGNode;
import com.simiacryptus.mindseye.network.InnerNode;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.*;
import com.simiacryptus.tensorflow.GraphModel;
import com.simiacryptus.tensorflow.ImageNetworkPipeline;
import org.jetbrains.annotations.NotNull;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

public class TFConverter {

  public static RefList getLayers(@Nonnull ImageNetworkPipeline pipeline) {
    return RefIntStream.range(0, pipeline.graphDefs.size()).mapToObj(i -> getLayer(pipeline, i))
        .collect(RefCollectors.toList());
  }

  @Nonnull
  public static TFLayer getLayer(@Nonnull ImageNetworkPipeline pipeline, int i) {
    GraphDef graphDef = pipeline.graphDefs.get(i);
    String output = pipeline.nodeIds().get(i);
    String input = i == 0 ? "input" : pipeline.nodeIds().get(i - 1);
    TFLayer temp_05_0004 = new TFLayer(graphDef.toByteArray(), new RefHashMap<>(), output, input);
    temp_05_0004.setFloat(true);
    TFLayer temp_05_0003 = temp_05_0004.addRef();
    temp_05_0004.freeRef();
    return temp_05_0003;
  }

  @Nonnull
  public FullyConnectedLayer getFCLayer(@Nonnull MatMulLayer matMulLayer) {
    RefMap temp_05_0010 = matMulLayer.getWeights();
    assert temp_05_0010 != null;
    Tensor weights = temp_05_0010.get("weights");
    temp_05_0010.freeRef();
    int[] intputDims = matMulLayer.getIntputDims();
    int[] outputDims = matMulLayer.getOutputDims();

    matMulLayer.freeRef();
    int[] tfView = Streams
        .concat(RefArrays.stream(outputDims),
            RefIntStream.range(0, intputDims.length).map(i -> intputDims.length - 1 - i).map(i -> intputDims[i]))
        .toArray();
    int[] tfPermute = Streams
        .concat(RefIntStream.range(0, intputDims.length).map(i -> outputDims.length + intputDims.length - 1 - i),
            RefIntStream.range(0, outputDims.length))
        .toArray();
    assert weights != null;
    Tensor temp_05_0011 = weights.reshapeCast(tfView);
    Tensor rearranged = temp_05_0011.permuteDimensions(tfPermute);

    temp_05_0011.freeRef();
    weights.freeRef();
    FullyConnectedLayer fullyConnectedLayer = new FullyConnectedLayer(intputDims, outputDims);
    Tensor temp_05_0012 = fullyConnectedLayer.getWeights();
    assert temp_05_0012 != null;
    temp_05_0012.set(rearranged.addRef());
    temp_05_0012.freeRef();
    rearranged.freeRef();
    return fullyConnectedLayer;
  }

  @Nonnull
  public PipelineNetwork convert(@Nonnull TFLayerBase tfLayer) {
    final PipelineNetwork converted = new PipelineNetwork(1);
    RefConcurrentHashMap nodes = new RefConcurrentHashMap<>();
    RefUtil.freeRef(getNode(tfLayer.getOutputNode(), converted.addRef(),
        new GraphModel(tfLayer.constGraph().toByteArray()), RefUtil.addRef(nodes)));
    tfLayer.freeRef();
    nodes.freeRef();
    return converted;
  }

  @Nullable
  protected DAGNode getNode(@Nonnull String id, @Nonnull PipelineNetwork network, @Nonnull GraphModel tfModel,
                            @Nonnull RefMap map) {
    try {
      if (!map.containsKey(id)) {
        DAGNode result = getDagNode(id, network, tfModel, map.addRef());
        if (!map.containsKey(id)) {
          RefUtil.freeRef(map.put(id, result));
        } else {
          result.freeRef();
        }
      } else {
        network.freeRef();
      }
      DAGNode node = map.get(id);
      map.freeRef();
      return node;
    } catch (Throwable e) {
      throw new RuntimeException("Error converting " + id, e);
    }
  }

  @Nonnull
  protected PoolingLayer getPoolingLayer(@Nonnull GraphModel.GraphNode graphNode) {
    PoolingLayer temp_05_0005 = new PoolingLayer();
    temp_05_0005.setMode(PoolingLayer.PoolingMode.Max);
    PoolingLayer poolingLayer = temp_05_0005.addRef();
    temp_05_0005.freeRef();
    Map attrMap = graphNode.getNodeDef().getAttrMap();
    assert "SAME".equals(attrMap.get("padding").getS().toStringUtf8());
    AttrValue _ksize = attrMap.get("ksize");
    if (null != _ksize) {
      List ksize = _ksize.getList().getIList();
      poolingLayer.setWindowX(Math.toIntExact(ksize.get(1)));
      poolingLayer.setWindowY(Math.toIntExact(ksize.get(2)));
    }
    AttrValue _strides = attrMap.get("strides");
    if (null != _strides) {
      List strides = _strides.getList().getIList();
      poolingLayer.setStrideX(Math.toIntExact(strides.get(1)));
      poolingLayer.setStrideY(Math.toIntExact(strides.get(2)));
    }
    return poolingLayer;
  }

  @Nonnull
  protected ImgBandBiasLayer getBiasAdd(@Nonnull GraphModel.GraphNode graphNode) {
    GraphModel.GraphNode dataNode = graphNode.getInputs().get(1);
    assert dataNode.getOp().equals("Const");
    double[] data = dataNode.getData();
    assert data != null;
    Tensor tensor = new Tensor(data, data.length);
    ImgBandBiasLayer temp_05_0006 = new ImgBandBiasLayer(data.length);
    temp_05_0006.set(tensor.addRef());
    ImgBandBiasLayer temp_05_0001 = temp_05_0006.addRef();
    temp_05_0006.freeRef();
    tensor.freeRef();
    return temp_05_0001;
  }

  @Nonnull
  protected Layer getConv2D(@Nonnull GraphModel.GraphNode graphNode) {
    GraphModel.GraphNode dataNode = graphNode.getInputs().get(1);
    assert dataNode.getOp().equals("Const");
    int[] kernelDims = RefArrays.stream(dataNode.getShape()).mapToInt(x -> (int) x).toArray();
    double[] data = dataNode.getData();
    if (kernelDims.length == 0) {
      assert data != null;
      kernelDims = new int[]{data.length};
    }
    Tensor temp_05_0007 = new Tensor(data, kernelDims[3], kernelDims[2], kernelDims[1], kernelDims[0]);
    Tensor sourceKernel = temp_05_0007.invertDimensions();
    temp_05_0007.freeRef();
    int[] sourceKernelDimensions = sourceKernel.getDimensions();
    //    ConvolutionLayer convolutionLayer = new ConvolutionLayer(sourceKernelDimensions[0], sourceKernelDimensions[1], sourceKernelDimensions[2], sourceKernelDimensions[3]);
    SimpleConvolutionLayer convolutionLayer = new SimpleConvolutionLayer(sourceKernelDimensions[0],
        sourceKernelDimensions[1], sourceKernelDimensions[2] * sourceKernelDimensions[3]);
    Tensor targetKernel = new Tensor(sourceKernelDimensions[0], sourceKernelDimensions[1], sourceKernelDimensions[2],
        sourceKernelDimensions[3]);
    sourceKernel.coordStream(false).forEach(RefUtil.wrapInterface((Consumer) c -> {
      int[] coord = c.getCoords();
      targetKernel.set(sourceKernelDimensions[0] - 1 - coord[0], sourceKernelDimensions[1] - 1 - coord[1], coord[2],
          coord[3], sourceKernel.get(c));
    }, targetKernel.addRef(), sourceKernel.addRef()));
    sourceKernel.freeRef();
    convolutionLayer.set(targetKernel.addRef());
    targetKernel.freeRef();
    AttrValue stridesArr = graphNode.getNodeDef().getAttrMap().get("strides");
    if (null != stridesArr) {
      int[] strides = stridesArr.getList().getIList().stream().mapToInt(x -> Math.toIntExact(x)).toArray();
      int strideX = strides[1];
      int strideY = strides[2];
      if (strideX > 1 || strideY > 1) {
        convolutionLayer.setStrideX(strideX);
        convolutionLayer.setStrideY(strideY);
        return convolutionLayer;
      } else {
        return convolutionLayer;
      }
    } else {
      return convolutionLayer;
    }
  }

  @NotNull
  private DAGNode getDagNode(@Nonnull String id, @Nonnull PipelineNetwork network, @Nonnull GraphModel tfModel, @Nonnull RefMap map) {
    GraphModel.GraphNode graphNode = tfModel.getChild(id);
    assert null != graphNode;
    try {
      if (graphNode.getOp().equals("Conv2D")) {
        return network.add(getConv2D(graphNode), getNode(graphNode.getInputKeys().get(0),
            network.addRef(), tfModel, map));
      } else if (graphNode.getOp().equals("BiasAdd")) {
        return network.add(getBiasAdd(graphNode), getNode(graphNode.getInputKeys().get(0),
            network.addRef(), tfModel, map));
      } else if (graphNode.getOp().equals("Relu")) {
        return network.add(new ActivationLayer(ActivationLayer.Mode.RELU), getNode(graphNode.getInputKeys().get(0),
            network.addRef(), tfModel, map));
      } else if (graphNode.getOp().equals("LRN")) {
        return network.add(getLRNLayer(graphNode), getNode(graphNode.getInputKeys().get(0),
            network.addRef(), tfModel, map));
      } else if (graphNode.getOp().equals("MaxPool")) {
        return network.add(getPoolingLayer(graphNode), getNode(graphNode.getInputKeys().get(0),
            network.addRef(), tfModel, map));
      } else if (graphNode.getOp().equals("Concat")) {
        List inputKeys = graphNode.getInputKeys();
        InnerNode innerNode = network.add(new ImgConcatLayer(),
            inputKeys.stream().skip(1)
                .map(inputKey -> getNode(inputKey,
                    network.addRef(), tfModel, RefUtil.addRef(map)))
                .toArray(i -> new DAGNode[i]));
        map.freeRef();
        return innerNode;
      } else if (graphNode.getOp().equals("Placeholder")) {
        map.freeRef();
        return network.getInput(0);
      } else {
        map.freeRef();
        throw new IllegalArgumentException(graphNode.getOp());
      }
    } finally {
      network.freeRef();
    }
  }

  @Nonnull
  private LRNLayer getLRNLayer(@Nonnull GraphModel.GraphNode graphNode) {
    Map attrMap = graphNode.getNodeDef().getAttrMap();
    long depth_radius = attrMap.get("depth_radius").getI();
    float alpha = attrMap.get("alpha").getF();
    float bias = attrMap.get("bias").getF();
    float beta = attrMap.get("beta").getF();
    long width = depth_radius * 2 + 1;
    LRNLayer temp_05_0009 = new LRNLayer((int) width);
    temp_05_0009.setAlpha(alpha * width);
    LRNLayer temp_05_0014 = temp_05_0009.addRef();
    temp_05_0014.setBeta(beta);
    LRNLayer temp_05_0015 = temp_05_0014.addRef();
    temp_05_0015.setK(bias);
    LRNLayer temp_05_0008 = temp_05_0015.addRef();
    temp_05_0015.freeRef();
    temp_05_0014.freeRef();
    temp_05_0009.freeRef();
    return temp_05_0008;
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy