All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.tfpark.TFDataFeatureSet.scala Maven / Gradle / Ivy

/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.tfpark

import com.intel.analytics.bigdl.dataset.{DistributedDataSet, MiniBatch}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.zoo.feature.{DistributedDataSetWrapper, DistributedFeatureSet}
import com.intel.analytics.zoo.pipeline.api.keras.layers.utils.EngineRef
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.tensorflow.DataType

import com.intel.analytics.zoo.tfpark.TFTensorNumeric.NumericByteArray


class TFDataFeatureSet(private val graphRDD: RDD[Array[Byte]],
                       private val initIteratorOp: String,
                       private val initTableOp: String,
                       private val outputNames: Array[String],
                       private val outputTypes: Array[DataType],
                       private val shardIndex: String,
                       sessionConfig: SessionConfig)
  extends DistributedFeatureSet[MiniBatch[Float]] {

  private val graphRunnerRDD = getGraphRunnerRDD(graphRDD)

  private def getGraphRunnerRDD(rdd: RDD[Array[Byte]]): RDD[GraphRunner] = {

    val config = sessionConfig.toByteArray()

    val graphRunnerRDD = rdd.mapPartitions { iter =>
      if (iter.hasNext) {
        val graphDef = iter.next()
        val runner = GraphRunner(graphDef,
          null, null, null, null, config)
        Iterator.single(runner)
      } else {
        throw new IllegalArgumentException("the input dataset rdd has an empty partition")
      }

    }.setName("GraphRunnerRDD").cache()
    graphRunnerRDD.count()
    graphRunnerRDD
  }
  override def originRDD(): RDD[_] = {
    graphRunnerRDD
  }

  override def data(train: Boolean): RDD[MiniBatch[Float]] = {
    val initOp = this.initIteratorOp
    val names = this.outputNames.toVector
    val types = this.outputTypes.toVector
    val shardIdx = this.shardIndex
    val initTableOp = this.initTableOp

    graphRunnerRDD.mapPartitionsWithIndex { case (idx, dataIter) =>
      val graphRunner = dataIter.next()
      TFDataFeatureSet.makeIterators(
        graphRunner,
        train,
        initOp,
        initTableOp,
        idx,
        shardIdx,
        types,
        names
      )
    }
  }

  override def shuffle(): Unit = {

  }

  override def size(): Long = {
    -1
  }

  override def toDistributed(): DistributedDataSet[MiniBatch[Float]] = {
    new DistributedDataSetWrapper[MiniBatch[Float]](this)
  }
}

object TFDataFeatureSet {
  def apply(graph: Array[Byte],
            initIteratorOp: String,
            initTableOp: String,
            outputNames: Array[String],
            outputTypes: Array[Int],
            shardIndex: String,
            interOpParallelismThreads: Int,
            intraOpParallelismThreads: Int
           ): TFDataFeatureSet = {
    val types = outputTypes.map(TFUtils.tfenum2datatype)
    new TFDataFeatureSet(createGraphRDD(graph),
      initIteratorOp, initTableOp, outputNames, types, shardIndex,
      sessionConfig = SessionConfig(intraOpParallelismThreads = intraOpParallelismThreads,
        interOpParallelismThreads = interOpParallelismThreads))
  }

  def apply(graphRDD: RDD[Array[Byte]],
            initIteratorOp: String,
            initTableOp: String,
            outputNames: Array[String],
            outputTypes: Array[Int],
            shardIndex: String,
            interOpParallelismThreads: Int,
            intraOpParallelismThreads: Int
           ): TFDataFeatureSet = {
    val types = outputTypes.map(TFUtils.tfenum2datatype)
    val nodeNumber = EngineRef.getNodeNumber()
    require(nodeNumber == graphRDD.getNumPartitions,
      s"number partitions should be the same as node number, " +
      s"got number partitions ${graphRDD.getNumPartitions}, node number ${nodeNumber}")
    new TFDataFeatureSet(graphRDD, initIteratorOp, initTableOp, outputNames, types, shardIndex,
      sessionConfig = SessionConfig(intraOpParallelismThreads = intraOpParallelismThreads,
        interOpParallelismThreads = interOpParallelismThreads))
  }

  private[zoo] def createGraphRDD(graph: Array[Byte]): RDD[Array[Byte]] = {
    val sc = SparkContext.getOrCreate()
    val nodeNumber = EngineRef.getNodeNumber()
    val coreNumber = EngineRef.getCoreNumber()

    val broadcastedGraph = sc.broadcast(graph)
    val originRdd = sc.parallelize(
      Array.tabulate(nodeNumber * 20)(_ => 0), nodeNumber * 10)
      .mapPartitions(_ => (0 until 20).toIterator)
      .coalesce(nodeNumber)
      .setName("PartitionRDD")
      .persist(StorageLevel.DISK_ONLY)
    originRdd.count()
    originRdd.mapPartitions { _ =>
      val graphDef = broadcastedGraph.value
      Iterator.single(graphDef)
    }.setName("GraphRDD")
  }

  private[zoo] def generateOutputTensors(types: Vector[DataType]) = {
    val outputs = Array.tabulate[Tensor[_]](types.length) { i =>
      if (types(i) == DataType.STRING) {
        Tensor[Array[Byte]]()
      } else {
        Tensor[Float]()
      }
    }
    outputs
  }

  private[zoo] def makeIterators(graphRunner: GraphRunner,
                                 train: Boolean,
                                 initOp: String,
                                 initTableOp: String,
                                 idx: Int,
                                 shardIdx: String,
                                 types: Vector[DataType],
                                 names: Vector[String]): Iterator[TFMiniBatch] = {
    def intiIterator(): Unit = {
      if (shardIdx != null) {
        graphRunner.runTargets(Vector(initOp, initTableOp),
          inputs = Vector(Tensor.scalar[Float](idx.toFloat)),
          inputTypes = Vector(DataType.INT64),
          inputNames = Vector(shardIdx))
      } else {
        graphRunner.runTargets(Vector(initOp, initTableOp))
      }

    }
    if (train) {
      new Iterator[TFMiniBatch] {

        override def hasNext(): Boolean = {
          true
        }

        private def getNext() = {
          val outputs = TFDataFeatureSet.generateOutputTensors(types)
          val outputVec = outputs.toVector
          try {
            graphRunner.runOutputs(outputVec, names, types)
          } catch {
            case _: java.lang.IndexOutOfBoundsException =>
              intiIterator()
              graphRunner.runOutputs(outputVec, names, types)
            case _: java.lang.IllegalStateException =>
              intiIterator()
              graphRunner.runOutputs(outputVec, names, types)
            case e: Throwable => throw e
          }
          outputs
        }

        override def next(): TFMiniBatch = {
          TFMiniBatch(getNext())
        }
      }
    } else {
      intiIterator()
      new Iterator[TFMiniBatch] {

        private var buffer: Array[Tensor[_]] = null
        override def hasNext(): Boolean = {
          if (buffer != null) {
            true
          } else {
            val (success, result) = getNext()
            if (success) {
              buffer = result
            }
            success
          }
        }

        private def getNext() = {
          val outputs = TFDataFeatureSet.generateOutputTensors(types)
          val outputVec = outputs.toVector
          val success = try {
            graphRunner.runOutputs(outputVec, names, types)
            true
          } catch {
            case _: java.lang.IndexOutOfBoundsException => false
            case e: Throwable => throw e
          }
          (success, outputs)
        }

        override def next(): TFMiniBatch = {
          if (hasNext()) {
            val result = TFMiniBatch(buffer)
            buffer = null
            result
          } else {
            throw new NoSuchElementException("Next on an empty iterator")
          }
        }
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy