com.intel.analytics.zoo.serving.http.FrontEndApp.scala Maven / Gradle / Ivy
/*
* Copyright 2018 Analytics Zoo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.zoo.serving.http
import java.io.File
import java.security.{KeyStore, SecureRandom}
import java.util.concurrent.TimeUnit
import javax.net.ssl.{KeyManagerFactory, SSLContext, TrustManagerFactory}
import akka.actor.{ActorRef, ActorSystem, Props}
import akka.http.scaladsl.{ConnectionContext, Http}
import akka.http.scaladsl.server.Directives.{complete, path, _}
import akka.pattern.ask
import akka.stream.ActorMaterializer
import akka.util.Timeout
import com.codahale.metrics.{MetricRegistry, Timer}
import com.intel.analytics.zoo.pipeline.inference.EncryptSupportive
import com.intel.analytics.zoo.serving.utils.Conventions
import org.apache.logging.log4j.LogManager
import org.slf4j.LoggerFactory
import scala.collection.mutable
import scala.concurrent.Await
object FrontEndApp extends Supportive with EncryptSupportive {
override val logger = LoggerFactory.getLogger(getClass)
val name = "analytics zoo web serving frontend"
implicit val system = ActorSystem("zoo-serving-frontend-system")
implicit val materializer = ActorMaterializer()
implicit val executionContext = system.dispatcher
implicit val timeout: Timeout = Timeout(100, TimeUnit.SECONDS)
def main(args: Array[String]): Unit = {
timing(s"$name started successfully.")() {
val arguments = timing("parse arguments")() {
argumentsParser.parse(args, FrontEndAppArguments()) match {
case Some(arguments) => logger.info(s"starting with $arguments"); arguments
case None => argumentsParser.failure("miss args, please see the usage info"); null
}
}
val servableManager = new ServableManager
logger.info("Multi Serving Mode")
timing("load servable manager")() {
try servableManager.load(arguments.servableManagerPath)
catch {
case e: ServableLoadException =>
throw e
case e =>
val exampleYaml =
"""
---
modelMetaDataList:
- !
modelName: "1"
modelVersion:"1.0"
redisHost: "localhost"
redisPort: "6381"
redisInputQueue: "serving_stream2"
redisOutputQueue: "cluster-serving_serving_stream2:"
- !
modelName: "1"
modelVersion:"1.0"
modelPath:"/"
modelType:"OpenVINO"
features:
- "a"
- "b"
"""
logger.info("Example Format of Input:" + exampleYaml)
throw e
}
}
logger.info("Servable Manager Load Success!")
var redisPutter : ActorRef = null
val route = timing("initialize http route")() {
path("") {
timing("welcome")(overallRequestTimer) {
complete("welcome to " + name)
}
} ~ (get & path("metrics")) {
timing("metrics")(overallRequestTimer, metricsRequestTimer) {
val keys = metrics.getTimers().keySet()
val servingMetrics = keys.toArray.map(key => {
val timer = metrics.getTimers().get(key)
ServingTimerMetrics(key.toString, timer)
}).toList
complete(jacksonJsonSerializer.serialize(servingMetrics))
}
} ~ (post & path("model-secure") &
extract(_.request.entity.contentType) & entity(as[String])) {
(contentType, content) => {
try {
if (redisPutter == null) {
val redisPutterName = s"redis-putter"
redisPutter = timing(s"$redisPutterName initialized.")() {
val redisPutterProps = Props(new RedisPutActor(
arguments.redisHost,
arguments.redisPort,
arguments.redisInputQueue,
arguments.redisOutputQueue,
arguments.timeWindow,
arguments.countWindow,
arguments.redisSecureEnabled,
arguments.redissTrustStorePath,
arguments.redissTrustStoreToken))
system.actorOf(redisPutterProps, name = redisPutterName)
}
}
val secrets = content.split("&")
val secret = secrets(0).split("=")(1)
val salt = secrets(1).split("=")(1)
val message = SecuredModelSecretSaltMessage(secret, salt)
val result = Await.result(redisPutter ? message, timeout.duration)
.asInstanceOf[Boolean]
result match {
case true => complete("model secured secrect and salt succeed to put in redis")
case false => complete("model secured secrect and salt failed to put in redis")
}
} catch {
case e: Exception =>
e.printStackTrace()
val error = ServingError(e.getMessage + "\n please post a content like " +
"secret=xxx&salt=xxxx")
complete(500, error.toString)
}
}
} ~ (get & path("models")) {
timing("get all model infos")(overallRequestTimer, servablesRetriveTimer) {
try {
val servables = servableManager.retriveAllServables
val metaData = servables.map(e => e.getMetaData)
val json = JsonUtil.toJson(metaData)
complete(200, json)
}
catch {
case e: ModelNotFoundException =>
complete(404, "Model Not Found")
case e: ServingRuntimeException =>
complete(405, "Serving Runtime Error Err: " + e)
case e =>
complete(500, "Internal Error: " + e)
}
}
} ~ pathPrefix("models") {
concat(
(get & path(Segment)) {
(modelName) => {
timing("get model infos with model name")(overallRequestTimer,
servablesRetriveTimer) {
try {
val servables = servableManager.retriveServables(modelName)
val metaData = servables.map(e => e.getMetaData)
val json = JsonUtil.toJson(metaData)
complete(200, json)
}
catch {
case e: ModelNotFoundException =>
complete(404, "Model Not Found")
case e: ServingRuntimeException =>
complete(405, "Serving Runtime Error Err: " + e)
case e =>
complete(500, "Internal Error: " + e)
}
}
}
} ~ (get & path(Segment / "versions" / Segment)) {
(modelName, modelVersion) => {
timing("get model info with model name and model version")(overallRequestTimer,
servableRetriveTimer) {
try {
val servables = servableManager.retriveServable(modelName, modelVersion)
val metaData = servables.getMetaData
val json = JsonUtil.toJson(metaData)
complete(200, json)
}
catch {
case e: ModelNotFoundException =>
complete(404, "Model Not Found")
case e: ServingRuntimeException =>
complete(405, "Serving Runtime Error Err: " + e)
case e =>
complete(500, "Internal Error: " + e)
}
}
}
} ~ (post & path(Segment / "versions" / Segment / "predict")
& extract(_.request.entity.contentType) & entity(as[String])) {
(modelName, modelVersion, contentType, content) => {
timing("backend inference")(overallRequestTimer, backendInferenceTimer) {
try {
logger.info("model name: " + modelName + ", model version: " + modelVersion)
val servable = timing("servable retrive")(servableRetriveTimer) {
servableManager.retriveServable(modelName, modelVersion)
}
val modelInferenceTimer = modelInferenceTimersMap(modelName)(modelVersion)
servable match {
case _: ClusterServingServable =>
val result = timing("cluster serving inference")(predictRequestTimer) {
val instances = timing("json deserialization")() {
JsonUtil.fromJson(classOf[Instances], content)
}
val outputs = timing("model inference")(modelInferenceTimer) {
servable.predict(instances)
}
Predictions(outputs)
}
timing("cluster serving response complete")() {
complete(200, result.toString)
}
case _: InferenceModelServable =>
val result = timing("inference model inference")(predictRequestTimer) {
val outputs = servable.getMetaData.
asInstanceOf[InferenceModelMetaData].inputCompileType match {
case "direct" => timing("model inference")(modelInferenceTimer) {
servable.predict(content)
}
case "instance" =>
val instances = timing("json deserialization")() {
JsonUtil.fromJson(classOf[Instances], content)
}
timing("model inference")(modelInferenceTimer) {
servable.predict(instances)
}
}
JsonUtil.toJson(outputs.map(_.result))
}
timing("inference model response complete")() {
complete(200, result)
}
}
}
catch {
case e: ModelNotFoundException =>
complete(404, "Model Not Found. Err: " + e.message)
case e: ServingRuntimeException =>
complete(405, "Serving Runtime Error Err: " + e.message)
case e =>
e.printStackTrace()
complete(500, "Internal Error: " + e)
}
}
}
}
)
}
}
if (arguments.httpsEnabled) {
val serverContext = defineServerContext(arguments.httpsKeyStoreToken,
arguments.httpsKeyStorePath)
Http().bindAndHandle(route, arguments.interface, port = arguments.securePort,
connectionContext = serverContext)
logger.info(s"https started at https://${arguments.interface}:${arguments.securePort}")
}
Http().bindAndHandle(route, arguments.interface, arguments.port)
logger.info(s"http started at http://${arguments.interface}:${arguments.port}")
}
}
val metrics = new MetricRegistry
val overallRequestTimer = metrics.timer("zoo.serving.request.overall")
val predictRequestTimer = metrics.timer("zoo.serving.request.predict")
val servableRetriveTimer = metrics.timer("zoo.serving.retrive.servable")
val servablesRetriveTimer = metrics.timer("zoo.serving.retrive.servables")
val backendInferenceTimer = metrics.timer("zoo.serving.backend.inference")
val putRedisTimer = metrics.timer("zoo.serving.redis.put")
val getRedisTimer = metrics.timer("zoo.serving.redis.get")
val waitRedisTimer = metrics.timer("zoo.serving.redis.wait")
val metricsRequestTimer = metrics.timer("zoo.serving.request.metrics")
val modelInferenceTimersMap = new mutable.HashMap[String, mutable.HashMap[String, Timer]]
val purePredictTimersMap = new mutable.HashMap[String, mutable.HashMap[String, Timer]]
val makeActivityTimer = metrics.timer("zoo.serving.activity.make")
val handleResponseTimer = metrics.timer("zoo.serving.response.handling")
val jacksonJsonSerializer = new JacksonJsonSerializer()
val argumentsParser = new scopt.OptionParser[FrontEndAppArguments]("AZ Serving") {
head("Analytics Zoo Serving Frontend")
opt[String]('i', "interface")
.action((x, c) => c.copy(interface = x))
.text("network interface of frontend")
opt[Int]('p', "port")
.action((x, c) => c.copy(port = x))
.text("network port of frontend")
opt[Int]('s', "securePort")
.action((x, c) => c.copy(securePort = x))
.text("https port of frontend")
opt[String]('h', "redisHost")
.action((x, c) => c.copy(redisHost = x))
.text("host of redis")
opt[Int]('r', "redisPort")
.action((x, c) => c.copy(redisPort = x))
.text("port of redis")
opt[String]('i', "redisInputQueue")
.action((x, c) => c.copy(redisInputQueue = x))
.text("input queue of redis")
opt[String]('o', "redisOutputQueue")
.action((x, c) => c.copy(redisOutputQueue = x))
.text("output queue of redis")
opt[Int]('l', "parallelism")
.action((x, c) => c.copy(parallelism = x))
.text("parallelism of frontend")
opt[Int]('t', "timeWindow")
.action((x, c) => c.copy(timeWindow = x))
.text("timeWindow of frontend")
opt[Int]('c', "countWindow")
.action((x, c) => c.copy(countWindow = x))
.text("countWindow of frontend")
opt[Boolean]('e', "tokenBucketEnabled")
.action((x, c) => c.copy(tokenBucketEnabled = x))
.text("Token Bucket Enabled or not")
opt[Int]('k', "tokensPerSecond")
.action((x, c) => c.copy(tokensPerSecond = x))
.text("tokens per second")
opt[Int]('a', "tokenAcquireTimeout")
.action((x, c) => c.copy(tokenAcquireTimeout = x))
.text("token acquire timeout")
opt[Boolean]('s', "httpsEnabled")
.action((x, c) => c.copy(httpsEnabled = x))
.text("https enabled or not")
opt[String]('p', "httpsKeyStorePath")
.action((x, c) => c.copy(httpsKeyStorePath = x))
.text("https keyStore path")
opt[String]('w', "httpsKeyStoreToken")
.action((x, c) => c.copy(httpsKeyStoreToken = x))
.text("https keyStore token")
opt[Boolean]('s', "redisSecureEnabled")
.action((x, c) => c.copy(redisSecureEnabled = x))
.text("redis secure enabled or not")
opt[Boolean]('s', "httpsEnabled")
.action((x, c) => c.copy(httpsEnabled = x))
.text("https enabled or not")
opt[String]('p', "redissTrustStorePath")
.action((x, c) => c.copy(redissTrustStorePath = x))
.text("rediss trustStore path")
opt[String]('w', "redissTrustStoreToken")
.action((x, c) => c.copy(redissTrustStoreToken = x))
.text("rediss trustStore password")
opt[String]('z', "servableManagerConfPath")
.action((x, c) => c.copy(servableManagerPath = x))
.text("servableManagerConfPath")
}
def defineServerContext(httpsKeyStoreToken: String,
httpsKeyStorePath: String): ConnectionContext = {
val token = httpsKeyStoreToken.toCharArray
val keyStore = KeyStore.getInstance("PKCS12")
val keystoreInputStream = new File(httpsKeyStorePath).toURI().toURL().openStream()
require(keystoreInputStream != null, "Keystore required!")
keyStore.load(keystoreInputStream, token)
val keyManagerFactory = KeyManagerFactory.getInstance("SunX509")
keyManagerFactory.init(keyStore, token)
val trustManagerFactory = TrustManagerFactory.getInstance("SunX509")
trustManagerFactory.init(keyStore)
val sslContext = SSLContext.getInstance("TLS")
sslContext.init(keyManagerFactory.getKeyManagers,
trustManagerFactory.getTrustManagers, new SecureRandom)
ConnectionContext.https(sslContext)
}
}
case class FrontEndAppArguments(
interface: String = "0.0.0.0",
port: Int = 10020,
securePort: Int = 10023,
redisHost: String = "localhost",
redisPort: Int = 6379,
redisInputQueue: String = Conventions.SERVING_STREAM_DEFAULT_NAME,
redisOutputQueue: String =
Conventions.RESULT_PREFIX + Conventions.SERVING_STREAM_DEFAULT_NAME
+ ":",
parallelism: Int = 1000,
timeWindow: Int = 0,
countWindow: Int = 0,
tokenBucketEnabled: Boolean = false,
tokensPerSecond: Int = 100,
tokenAcquireTimeout: Int = 100,
httpsEnabled: Boolean = false,
httpsKeyStorePath: String = null,
httpsKeyStoreToken: String = "1234qwer",
redisSecureEnabled: Boolean = false,
redissTrustStorePath: String = null,
redissTrustStoreToken: String = "1234qwer",
servableManagerPath: String = "./servables-conf.yaml"
)
© 2015 - 2025 Weber Informatics LLC | Privacy Policy