All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.utils.tf.Session.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.utils.tf

import java.nio.{ByteOrder, DoubleBuffer, FloatBuffer}

import com.intel.analytics.bigdl.Criterion
import com.intel.analytics.bigdl.dataset.{DataSet, DistributedDataSet, MiniBatch, Sample}
import com.intel.analytics.bigdl.nn.{ClassNLLCriterion, Graph, Linear}
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractCriterion, AbstractModule, Activity}
import com.intel.analytics.bigdl.optim._
import com.intel.analytics.bigdl.tensor.TensorNumericMath.{NumericWildcard, TensorNumeric}
import com.intel.analytics.bigdl.tensor.{Storage, Tensor}
import com.intel.analytics.bigdl.utils._
import org.apache.spark.SparkContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.tensorflow.framework.{GraphDef, NodeDef}
import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.models.utils.ModelBroadcast
import TFTensorNumeric.NumericByteString
import com.intel.analytics.bigdl.utils.tf.BigDLSessionImpl.FakeCriterion
import org.apache.hadoop.io.{BytesWritable, NullWritable}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import scala.util.Random

abstract class Session[T: ClassTag] {

  /**
   * Train the tensorflow graph
   * @param outputs
   * @param dataSet
   * @param optMethod
   * @param criterion
   * @param endWhen
   * @return
   */
  def train(outputs: Seq[String],
            dataSet: DistributedDataSet[MiniBatch[T]],
            optMethod: OptimMethod[T],
            criterion: Criterion[T],
            endWhen: Trigger): Graph[T]


  /**
   * Train the tensorflow graph. The model must be fed data with a queue
   * @param endPoints
   * @param optMethod
   * @param endWhen
   * @param isDataBatch if the model input is the batch
   * @param batchSize batch size, which should be original batch size * total core number
   * @param sc
   * @param loss
   * @return
   */
  def train(
    endPoints: Seq[String],
    optMethod: OptimMethod[T],
    endWhen: Trigger,
    isDataBatch: Boolean,
    batchSize: Int,
    sc: SparkContext,
    loss: Option[String]
  ): this.type

  /**
   * Predict data with tensorflow graph. The data must hold in a queue
   * @param endPoints
   * @param isDataBatch if the model input is the batch
   * @param batchSize batch size, which should be original batch size * total core number
   * @param sc
   * @return
   */
  def predict(
    endPoints: Seq[String],
    isDataBatch: Boolean,
    batchSize: Int,
    sc: SparkContext
  ): RDD[Activity]

  /**
   * Dump varaible contents to a file
   * @param binFile
   * @return
   */
  def saveParameters(binFile: String): this.type
}

class BigDLSessionImpl[T: ClassTag](graph: Seq[NodeDef], context: Context[T],
  byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN)
  (implicit ev: TensorNumeric[T]) extends Session[T] {

  import scala.collection.JavaConverters._

  override def train(outputs: Seq[String],
    dataSet: DistributedDataSet[MiniBatch[T]],
    optMethod: OptimMethod[T],
    criterion: Criterion[T],
    endWhen: Trigger): Graph[T] = {

    val (model, input) = constructModel(outputs, byteOrder, true, None)

    require(input.element.getOp == "Placeholder",
      "only support Placeholder as input when in-memory input data is provided")

    val opt = new DistriOptimizer(
      model,
      dataSet,
      criterion
    )
    opt.setOptimMethod(optMethod).setEndWhen(endWhen)
      .optimize()
    model
  }

  override def train(
    endPoints: Seq[String],
    optMethod: OptimMethod[T],
    endWhen: Trigger,
    isDataBatch: Boolean,
    batchSize: Int,
    sc: SparkContext,
    loss: Option[String]
  )
  : this.type = {
    val weightsAndGrads = endPoints.map(e => name2Node(e)).map(n => n.graph(true).DFS).flatten
      .map(n => TFUpdater(n.element)).flatten.toSet

    require(weightsAndGrads.size != 0, "Cannot find updater nodes")
    context.setAssignGrads(weightsAndGrads)
    val modelOutputs = if (loss.isDefined) {
      Seq(loss.get) ++ weightsAndGrads.map(_._2).toSeq
    } else {
      weightsAndGrads.map(_._2).toSeq
    }
    val (model, input) = constructModel(modelOutputs, byteOrder, false, Some(context))
    val data = BigDLSessionImpl.toSample[T](getRDD(Seq(input.element.getName), sc, isDataBatch))

    val opt = Optimizer[T](
      model,
      data,
      new FakeCriterion[T](),
      batchSize
    )
    opt.setOptimMethod(optMethod).setEndWhen(endWhen)
      .optimize()
    this
  }

  override def predict(
    endPoints: Seq[String],
    isDataBatch: Boolean,
    batchSize: Int,
    sc: SparkContext
  ): RDD[Activity] = {
    val (model, input) = constructModel(endPoints, byteOrder, true, Some(context))
    val data = BigDLSessionImpl.toSample[T](getRDD(Seq(input.element.getName), sc, isDataBatch))
    model.predict(data)
  }

  override def saveParameters(binFile: String): this.type = {
    TensorflowLoader.saveBinFile(binFile, context)
    this
  }

  private val inputOp = Set("ReaderReadV2", "QueueDequeueV2", "QueueDequeueManyV2", "Placeholder")

  private val dequeueOp = Set("QueueDequeueV2", "QueueDequeueManyV2", "ReaderReadV2")

  private val enqueueOp = Set("QueueEnqueueV2", "QueueEnqueueManyV2")

  private val queueOp = Set("RandomShuffleQueueV2", "FIFOQueueV2")

  private val (wholeTFGraph, _, _) = TensorflowLoader.buildTFGraph(graph.asJava, null)

  private val name2Node = wholeTFGraph.
    DFS.filter(_.element != null).map(node => (node.element.getName, node)).toMap

  private def handleReaderNode(node: Node[NodeDef], cache: DataCache,
    sc: SparkContext): RDD[Table] = {
    require(node.prevNodes.length == 2, "require ReaderReadV2 only has two inputs")
    val readerNode = node.prevNodes.head
    val queueNode = node.prevNodes(1)
    val dequeNodeNames = mutable.LinkedHashSet[String]()

    queueNode.nextNodes
      .filter(n => n.element != null && dequeueOp(n.element.getOp))
      .map(n => n.element.getName.split(":")(0)).foreach(dequeNodeNames.add)

    val nameToIndex = dequeNodeNames.zipWithIndex.toMap
    val index = nameToIndex(node.element.getName)
    val nSlices = dequeNodeNames.size

    val enqueueNodes = findEnqueueNodes(queueNode)
    val filesSeq = if (cache.contains(queueNode.element.getName)) {
      val resultArray = cache(queueNode.element.getName)
      val result = resultArray(index)
      resultArray(index) = null
      result
    } else {
      val allResult = enqueueNodes.map { enqueueNode =>
        val inputs = Seq(enqueueNode.element.getName)
        val result = constructLocalData(inputs, new DataCache())
        if (enqueueNode.element.getOp == "QueueEnqueueManyV2") {
          result.flatMap(BigDLSessionImpl.splitTensorByFirstDim)
        } else {
          result
        }
      }.reduce { (outerSeq1, outerSeq2) =>
        outerSeq1.zip(outerSeq2).map { case (seq1, seq2) =>
          seq1.add(seq2)
        }
      }
      val resultArray = split(allResult, nSlices)
      cache.put(queueNode.element.getName, resultArray)
      resultArray(index)
    }

    readerNode.element.getOp match {
      case "TFRecordReaderV2" => readTFRecord(filesSeq, sc)
      case "FixedLengthRecordReaderV2" => readFixedLengthRecord(filesSeq, readerNode.element, sc)
    }
  }

  private def split[A](xs: Seq[A], n: Int): Array[Seq[A]] = {
    val result = new Array[Seq[A]](n)
    var i = 0
    while (i < n) {
      result(i) = Vector[A]()
      i = i + 1
    }

    var j = 0
    while (j < xs.length) {
      result(j % n) = result(j % n) :+ xs(j)
      j = j + 1
    }

    result
  }

  private def allHDFSFiles(fileNames: Seq[String]) = {
    val isHdfs = fileNames.map(_.startsWith("hdfs:"))
    if (isHdfs.reduceLeft(_ && _)) {
      true
    } else if (isHdfs.map(!_).reduceLeft(_ && _)) {
      false
    } else {
      throw new IllegalArgumentException("filenames contain both local and hdfs path")
    }
  }

  private def readTFRecord(filesTable: Seq[Table], sc: SparkContext): RDD[Table] = {
    if (filesTable.isEmpty) {
      return sc.parallelize(Seq.empty[Table], numSlices = Engine.coreNumber() * Engine.nodeNumber())
    }
    val fileNames = filesTable.map { t =>
      require(t.length() == 1, "Reader can only read one file at a time")
      val fileTensor = t[Tensor[ByteString]](1)
      require(fileTensor.isScalar, s"require fileTensor to be a scalar," +
        s" but got size: ${fileTensor.size()}")
      val file = fileTensor.value()
      file.toStringUtf8
    }

    val isHdfs = allHDFSFiles(fileNames)
    if (isHdfs) {
      // all files are hdfs files
      fileNames.map { fileName =>
        val rdd = sc.newAPIHadoopFile[BytesWritable, NullWritable, TFRecordInputFormat](fileName)
        rdd.map { case (k, v) =>
          val table = T()
          val key = Tensor.scalar[ByteString](ByteString.copyFromUtf8("fake_key"))
          val value = Tensor.scalar[ByteString](ByteString.copyFrom(k.copyBytes()))
          table.insert(key)
          table.insert(value)
          table
        }
      }.reduceLeft(_.union(_))
    } else {
      // all files are local files
      val result = fileNames.flatMap { file =>
        val iter = TFRecordIterator(new java.io.File(file))
        iter
      }.map { record =>
        val table = T()
        val key = Tensor.scalar[ByteString](ByteString.copyFromUtf8("fake_key"))
        val value = Tensor.scalar[ByteString](ByteString.copyFrom(record))
        table.insert(key)
        table.insert(value)
        table
      }
      sc.parallelize(result, numSlices = Engine.coreNumber() * Engine.nodeNumber())
    }
  }

  private def readFixedLengthRecord(filesTable: Seq[Table], readerNode: NodeDef, sc: SparkContext)
    : RDD[Table] = {

    val footerBytes = readerNode.getAttrMap.get("footer_bytes").getI.toInt
    val headerBytes = readerNode.getAttrMap.get("header_bytes").getI.toInt
    val hopBytes = readerNode.getAttrMap.get("hop_bytes").getI.toInt
    val recordBytes = readerNode.getAttrMap.get("record_bytes").getI.toInt

    val fileNames = filesTable.map { t =>
      require(t.length() == 1 && t(1).isInstanceOf[Tensor[ByteString]],
        "Reader can only read one file at a time")
      val fileTensor = t[Tensor[ByteString]](1)
      require(fileTensor.isScalar)
      val file = fileTensor.value()
      file.toStringUtf8
    }

    val isHdfs = allHDFSFiles(fileNames)

    if (isHdfs) {
      require(footerBytes == 0,
        s"Reading from HDFS does not support footer_bytes, but get footer_bytes $footerBytes")
      require(headerBytes == 0,
        s"Reading from HDFS does not support footer_bytes, but get footer_bytes $headerBytes")
      require(hopBytes == 0,
        s"Reading from HDFS does not support footer_bytes, but get footer_bytes $hopBytes")
      fileNames.map { file =>
        val rdd = sc.binaryRecords(file, recordBytes)
        rdd.map { bytes =>
          val table = T()
          val key = Tensor.scalar[ByteString](ByteString.copyFromUtf8("fake_key"))
          val value = Tensor.scalar[ByteString](ByteString.copyFrom(bytes))
          table.insert(key)
          table.insert(value)
          table
        }
      }.reduceLeft(_.union(_))
    } else {
      val result = fileNames.flatMap { file =>
        val iter = new FixedLengthRecordReader(
          new java.io.File(file),
          footerBytes,
          headerBytes,
          hopBytes,
          recordBytes)
        iter
      }.map { record =>
        val table = T()
        val key = Tensor[ByteString](Array(ByteString.copyFromUtf8("fake_key")), Array[Int]())
        val value = Tensor[ByteString](Array(ByteString.copyFrom(record)), Array[Int]())
        table.insert(key)
        table.insert(value)
        table
      }
      sc.parallelize(result, numSlices = Engine.coreNumber() * Engine.nodeNumber())
    }
  }

  private val identityOp = Set("Switch", "Identity", "Merge")
  private def findEnqueueNodes(queueNode: Node[NodeDef]): Seq[Node[NodeDef]] = {
    val queue = mutable.Queue[Node[NodeDef]]()
    val enqueNodes = mutable.ArrayBuffer[Node[NodeDef]]()
    queue.enqueue(queueNode.nextNodes: _*)
    val visited = mutable.HashSet[Node[NodeDef]]()
    while(queue.nonEmpty) {
      val node = queue.dequeue()
      if (!visited(node)) {
        if (node.element != null && enqueueOp(node.element.getOp)) {
          enqueNodes += node
        } else if (node.element != null && identityOp(node.element.getOp)) {
          queue.enqueue(node.nextNodes: _*)
        }
      }
    }
    if (enqueNodes.isEmpty) {
      throw new IllegalArgumentException(
        s"Cannot find enqueue node for queue: ${queueNode.element}")
    } else {
      enqueNodes
    }
  }

  private def handleLocalDequeue(node: Node[NodeDef], cache: DataCache): Seq[Table] = {
    require(node.prevNodes.length == 1, "require QueueDequeueV2 only has one input")
    val queueNode = node.prevNodes.head
    val enqueueNodes = findEnqueueNodes(queueNode)
    val dequeNodeNames = mutable.LinkedHashSet[String]()

    queueNode.nextNodes
      .filter(n => n.element != null && dequeueOp(n.element.getOp))
      .map(n => n.element.getName.split(":")(0)).foreach(dequeNodeNames.add)

    val nameToIndex = dequeNodeNames.zipWithIndex.toMap
    val index = nameToIndex(node.element.getName)
    val nSlices = dequeNodeNames.size

    val dataSeq = if (cache.contains(queueNode.element.getName)) {
      val resultArray = cache(queueNode.element.getName)
      val result = resultArray(index)
      resultArray(index) = null
      result
    } else {
      val allResult = enqueueNodes.map { enqueueNode =>
        val inputs = Seq(enqueueNode.element.getName)
        val result = constructLocalData(inputs, new DataCache())
        if (enqueueNode.element.getOp == "QueueEnqueueManyV2") {
          result.flatMap(BigDLSessionImpl.splitTensorByFirstDim)
        } else {
          result
        }
      }.reduce { (outerSeq1, outerSeq2) =>
        outerSeq1.zip(outerSeq2).map { case (seq1, seq2) =>
          seq1.add(seq2)
        }
      }
      val resultArray = split(allResult, nSlices)
      cache.put(queueNode.element.getName, resultArray)
      resultArray(index)
    }
    dataSeq
  }

  private def batchRdd(rdd: RDD[Table], batchSize: Int): RDD[Table] = {
    rdd.mapPartitions { iter =>

      new Iterator[Table] {
        override def hasNext: Boolean = iter.hasNext

        override def next(): Table = {
          require(iter.hasNext, "Call next() on a empty iterator")

          val tables =
            for (i <- 0 until batchSize if iter.hasNext) yield {
              iter.next()
            }

          val batch = tables.map(_.toSeq[Tensor[T]])
          val firstSeq = batch.head
          val sizes = firstSeq.map { tensor =>
            val nDim = tensor.nDimension()
            val size: Array[Int] = new Array[Int](nDim + 1)
            var i = 1
            while(i <= nDim + 1) {
              if (i < 1) {
                size(i-1) = tensor.size(i)
              } else if (i == 1) {
                size(i-1) = batch.length
              } else {
                size(i-1) = tensor.size(i - 1)
              }
              i = i + 1
            }
            size
          }

          val results = sizes.zipWithIndex.map { case (size, i) =>
            firstSeq(i).emptyInstance().resize(size)
          }

          for ((seq, index) <- batch.zipWithIndex) {
            results.zip(seq).foreach { case (result, tensor) =>
              result.asInstanceOf[Tensor[NumericWildcard]]
                .narrow(1, index + 1, 1)
                .copy(tensor.asInstanceOf[Tensor[NumericWildcard]])
            }
          }
          T.seq(results)
        }
      }

    }
  }

  private def handleDistriDequeue(node: Node[NodeDef], cache: DataCache,
                                 sc: SparkContext): RDD[Table] = {
    val queueNode = node.prevNodes.head
    val dequeueNodes = queueNode.nextNodes
      .filter(n => n.element != null && dequeueOp(n.element.getOp))
      .map(n => n.element.getName.split(":")(0)).toSet
    require(dequeueNodes.size == 1, "only support one dequeue node after reader")
    val enqueueNodes = findEnqueueNodes(queueNode)
    // get previous rdd
    var rdd = enqueueNodes.map { enqueueNode =>
      val inputs = Seq(enqueueNode.element.getName)
      val result = constructDistributeData(inputs, cache, sc)
      if (enqueueNode.element.getOp == "QueueEnqueueManyV2") {
        result.flatMap(BigDLSessionImpl.splitTensorByFirstDim)
      } else {
        result
      }
    }.reduce { (rdd1, rdd2) =>
      rdd1.union(rdd2)
    }

    if (node.element.getOp == "QueueDequeueManyV2") {
      // get batch size
      val batchSizeNode = node.prevNodes(1)
      require(batchSizeNode.element.getOp == "Const", "batchsize must be a const")

      val batchSize = batchSizeNode.element.getAttrMap.get("value").getTensor.getIntVal(0)
      rdd = batchRdd(rdd, batchSize)
    }
    rdd
  }

  type DataCache = mutable.HashMap[String, Array[Seq[Table]]]

  private def adjustInputNames(inputs: Seq[String]): Seq[String] = {
    val stripedNames = inputs.map(_.split(":")(0))
    val set = mutable.LinkedHashSet[String]()
    for (name <- stripedNames) {
      set.add(name)
    }
    set.toSeq
  }

  private def checkAndRemoveQueueNode(tfGraph: DirectedGraph[NodeDef]) = {
    tfGraph.DFS.filter(n => n.element != null && enqueueOp(n.element.getOp))
      .foreach { node =>
        node.prevNodes.head.delete(node)
      }
  }

  private def constructLocalData(endPoints: Seq[String], cache: DataCache): Seq[Table] = {
    val isInputOp = (n: NodeDef) => inputOp(n.getOp)
    val (tfGraph, inputs, originInputs) = TensorflowLoader.
      buildTFGraph(graph.asJava, endPoints, isInputOp)

    checkAndRemoveQueueNode(tfGraph)

    val adjustedInputs = adjustInputNames(originInputs)
    val transformer = TensorflowLoader.buildBigDLModel(
      tfGraph,
      inputs.toSeq.map(_._2).flatten,
      endPoints,
      ByteOrder.LITTLE_ENDIAN,
      "",
      None,
      generatedBackward = false
    ).asInstanceOf[Graph[T]]


    if (adjustedInputs.nonEmpty) {
      val inputNodes = originInputs.map(name2Node)
      val inputDataSeq = inputNodes.map { node => // this is the input op
        node.element.getOp match {
          // only support Dequeue before reader
          case "QueueDequeueV2" => handleLocalDequeue(node, cache)
        }
      }

      val reducedInputSeq = inputDataSeq.reduce { (outerSeq1, outerSeq2) =>
        outerSeq1.zip(outerSeq2).map { case (seq1, seq2) =>
          seq1.add(seq2)
        }
      }

      reducedInputSeq.map { tensors =>
        val output = transformer.forward(tensors.flatten())
        toTable(output)
      }
    } else {
      Seq(toTable(transformer.forward(T())))
    }
  }

  private def toTable(activity: Activity): Table = {
    activity match {
      case t: Tensor[_] => T(t)
      case t: Table => t
    }
  }

  private def constructDistributeData(endPoints: Seq[String], cache: DataCache,
      sc: SparkContext): RDD[Table] = {
    val isInputOp = (n: NodeDef) => inputOp(n.getOp)
    val (tfGraph, inputs, originInputs) =
      TensorflowLoader.buildTFGraph(graph.asJava, endPoints, isInputOp)

    checkAndRemoveQueueNode(tfGraph)

    val adjustedInputs = adjustInputNames(originInputs)

    val inputNodes = adjustedInputs.map(name2Node)

    val transformer = TensorflowLoader.buildBigDLModel(
      tfGraph,
      inputs.toSeq.map(_._2).flatten,
      endPoints,
      ByteOrder.LITTLE_ENDIAN,
      "",
      Some(context),
      generatedBackward = false
    ).asInstanceOf[Graph[T]]

    val inputRdds = inputNodes.map { node => // this is the input op
      node.element.getOp match {
        case "ReaderReadV2" => handleReaderNode(node, cache, sc)
        case "QueueDequeueV2" => handleDistriDequeue(node, cache, sc)
        case "QueueDequeueManyV2" => handleDistriDequeue(node, cache, sc)
      }
    }
    val inputRdd = inputRdds.reduce { (rdd1, rdd2) =>
      rdd1.zip(rdd2).map { case (seq1, seq2) =>
        seq1.add(seq2)
      }
    }

    val modelBroadCast = ModelBroadcast[T]().broadcast(sc, transformer)
    inputRdd.map { tensors =>
      val trans = modelBroadCast.value()
      val output = trans.forward(tensors.flatten())
      output match {
        case t: Tensor[_] => T(t)
        case t: Table => t
      }
    }
  }


  private def constructModel(endPoints: Seq[String], byteOrder: ByteOrder,
    generateBackward: Boolean, context: Option[Context[T]])
      : (Graph[T], Node[NodeDef]) = {
    val isInputOp = (n: NodeDef) => inputOp(n.getOp)
    val (tfGraph, inputs, originInputs) =
      TensorflowLoader.buildTFGraph(graph.asJava, endPoints, isInputOp)

    checkAndRemoveQueueNode(tfGraph)

    val adjustedInputs = adjustInputNames(originInputs)

    val inputNodes = adjustedInputs.map(name2Node)

    require(inputNodes.length == 1, "Only support one model input")

    val model = TensorflowLoader.buildBigDLModel(
      tfGraph,
      inputs.toSeq.map(_._2).flatten,
      endPoints,
      byteOrder,
      "",
      context,
      generateBackward
    ).asInstanceOf[Graph[T]]
    (model, inputNodes.head)
  }

  /**
   * Get and calculate the data up to the specified endpoints, and
   * return as a RDD[Table]
   * @param endPoints output endpoints
   * @param hasToBatch indicate whether the subgraph to be executed already has
   *        to batch operation. If yes, the batch operation will be undone at
   *        the end of this execution, that is split each tensor by their first dimension.
   * @return
   */
  def getRDD(endPoints: Seq[String], sc: SparkContext,
                            hasToBatch: Boolean = true): RDD[Table] = {
    val cache = new mutable.HashMap[String, Array[Seq[Table]]]()
    val result = if (!hasToBatch) {
      constructDistributeData(endPoints, cache, sc)
    } else {
      val batchRdd = constructDistributeData(endPoints, cache, sc)
      batchRdd.flatMap(BigDLSessionImpl.splitTensorByFirstDim)
    }
    result
  }
}

object TFUpdater {
  def apply(node: NodeDef): Option[(String, String)] = {
    node.getOp match {
      case "ApplyRMSProp" =>
        Some((node.getInput(0), node.getInput(7)))
      case _ => None
    }
  }
}

object BigDLSessionImpl {

  class FakeCriterion[T: ClassTag](enable: Boolean = false)(implicit ev: TensorNumeric[T])
    extends AbstractCriterion[Activity, Activity, T] {

    override def updateOutput(input: Activity, target: Activity): T = {
      if (enable) {
        ev.fromType(0.0)
      } else {
        input.toTable.apply[Tensor[T]](1).value()
      }
    }

    override def updateGradInput(input: Activity, target: Activity): Activity = {
      null
    }
  }

  private def splitTensorByFirstDim(table: Table): Array[Table] = {
    val nElem = table.length()
    require(nElem >= 1, "EnqueueManyV2 encounter a empty table")
    val first = table[Tensor[_]](1)
    require(first.nDimension() >= 1)
    val depth = first.size(1)
    val result = new Array[Table](depth)
    var i = 0
    while(i < depth) {
      var j = 0
      val newTable = new Table()
      while (j < nElem) {
        val elem = table[Tensor[ByteString]](j + 1)
        newTable.insert(elem(i + 1))
        j = j + 1
      }
      result(i) = newTable
      i = i + 1
    }
    result
  }

  private def toSample[T: ClassTag](rdd: RDD[Table])
                           (implicit ev: TensorNumeric[T]): RDD[Sample[T]] = {
    rdd.map{ t =>
      val arr = t.toSeq[Tensor[T]].toArray
      Sample[T](arr)
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy