All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.python.APIBatchPredict.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs.python

import java.nio.file.{Files, Paths}
import java.util.UUID

import org.apache.spark.{APIDeployPythonRunnerEnv, TaskContext}
import org.apache.spark.api.python.WowPythonRunner
import org.apache.spark.ml.feature.PythonBatchPredictDataSchema
import org.apache.spark.ml.linalg.Matrix
import org.apache.spark.ml.linalg.SQLDataTypes.{MatrixType, VectorType}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType}
import org.apache.spark.util.{ExternalCommandRunner, MatrixSerDer, ObjPickle}
import streaming.core.strategy.platform.PlatformManager
import streaming.dsl.mmlib.algs.{Functions, SQLPythonAlg, SQLPythonFunc}
import scala.collection.JavaConverters._

import scala.collection.mutable

class APIBatchPredict(df: DataFrame, _path: String, params: Map[String, String]) extends Functions {
//  def predict: DataFrame = {
//    val sparkSession = df.sparkSession
//
//    val modelMetaManager = new ModelMetaManager(sparkSession, _path, params)
//    val modelMeta = modelMetaManager.loadMetaAndModel
//
//    val batchSize = params.getOrElse("batchSize", "10").toInt
//    val inputCol = params.getOrElse("inputCol", "")
//
//    require(inputCol != null && inputCol != "", s"inputCol in ${getClass} module should be configured!")
//
//    val batchPredictFun = params.getOrElse("predictFun", UUID.randomUUID().toString.replaceAll("-", ""))
//    val predictLabelColumnName = params.getOrElse("predictCol", "predict_label")
//    val predictTableName = params.getOrElse("predictTable", "")
//
//    require(
//      predictTableName != null && predictTableName != "",
//      s"predictTable in ${getClass} module should be configured!"
//    )
//
//    val schema = PythonBatchPredictDataSchema.newSchema(df)
//
//    val rdd = df.rdd.mapPartitions(it => {
//      var list = List.empty[List[Row]]
//      var tmpList = List.empty[Row]
//      var batchCount = 0
//      while (it.hasNext) {
//        val e = it.next()
//        if (batchCount == batchSize) {
//          list +:= tmpList
//          batchCount = 0
//          tmpList = List.empty[Row]
//        } else {
//          tmpList +:= e
//          batchCount += 1
//        }
//      }
//      if (batchCount != batchSize) {
//        list +:= tmpList
//      }
//      list.map(x => {
//        Row.fromSeq(Seq(x, SQLPythonAlg.createNewFeatures(x, inputCol)))
//      }).iterator
//    })
//
//    val systemParam = mapParams("systemParam", modelMeta.trainParams)
//    val pythonPath = systemParam.getOrElse("pythonPath", "python")
//    val pythonVer = systemParam.getOrElse("pythonVer", "2.7")
//    val kafkaParam = mapParams("kafkaParam", modelMeta.trainParams)
//
//    // load python script
//    val userPythonScript = SQLPythonFunc.findPythonPredictScript(sparkSession, params, "")
//
//    val maps = new java.util.HashMap[String, java.util.Map[String, _]]()
//    val item = new java.util.HashMap[String, String]()
//    item.put("funcPath", "/tmp/" + System.currentTimeMillis())
//    maps.put("systemParam", item)
//    maps.put("internalSystemParam", modelMeta.resources.asJava)
//
//    val taskDirectory = SQLPythonFunc.getLocalRunPath(UUID.randomUUID().toString)
//
//    val res = ExternalCommandRunner.run(taskDirectory, Seq(pythonPath, userPythonScript.fileName),
//      maps,
//      MapType(StringType, MapType(StringType, StringType)),
//      userPythonScript.fileContent,
//      userPythonScript.fileName, modelPath = null, recordLog = SQLPythonFunc.recordAnyLog(kafkaParam)
//    )
//    res.foreach(f => f)
//    val command = Files.readAllBytes(Paths.get(item.get("funcPath")))
//    val runtimeParams = PlatformManager.getRuntime.params.asScala.toMap
//
//    // registe batch predict python function
//
//    val recordLog = SQLPythonFunc.recordAnyLog(Map[String, String]())
//    val models = sparkSession.sparkContext.broadcast(modelMeta.modelEntityPaths)
//    val f = (m: Matrix, modelPath: String) => {
//      val modelRow = InternalRow.fromSeq(Seq(SQLPythonFunc.getLocalTempModelPath(modelPath)))
//      val trainParamsRow = InternalRow.fromSeq(Seq(ArrayBasedMapData(params)))
//      val v_ser = ObjPickle.pickleInternalRow(Seq(MatrixSerDer.serialize(m)).toIterator, MatrixSerDer.matrixSchema())
//      val v_ser2 = ObjPickle.pickleInternalRow(Seq(modelRow).toIterator, StructType(Seq(StructField("modelPath", StringType))))
//      val v_ser3 = v_ser ++ v_ser2
//
//      if (TaskContext.get() == null) {
//        APIDeployPythonRunnerEnv.setTaskContext(APIDeployPythonRunnerEnv.createTaskContext())
//      }
//      val iter = WowPythonRunner.run(
//        pythonPath, pythonVer, command, v_ser3, TaskContext.get().partitionId(), Array(), runtimeParams, recordLog
//      )
//      val a = iter.next()
//      val predictValue = MatrixSerDer.deserialize(ObjPickle.unpickle(a).asInstanceOf[java.util.ArrayList[Object]].get(0))
//      predictValue
//    }
//
//    val f2 = (m: Matrix) => {
//      models.value.map { modelPath =>
//        f(m, modelPath)
//      }.head
//    }
//
//    val func = UserDefinedFunction(f2, MatrixType, Some(Seq(MatrixType)))
//    sparkSession.udf.register(batchPredictFun, func)
//
//    // temp batch predict column name
//    val tmpPredColName = UUID.randomUUID().toString.replaceAll("-", "")
//    val pdf = sparkSession.createDataFrame(rdd, schema)
//      .selectExpr(s"${batchPredictFun}(newFeature) as ${tmpPredColName}", "originalData")
//
//    val prdd = pdf.rdd.mapPartitions(it => {
//      var list = List.empty[Row]
//      while (it.hasNext) {
//        val e = it.next()
//        val originalData = e.getAs[mutable.WrappedArray[Row]]("originalData")
//        val newFeature = e.getAs[Matrix](tmpPredColName).rowIter.toList
//        val size = originalData.size
//        (0 until size).map(index => {
//          val od = originalData(index)
//          val pd = newFeature(index)
//          list +:= Row.fromSeq(od.toSeq ++ Seq(pd))
//        })
//      }
//      list.iterator
//    })
//    val pschema = df.schema.add(predictLabelColumnName, VectorType)
//    val newdf = sparkSession.createDataFrame(prdd, pschema)
//    newdf.createOrReplaceTempView(predictTableName)
//    newdf
//  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy