com.intel.analytics.zoo.serving.http.Frontend2.scala Maven / Gradle / Ivy
/*
* Copyright 2018 Analytics Zoo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.zoo.serving.http
import java.io.File
import java.security.{KeyStore, SecureRandom}
import java.util
import java.util.UUID
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import javax.net.ssl.{KeyManagerFactory, SSLContext, TrustManagerFactory}
import akka.actor.{ActorRef, ActorSystem, Props}
import akka.http.scaladsl.{ConnectionContext, Http}
import akka.http.scaladsl.server.Directives.{complete, path, _}
import akka.pattern.ask
import akka.stream.ActorMaterializer
import akka.util.Timeout
import com.codahale.metrics.MetricRegistry
import com.google.common.util.concurrent.RateLimiter
import com.intel.analytics.zoo.pipeline.inference.EncryptSupportive
import com.intel.analytics.zoo.serving.utils.Conventions
import com.fasterxml.jackson.databind.ObjectMapper
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.zoo.serving.ClusterServing
import org.apache.logging.log4j.LogManager
import org.slf4j.LoggerFactory
import redis.clients.jedis.JedisPool
import scala.concurrent.Await
import scala.concurrent.duration.DurationInt
object Frontend2 extends Supportive with EncryptSupportive {
override val logger = LoggerFactory.getLogger(getClass)
val name = "analytics zoo web serving frontend"
implicit val system = ActorSystem("zoo-serving-frontend-system")
implicit val materializer = ActorMaterializer()
implicit val executionContext = system.dispatcher
implicit val timeout: Timeout = Timeout(100, TimeUnit.SECONDS)
def main(args: Array[String]): Unit = {
timing(s"$name started successfully.")() {
val arguments = timing("parse arguments")() {
argumentsParser.parse(args, FrontEndAppArguments()) match {
case Some(arguments) => logger.info(s"starting with $arguments"); arguments
case None => argumentsParser.failure("miss args, please see the usage info"); null
}
}
val jedisPool = new JedisPool(
ClusterServing.jedisPoolConfig, arguments.redisHost, arguments.redisPort)
val rateLimiter: RateLimiter = arguments.tokenBucketEnabled match {
case true => RateLimiter.create(arguments.tokensPerSecond)
case false => null
}
val actorName = s"redis-getter"
val ioActor = timing(s"$actorName initialized.")() {
val getterProps = Props(new RedisIOActor(jedisPool = jedisPool))
system.actorOf(getterProps, name = actorName)
}
def processPredictionInput(inputs: String):
Seq[PredictionOutput[String]] = {
val result = timing("response waiting")() {
val id = UUID.randomUUID().toString
val results = timing(s"query message wait for key $id")() {
Await.result(ioActor ? DataInputMessage(id, inputs), timeout.duration)
.asInstanceOf[ModelOutputMessage].valueMap
}
val objectMapper = new ObjectMapper()
results.map(r => {
val resultStr = objectMapper.writeValueAsString(r._2)
PredictionOutput(r._1, resultStr)
})
}
result.toSeq
}
val route = timing("initialize http route")() {
path("") {
timing("welcome")() {
complete("welcome to " + name)
}
} ~ (get & path("metrics")) {
timing("metrics")() {
val keys = metrics.getTimers().keySet()
val servingMetrics = keys.toArray.map(key => {
val timer = metrics.getTimers().get(key)
ServingTimerMetrics(key.toString, timer)
}).toList
complete(jacksonJsonSerializer.serialize(servingMetrics))
}
} ~ (post & path("predict") &
extract(_.request.entity.contentType) & entity(as[String])) {
(contentType, content) => {
val rejected = arguments.tokenBucketEnabled match {
case true =>
if (!rateLimiter.tryAcquire(
arguments.tokenAcquireTimeout, TimeUnit.MILLISECONDS)) {
true
} else {
false
}
case false => false
}
if (rejected) {
val error = ServingError("limited")
complete(500, error.toString)
} else {
try {
val result = timing("predict")() {
val outputs = processPredictionInput(content)
Predictions(outputs)
}
complete(200, result.toString)
} catch {
case e =>
val message = e.getMessage
val exampleJson =
"""{
"instances" : [ {
"intScalar" : 12345,
"floatScalar" : 3.14159,
"stringScalar" : "hello, world. hello, zoo.",
"intTensor" : [ 7756, 9549, 1094, 9808, 4959, 3831, 3926, 6578, 1870, 1741 ],
"floatTensor" : [ 0.6804766, 0.30136853, 0.17394465, 0.44770062, 0.20275897 ],
"stringTensor" : [ "come", "on", "united" ],
"intTensor2" : [ [ 1, 2 ], [ 3, 4 ], [ 5, 6 ] ],
"floatTensor2" : [ [ [ 0.2, 0.3 ], [ 0.5, 0.6 ] ], [ [ 0.2, 0.3 ], [ 0.5, 0.6 ] ] ],
"stringTensor2" : [ [ [ [ "come", "on", "united" ], [ "come", "on", "united" ] ] ] ],
"sparseTensor" : {
"shape" : [ 100, 10000, 10 ],
"data" : [ 0.2, 0.5, 3.45, 6.78 ],
"indices" : [ [ 1, 1, 1 ], [ 2, 2, 2 ], [ 3, 3, 3 ], [ 4, 4, 4 ] ]
},
"image": "/9j/4AAQSkZJRgABAQEASABIAAD/7RcEUGhvdG9za..."
} ]
}"""
val error = ServingError(s"Wrong content format.\n" +
s"Details: ${message}\n" +
s"Please refer to examples:\n" +
s"$exampleJson\n")
complete(400, error.error)
}
}
}
}
}
if (arguments.httpsEnabled) {
val serverContext = defineServerContext(arguments.httpsKeyStoreToken,
arguments.httpsKeyStorePath)
Http().bindAndHandle(route, arguments.interface, port = arguments.securePort,
connectionContext = serverContext)
logger.info(s"https started at https://${arguments.interface}:${arguments.securePort}")
}
Http().bindAndHandle(route, arguments.interface, arguments.port)
logger.info(s"http started at http://${arguments.interface}:${arguments.port}")
while(true) {
ioActor ! DequeueMessage()
Thread.sleep(1)
}
// system.scheduler.schedule(1 milliseconds, 1 millisecond,
// redisGetter, DequeueMessage())(system.dispatcher)
}
}
val metrics = new MetricRegistry
val overallRequestTimer = metrics.timer("zoo.serving.request.overall")
val predictRequestTimer = metrics.timer("zoo.serving.request.predict")
val putRedisTimer = metrics.timer("zoo.serving.redis.put")
val getRedisTimer = metrics.timer("zoo.serving.redis.get")
val waitRedisTimer = metrics.timer("zoo.serving.redis.wait")
val metricsRequestTimer = metrics.timer("zoo.serving.request.metrics")
val jacksonJsonSerializer = new JacksonJsonSerializer()
val argumentsParser = new scopt.OptionParser[FrontEndAppArguments]("AZ Serving") {
head("Analytics Zoo Serving Frontend")
opt[String]('i', "interface")
.action((x, c) => c.copy(interface = x))
.text("network interface of frontend")
opt[Int]('p', "port")
.action((x, c) => c.copy(port = x))
.text("network port of frontend")
opt[Int]('s', "securePort")
.action((x, c) => c.copy(securePort = x))
.text("https port of frontend")
opt[String]('h', "redisHost")
.action((x, c) => c.copy(redisHost = x))
.text("host of redis")
opt[Int]('r', "redisPort")
.action((x, c) => c.copy(redisPort = x))
.text("port of redis")
opt[String]('i', "redisInputQueue")
.action((x, c) => c.copy(redisInputQueue = x))
.text("input queue of redis")
opt[String]('o', "redisOutputQueue")
.action((x, c) => c.copy(redisOutputQueue = x))
.text("output queue of redis")
opt[Int]('l', "parallelism")
.action((x, c) => c.copy(parallelism = x))
.text("parallelism of frontend")
opt[Int]('t', "timeWindow")
.action((x, c) => c.copy(timeWindow = x))
.text("timeWindow of frontend")
opt[Int]('c', "countWindow")
.action((x, c) => c.copy(countWindow = x))
.text("countWindow of frontend")
opt[Boolean]('e', "tokenBucketEnabled")
.action((x, c) => c.copy(tokenBucketEnabled = x))
.text("Token Bucket Enabled or not")
opt[Int]('k', "tokensPerSecond")
.action((x, c) => c.copy(tokensPerSecond = x))
.text("tokens per second")
opt[Int]('a', "tokenAcquireTimeout")
.action((x, c) => c.copy(tokenAcquireTimeout = x))
.text("token acquire timeout")
opt[Boolean]('s', "httpsEnabled")
.action((x, c) => c.copy(httpsEnabled = x))
.text("https enabled or not")
opt[String]('p', "httpsKeyStorePath")
.action((x, c) => c.copy(httpsKeyStorePath = x))
.text("https keyStore path")
opt[String]('w', "httpsKeyStoreToken")
.action((x, c) => c.copy(httpsKeyStoreToken = x))
.text("https keyStore token")
opt[Boolean]('s', "redisSecureEnabled")
.action((x, c) => c.copy(redisSecureEnabled = x))
.text("redis secure enabled or not")
opt[Boolean]('s', "httpsEnabled")
.action((x, c) => c.copy(httpsEnabled = x))
.text("https enabled or not")
opt[String]('p', "redissTrustStorePath")
.action((x, c) => c.copy(redissTrustStorePath = x))
.text("rediss trustStore path")
opt[String]('w', "redissTrustStoreToken")
.action((x, c) => c.copy(redissTrustStoreToken = x))
.text("rediss trustStore password")
}
def defineServerContext(httpsKeyStoreToken: String,
httpsKeyStorePath: String): ConnectionContext = {
val token = httpsKeyStoreToken.toCharArray
val keyStore = KeyStore.getInstance("PKCS12")
val keystoreInputStream = new File(httpsKeyStorePath).toURI().toURL().openStream()
require(keystoreInputStream != null, "Keystore required!")
keyStore.load(keystoreInputStream, token)
val keyManagerFactory = KeyManagerFactory.getInstance("SunX509")
keyManagerFactory.init(keyStore, token)
val trustManagerFactory = TrustManagerFactory.getInstance("SunX509")
trustManagerFactory.init(keyStore)
val sslContext = SSLContext.getInstance("TLS")
sslContext.init(keyManagerFactory.getKeyManagers,
trustManagerFactory.getTrustManagers, new SecureRandom)
ConnectionContext.https(sslContext)
}
}
case class FrontEndApp2Arguments(
interface: String = "0.0.0.0",
port: Int = 10020,
securePort: Int = 10023,
redisHost: String = "localhost",
redisPort: Int = 6379,
redisInputQueue: String = Conventions.SERVING_STREAM_DEFAULT_NAME,
redisOutputQueue: String =
Conventions.RESULT_PREFIX +
Conventions.SERVING_STREAM_DEFAULT_NAME + ":",
parallelism: Int = 1000,
timeWindow: Int = 0,
countWindow: Int = 56,
tokenBucketEnabled: Boolean = false,
tokensPerSecond: Int = 100,
tokenAcquireTimeout: Int = 100,
httpsEnabled: Boolean = false,
httpsKeyStorePath: String = null,
httpsKeyStoreToken: String = "1234qwer",
redisSecureEnabled: Boolean = false,
redissTrustStorePath: String = null,
redissTrustStoreToken: String = "1234qwer"
)
© 2015 - 2025 Weber Informatics LLC | Privacy Policy