All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.optim.PredictionService.scala Maven / Gradle / Ivy

/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.optim

import java.util.concurrent.LinkedBlockingQueue

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.ArrayValue
import com.intel.analytics.bigdl.serialization.Bigdl.{AttrValue, BigDLTensor, DataType, TensorStorage}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.{NumericBoolean, NumericChar, NumericDouble, NumericFloat, NumericInt, NumericLong, NumericString}
import com.intel.analytics.bigdl.utils._
import com.intel.analytics.bigdl.utils.serializer.converters.DataConverter
import com.intel.analytics.bigdl.utils.serializer.{DeserializeContext, ModuleSerializer, ProtoStorageType, SerializeContext}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.Type
import scala.util.{Failure, Success, Try}

/**
 * 
Thread-safe Prediction Service for Concurrent Calls
* In this service, concurrency is kept not greater than [[numThreads]] by a `BlockingQueue`, * which contains available model instances. *

* [[numThreads]] model instances sharing weights/bias * will be put into the `BlockingQueue` during initialization. *

* When predict method called, service will try to take an instance from `BlockingQueue`, * which means if all instances are on serving, the predicting request will be blocked until * some instances are released. *

* If exceptions caught during predict, * a scalar Tensor[String] will be returned with thrown message. * * @param model BigDL model used to do predictions * @param numThreads max concurrency */ class PredictionService[T: ClassTag] private[optim]( model: Module[T], numThreads: Int )(implicit ev: TensorNumeric[T]) { protected val instQueue: LinkedBlockingQueue[Module[T]] = { val shallowCopies = (1 to numThreads) .map(_ => model.clone(false).evaluate()).asJava new LinkedBlockingQueue[Module[T]](shallowCopies) } /** *
Thread-safe single sample prediction
* Running model prediction with input Activity as soon as * there exists vacant instances(the size of pool is [[numThreads]]). * Otherwise, it will hold on till some instances are released. *

* Outputs will be deeply copied after model prediction, so they are invariant. * * @param request input Activity, could be Tensor or Table(key, Tensor) * @return output Activity, could be Tensor or Table(key, Tensor) */ def predict(request: Activity): Activity = { // Take an instance from blocking queue, // it will cause a thread blocking when no instance is available. val module = instQueue.take() // do predictions val forwardResult = Try(module.forward(request)) match { case Success(activity) => activity case Failure(e) => errorTensor("running forward", e) } // cloned values after prediction finished val output = try { forwardResult match { case tensor: Tensor[_] => tensor.clone() case table: Table => val clonedMap = mutable.HashMap[Any, Any]() table.getState().foreach { x => (x: @unchecked) match { case (k: Tensor[_], v: Tensor[_]) => clonedMap += k.clone() -> v.clone() case (k, v: Tensor[_]) => clonedMap += k -> v.clone() } } new Table(clonedMap) } } catch { case e: Throwable => errorTensor("Clone Result", e) } finally { // Release module instance back to blocking queue instQueue.offer(module) } output } /** *
Thread-safe single sample prediction
* Firstly, deserialization tasks will be run with inputs(Array[Byte]). *

* Then, run model prediction with deserialized inputs * as soon as there exists vacant instances(total number is [[numThreads]]). * Otherwise, it will hold on till some instances are released. *

* Finally, prediction results will be serialized to Array[Byte] according to BigDL.proto. * * @param request input bytes, which will be deserialized by BigDL.proto * @return output bytes, which is serialized by BigDl.proto */ def predict(request: Array[Byte]): Array[Byte] = { val output = Try( PredictionService.deSerializeActivity(request) ) match { case Success(activity) => predict(activity) case Failure(e) => errorTensor("DeSerialize Input", e) } val bytesOut = try { PredictionService.serializeActivity(output) } catch { case e: Throwable => val act = errorTensor("Serialize Output", e) PredictionService.serializeActivity(act) } bytesOut } private def errorTensor(stage: String, e: Throwable): Tensor[String] = { val msg = s"Exception caught during [$stage]! \n" + s"The message is ${e.getMessage} \n" + s"The cause is ${e.getCause}" Tensor.scalar(msg) } } object PredictionService { /** *
Thread-safe Prediction Service for Concurrent Calls
* In this service, concurrency is kept not greater than `numThreads` by a `BlockingQueue`, * which contains available model instances. *

* If exceptions caught during predict, * a scalar Tensor[String] will be returned with thrown message. * * @param model BigDL model used to do predictions * @param numThreads max concurrency * @return a PredictionService instance */ def apply[T: ClassTag]( model: Module[T], numThreads: Int )(implicit ev: TensorNumeric[T]): PredictionService[T] = { new PredictionService[T](model, numThreads) } /** *
Serialize activities to Array[Byte] according to `Bigdl.proto`.
* For now, `Tensor` and `Table[primitive|Tensor, Tensor]` are supported. * * @param activity activity to be serialized */ def serializeActivity(activity: Activity): Array[Byte] = { val attrBuilder = AttrValue.newBuilder() activity match { case table: Table => var keyIsPrimitive = true val firstKey = table.getState().head._1 val tensorState: Array[(Tensor[_], Tensor[_])] = firstKey match { case _: Tensor[_] => keyIsPrimitive = false table.getState().map { x => (x: @unchecked) match { case (k: Tensor[_], v: Tensor[_]) => k -> v }}.toArray case _: Int => table.getState().map { x => (x: @unchecked) match { case (k: Int, v: Tensor[_]) => Tensor.scalar(k) -> v }}.toArray case _: Long => table.getState().map { x => (x: @unchecked) match { case (k: Long, v: Tensor[_]) => Tensor.scalar(k) -> v }}.toArray case _: Char => table.getState().map { x => (x: @unchecked) match { case (k: Char, v: Tensor[_]) => Tensor.scalar(k) -> v }}.toArray case _: Short => table.getState().map {x => (x: @unchecked) match { case (k: Short, v: Tensor[_]) => Tensor.scalar(k) -> v }}.toArray case _: Float => table.getState().map { x => (x: @unchecked) match { case (k: Float, v: Tensor[_]) => Tensor.scalar(k) -> v }}.toArray case _: Double => table.getState().map { x => (x: @unchecked) match { case (k: Double, v: Tensor[_]) => Tensor.scalar(k) -> v }}.toArray case _: Boolean => table.getState().map { x => (x: @unchecked) match { case (k: Boolean, v: Tensor[_]) => Tensor.scalar(k) -> v }}.toArray case _: String => table.getState().map { x => (x: @unchecked) match { case (k: String, v: Tensor[_]) => Tensor.scalar(k) -> v }}.toArray case key => throw new UnsupportedOperationException(s"Unsupported Table key: $key!") } val (keys, values) = tensorState.unzip // tensors structure: [isKeyPrimitive, keys, values] val tensors = Array(Tensor.scalar(keyIsPrimitive)) ++ keys ++ values val arrayValue = ArrayValue.newBuilder arrayValue.setDatatype(DataType.TENSOR) arrayValue.setSize(tensors.length) tensors.foreach { tensor => arrayValue.addTensor(buildBigDLTensor(tensor, attrBuilder)) attrBuilder.clear() } attrBuilder.setDataType(DataType.ARRAY_VALUE) attrBuilder.setArrayValue(arrayValue) case tensor: Tensor[_] => attrBuilder.setTensorValue(buildBigDLTensor(tensor, attrBuilder)) case _ => throw new UnsupportedOperationException("Unsupported Activity Type!") } val attr = attrBuilder.build() attr.toByteArray } /** *
Deserialize Array[Byte] to activities according to `Bigdl.proto`.
* For now, `Tensor` and `Table[primitive|Tensor, Tensor]` are supported. * It will convert `AttrValue(Array(BigdlTensor))` to a `Table`. * It will convert `AttrValue(BigdlTensor) ` to a `Tensor`. * * @param bytes bytes data for Activity to be deserialized */ def deSerializeActivity(bytes: Array[Byte]): Activity = { val attr = AttrValue.parseFrom(bytes) attr.getDataType match { case DataType.ARRAY_VALUE => val dataType = attr.getArrayValue.getTensor(0).getDatatype // tensors structure: [isKeyPrimitive, keys, values] val tensors = getAttr(dataType, attr).asInstanceOf[Array[Tensor[_]]] val nElement = (tensors.length - 1) / 2 val keyIsPrimitive = tensors.head.asInstanceOf[Tensor[Boolean]].value() val _keys = tensors.slice(1, nElement + 1) val keys = if (keyIsPrimitive) _keys.map(_.value()) else _keys val values = tensors.slice(nElement + 1, tensors.length) val table = T() keys.zip(values).foreach { case(k, v) => table.update(k, v) } table case DataType.TENSOR => val tValue = attr.getTensorValue val tensor = getAttr(tValue.getDatatype, attr) tensor.asInstanceOf[Tensor[_]] case tpe => throw new UnsupportedOperationException(s"Unsupported DataType($tpe)!") } } private def buildBigDLTensor(tensor: Tensor[_], attrBuilder: AttrValue.Builder): BigDLTensor = { val status = mutable.HashMap[Int, Any]() val partial = partialSetAttr(tensor.getTensorNumeric(), status) partial(attrBuilder, tensor, ModuleSerializer.tensorType) val tensorId = System.identityHashCode(tensor) val _tensor = status(tensorId).asInstanceOf[BigDLTensor] val tensorBuilder = BigDLTensor.newBuilder(_tensor) val storageId = System.identityHashCode(tensor.storage().array()) val _storage = status(storageId).asInstanceOf[TensorStorage] tensorBuilder.setStorage(_storage) tensorBuilder.build() } private def partialSetAttr(numeric: TensorNumeric[_], status: mutable.HashMap[Int, Any]) = { numeric match { case NumericFloat => val sc = SerializeContext[Float](null, status, ProtoStorageType) (attrBuilder: AttrValue.Builder, value: Any, tpe: Type) => DataConverter.setAttributeValue[Float](sc, attrBuilder, value, tpe) case NumericDouble => val sc = SerializeContext[Double](null, status, ProtoStorageType) (attrBuilder: AttrValue.Builder, value: Any, tpe: Type) => DataConverter.setAttributeValue[Double](sc, attrBuilder, value, tpe) case NumericChar => val sc = SerializeContext[Char](null, status, ProtoStorageType) (attrBuilder: AttrValue.Builder, value: Any, tpe: Type) => DataConverter.setAttributeValue[Char](sc, attrBuilder, value, tpe) case NumericBoolean => val sc = SerializeContext[Boolean](null, status, ProtoStorageType) (attrBuilder: AttrValue.Builder, value: Any, tpe: Type) => DataConverter.setAttributeValue[Boolean](sc, attrBuilder, value, tpe) case NumericString => val sc = SerializeContext[String](null, status, ProtoStorageType) (attrBuilder: AttrValue.Builder, value: Any, tpe: Type) => DataConverter.setAttributeValue[String](sc, attrBuilder, value, tpe) case NumericInt => val sc = SerializeContext[Int](null, status, ProtoStorageType) (attrBuilder: AttrValue.Builder, value: Any, tpe: Type) => DataConverter.setAttributeValue[Int](sc, attrBuilder, value, tpe) case NumericLong => val sc = SerializeContext[Long](null, status, ProtoStorageType) (attrBuilder: AttrValue.Builder, value: Any, tpe: Type) => DataConverter.setAttributeValue[Long](sc, attrBuilder, value, tpe) } } private def getAttr(dataType: DataType, attr: AttrValue) = { val status = mutable.HashMap[Int, Any]() val dsc = DeserializeContext(null, status, ProtoStorageType) dataType match { case DataType.INT32 => DataConverter.getAttributeValue[Int](dsc, attr) case DataType.INT64 => DataConverter.getAttributeValue[Long](dsc, attr) case DataType.FLOAT => DataConverter.getAttributeValue[Float](dsc, attr) case DataType.DOUBLE => DataConverter.getAttributeValue[Double](dsc, attr) case DataType.STRING => DataConverter.getAttributeValue[String](dsc, attr) case DataType.BOOL => DataConverter.getAttributeValue[Boolean](dsc, attr) case DataType.CHAR => DataConverter.getAttributeValue[Char](dsc, attr) case _ => throw new UnsupportedOperationException(s"Unsupported DataType($dataType)!") } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy