All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.spotify.scio.tensorflow.TensorFlowDoFns.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2019 Spotify AB.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.spotify.scio.tensorflow

import java.time.Duration
import java.util.concurrent.ConcurrentHashMap
import com.spotify.zoltar.tf.{TensorFlowLoader, TensorFlowModel}
import com.spotify.zoltar.Model
import org.apache.beam.sdk.transforms.DoFn.{
  Element,
  OutputReceiver,
  ProcessElement,
  Setup,
  Teardown
}
import org.slf4j.LoggerFactory
import org.tensorflow._
import org.tensorflow.types.TString
import org.tensorflow.ndarray.NdArrays
import org.tensorflow.proto.example.Example

import scala.jdk.CollectionConverters._
import com.spotify.scio.transforms.DoFnWithResource
import com.spotify.scio.transforms.DoFnWithResource.ResourceType
import com.spotify.zoltar.Model.Id

sealed trait PredictDoFn[T, V, M <: Model[_]]
    extends DoFnWithResource[T, V, PredictDoFn.Resource[M]] {
  import PredictDoFn._

  def modelId: String

  def loadModel(): M

  def model: M = getResource.get(modelId)._2

  def withRunner(f: Session#Runner => V): V

  def extractInput(input: T): Map[String, Tensor]

  def extractOutput(input: T, out: Map[String, Tensor]): V

  def outputTensorNames: Seq[String]

  override def createResource(): Resource[M] = new ConcurrentHashMap[String, (Int, M)]()

  override def getResourceType: DoFnWithResource.ResourceType = ResourceType.PER_CLASS

  @Setup
  override def setup(): Unit = {
    super.setup()
    getResource.compute(
      modelId,
      {
        case (_, null)              => 1 -> loadModel()
        case (_, (refCount, model)) => (refCount + 1) -> model
      }
    )
  }

  /** Process an element asynchronously. */
  @ProcessElement
  def processElement(@Element input: T, out: OutputReceiver[V]): Unit = {
    val result = withRunner { runner =>
      val i = extractInput(input)
      var result: V = null.asInstanceOf[V]

      try {
        i.foreach { case (op, t) => runner.feed(op, t) }
        outputTensorNames.foreach(runner.fetch)
        val outTensors = runner.run()
        try {
          result = extractOutput(
            input,
            outputTensorNames.iterator.zip(outTensors.iterator().asScala).toMap
          )
        } finally {
          Log.debug("Closing down output tensors")
          outTensors.asScala.foreach(_.close())
        }
      } finally {
        Log.debug("Closing down input tensors")
        i.foreach { case (_, t) => t.close() }
      }

      result
    }

    out.output(result)
  }

  @Teardown
  override def teardown(): Unit = {
    Log.info("Tearing down predict DoFn {}", this)
    getResource.compute(
      modelId,
      {
        case (_, null) =>
          Log.warn("No model to close while tearing down predict DoFn")
          null
        case (_, (1, model)) =>
          model.close()
          null
        case (_, (refCount, model)) =>
          (refCount - 1) -> model
      }
    )
  }
}

object PredictDoFn {
  type Resource[M <: Model[_]] = ConcurrentHashMap[String, (Int, M)]

  private val Log = LoggerFactory.getLogger(this.getClass)
}

abstract private[tensorflow] class SavedBundlePredictDoFn[T, V](
  uri: String,
  signatureName: String,
  options: TensorFlowModel.Options
) extends PredictDoFn[T, V, TensorFlowModel] {
  override def modelId: String =
    s"tf:$uri:$signatureName:${options.tags.asScala.mkString(":")}"

  override def loadModel(): TensorFlowModel =
    TensorFlowLoader
      .create(Id.create(modelId), uri, options, signatureName)
      .get(Duration.ofDays(Integer.MAX_VALUE))

  override def withRunner(f: Session#Runner => V): V =
    f(model.instance().session().runner())
}

object SavedBundlePredictDoFn {
  def forRaw[T, V](
    uri: String,
    fetchOps: Seq[String],
    options: TensorFlowModel.Options,
    signatureName: String,
    inFn: T => Map[String, Tensor],
    outFn: (T, Map[String, Tensor]) => V
  ): SavedBundlePredictDoFn[T, V] = new SavedBundlePredictDoFn[T, V](uri, signatureName, options) {
    override def extractInput(input: T): Map[String, Tensor] = inFn(input)

    override def extractOutput(input: T, out: Map[String, Tensor]): V = outFn(input, out)

    override def outputTensorNames: Seq[String] = fetchOps

    override def modelId: String = s"${super.modelId}:${fetchOps.mkString(":")}"
  }

  def forInput[T, V](
    uri: String,
    fetchOps: Option[Seq[String]],
    options: TensorFlowModel.Options,
    signatureName: String,
    inFn: T => Map[String, Tensor],
    outFn: (T, Map[String, Tensor]) => V
  ): SavedBundlePredictDoFn[T, V] = new SavedBundlePredictDoFn[T, V](uri, signatureName, options) {
    private lazy val exportedFetchOps =
      model.outputsNameMap().asScala.toMap
    private lazy val requestedFetchOps: Map[String, String] = fetchOps
      .map { tensorIds =>
        tensorIds.iterator.map(tensorId => tensorId -> exportedFetchOps(tensorId)).toMap
      }
      .getOrElse(exportedFetchOps)

    override def extractInput(input: T): Map[String, Tensor] = {
      val extractedInput = inFn(input)
      extractedInput.iterator.map { case (tensorId, tensor) =>
        model.inputsNameMap().get(tensorId) -> tensor
      }.toMap
    }

    override def extractOutput(input: T, out: Map[String, Tensor]): V =
      outFn(
        input,
        requestedFetchOps.iterator.map { case (tensorId, opName) =>
          tensorId -> out(opName)
        }.toMap
      )

    override def outputTensorNames: Seq[String] = requestedFetchOps.values.toSeq

    override def modelId: String = s"${super.modelId}:${fetchOps.toList.mkString(":")}"
  }

  /**
   * Note: if fetchOps isn't provided, then all outputs defined in the signature of the model are
   * retrieved. This can be expensive and unwanted, depending on the model.
   */
  def forTensorFlowExample[T, V](
    uri: String,
    exampleTensorName: String,
    fetchOps: Option[Seq[String]],
    options: TensorFlowModel.Options,
    signatureName: String,
    outFn: (T, Map[String, Tensor]) => V
  )(implicit ev: T <:< Example): SavedBundlePredictDoFn[T, V] =
    new SavedBundlePredictDoFn[T, V](uri, signatureName, options) {
      private lazy val exportedFetchOps =
        model.outputsNameMap().asScala.toMap
      private lazy val requestedFetchOps: Map[String, String] = fetchOps
        .map { tensorIds =>
          tensorIds.iterator.map(tensorId => tensorId -> exportedFetchOps(tensorId)).toMap
        }
        .getOrElse(exportedFetchOps)

      override def outputTensorNames: Seq[String] = requestedFetchOps.values.toSeq

      override def extractInput(input: T): Map[String, Tensor] = {
        val opName = model.inputsNameMap().get(exampleTensorName)
        val bytes = NdArrays.vectorOfObjects(input.toByteArray())
        val tensor = TString.tensorOfBytes(bytes)
        Map(opName -> tensor)
      }

      override def extractOutput(input: T, out: Map[String, Tensor]): V =
        outFn(
          input,
          requestedFetchOps.iterator.map { case (tensorId, opName) =>
            tensorId -> out(opName)
          }.toMap
        )

      override def modelId: String = s"${super.modelId}:${fetchOps.toList.mkString(":")}"
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy