All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.rest.RestPredictController.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.rest

import net.csdn.annotation.rest.At
import net.csdn.modules.http.ApplicationController
import net.csdn.modules.http.RestRequest.Method._
import net.sf.json.{JSONArray, JSONObject}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.execution.datasources.json.WowJsonInferSchema
import streaming.core.strategy.platform.{PlatformManager, SparkRuntime}
import streaming.dsl.{MLSQLExecuteContext, ScriptSQLExec}

import scala.collection.JavaConversions._


/**
  * Created by allwefantasy on 20/4/2018.
  */
class RestPredictController extends ApplicationController {

  @At(path = Array("/model/predict"), types = Array(GET, POST))
  def modelPredict = {
    intercept()
    createContext
    val res = param("dataType", "vector") match {
      case "vector" => vec2vecPredict
      case "string" => string2vecPredict
      case "row" => row2vecPredict
    }
    render(200, res)
  }

  def getSession = {
    if (paramAsBoolean("sessionPerUser", false)) {
      runtime.asInstanceOf[SparkRuntime].getSession(param("owner", "admin"))
    } else {
      runtime.asInstanceOf[SparkRuntime].sparkSession
    }
  }

  @At(path = Array("/compute"), types = Array(GET, POST))
  def compute = {
    intercept()
    val sparkSession = getSession
    val res = WowJsonInferSchema.toJson(sparkSession.sql(param("sql"))).mkString(",")

    render(200, res)
  }

  def createContext = {
    val userDefineParams = params.toMap.filter(f => f._1.startsWith("context.")).map(f => (f._1.substring("context.".length), f._2)).toMap
    ScriptSQLExec.setContext(new MLSQLExecuteContext(null,param("owner"), "", "", userDefineParams))
  }

  def getSQL = {
    if (hasParam("sql")) {
      param("sql", "").split("select").mkString("")
    } else if (hasParam("pipeline")) {
      param("pipeline", "").split(",").reverse.foldRight("feature") { (acc, o) =>
        s"${acc}(${o})"
      } + " as feature"
    } else throw new IllegalArgumentException("parameter sql or pipline is required")
  }

  def vec2vecPredict = {
    //dense or sparse
    val vectorType = param("vecType", "dense")
    val sparkSession = getSession
    val vectors = JSONArray.fromObject(param("data", "[]")).map { f =>

      val vec = vectorType match {
        case "dense" =>
          val v = f.asInstanceOf[JSONArray].map(f => f.asInstanceOf[Number].doubleValue()).toArray
          Vectors.dense(v)
        case "sparse" =>
          val v = f.asInstanceOf[JSONObject].map(f => (f._1.asInstanceOf[Int], f._2.asInstanceOf[Number].doubleValue())).toMap
          require(paramAsInt("vecSize", -1) != -1, "when vector type is sparse, vecSize is required")
          Vectors.sparse(paramAsInt("vecSize", -1), v.keys.toArray, v.values.toArray)
      }
      Feature(feature = vec)
    }
    import sparkSession.implicits._
    //select vec_argmax(tf_predict(feature,"feature","label",2)) as predict_label
    val sql = getSQL
    val res = sparkSession.createDataset(sparkSession.sparkContext.parallelize(vectors)).selectExpr(sql).toJSON.collect().mkString(",")
    "[" + res + "]"


  }

  def row2vecPredict = {
    val sparkSession = getSession
    val strList = JSONArray.fromObject(param("data", "[]")).map(f => f.toString)
    val sql = getSQL
    val perRequestCoreNum = paramAsInt("perRequestCoreNum", 1)
    //import sparkSession.implicits._
    //val rdd = sparkSession.sparkContext.parallelize(strList, perRequestCoreNum)
    val res = WowJsonInferSchema.toJson(WowJsonInferSchema.createDataSet(strList, sparkSession).selectExpr(sql)).mkString(",")
    "[" + res + "]"

  }

  def string2vecPredict = {
    val sparkSession = getSession
    val strList = JSONArray.fromObject(param("data", "[]")).map(f => StringFeature(f.toString))
    val sql = getSQL
    val perRequestCoreNum = paramAsInt("perRequestCoreNum", 1)
    import sparkSession.implicits._
    val res = WowJsonInferSchema.toJson(sparkSession.createDataset(strList).selectExpr(sql)).mkString(",")
    "[" + res + "]"

  }

  def runtime = PlatformManager.getRuntime

  def intercept() = {
    val jparams = runtime.asInstanceOf[SparkRuntime].params
    if (jparams.containsKey("streaming.rest.intercept.clzz")) {
      val interceptor = Class.forName(jparams("streaming.rest.intercept.clzz").toString).newInstance()
      interceptor.asInstanceOf[RestInterceptor].before(request = request.httpServletRequest(), response = restResponse.httpServletResponse())
    }
  }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy