All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.SQLDTFAlg.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs

import java.io.File
import java.nio.file.{Files, Paths}
import java.util
import java.util.UUID
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}

import net.sf.json.{JSONArray, JSONObject}
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.http.client.fluent.{Form, Request}
import org.apache.spark.{TaskContext, TaskContextImpl}
import org.apache.spark.api.python.WowPythonRunner
import org.apache.spark.ml.linalg.SQLDataTypes._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, functions => F}
import org.apache.spark.util.ObjPickle._
import org.apache.spark.util.VectorSerDer._
import org.apache.spark.util.{ExternalCommandRunner, ObjPickle, TaskContextUtil, VectorSerDer}
import streaming.common.{HDFSOperator, NetUtils}
import streaming.core.strategy.platform.{PlatformManager, SparkRuntime}
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.SQLPythonFunc._
import streaming.dsl.mmlib.algs.tf.cluster.{ClusterSpec, ClusterStatus}
import scala.collection.mutable

import scala.collection.JavaConverters._

/**
  * Created by allwefantasy on 5/2/2018.
  */
class SQLDTFAlg extends SQLAlg with Functions {
  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {

    val keepVersion = params.getOrElse("keepVersion", "true").toBoolean

    val enableDataLocal = params.getOrElse("enableDataLocal", "false").toBoolean

    var kafkaParam = mapParams("kafkaParam", params)

    require(kafkaParam.size > 0, "kafkaParam should be configured")

    var stopFlagNum = -1
    if (!enableDataLocal) {
      val (_kafkaParam, _newRDD) = writeKafka(df, path, params)
      stopFlagNum = _newRDD.getNumPartitions
      kafkaParam = _kafkaParam
    }

    val systemParam = mapParams("systemParam", params)
    val fitParam = arrayParamsWithIndex("fitParam", params)

    require(fitParam.size > 0, "fitParam should be configured")

    val fitParamRDD = df.sparkSession.sparkContext.parallelize(fitParam, fitParam.length)
    val pythonPath = systemParam.getOrElse("pythonPath", "python")
    val pythonVer = systemParam.getOrElse("pythonVer", "2.7")
    val pythonParam = systemParam.getOrElse("pythonParam", "").split(",").filterNot(f => f.isEmpty)
    var tempDataLocalPath = ""
    var dataHDFSPath = ""
    val tfDataMap = new mutable.HashMap[Int, String]()

    if (enableDataLocal) {
      val dataLocalFormat = params.getOrElse("dataLocalFormat", "json")
      val dataLocalFileNum = fitParam.filter(f => f._2("jobName") == "worker").length

      println(s"enableDataLocal is enabled. we should  prepare data for $dataLocalFileNum workers")

      dataHDFSPath = SQLPythonFunc.getAlgTmpPath(path) + "/data"
      tempDataLocalPath = SQLPythonFunc.getLocalTempDataPath(path)

      val newDF = if (dataLocalFileNum > -1) {
        df.repartition(dataLocalFileNum)
      } else df
      newDF.write.format(dataLocalFormat).mode(SaveMode.Overwrite).save(dataHDFSPath)

      HDFSOperator.listFiles(dataHDFSPath).filter(f => f.getPath.getName.endsWith(s".${dataLocalFormat}")).zipWithIndex.foreach { f =>
        val (path, index) = f
        tfDataMap.put(index, path.getPath.toUri.toString)
      }
      val printData = tfDataMap.map(f => s"${f._1}=>${f._2}").mkString("\n")
      println(s"enableDataLocal is enabled.  data partition to tensorflow worker:\n ${printData}")
      emptyDataFrame()(df)

    }


    val userPythonScript = loadUserDefinePythonScript(params, df.sparkSession)

    val schema = df.schema
    var rows = Array[Array[Byte]]()
    //目前我们只支持同一个测试集
    if (params.contains("validateTable")) {
      val validateTable = params("validateTable")
      rows = df.sparkSession.table(validateTable).rdd.mapPartitions { iter =>
        ObjPickle.pickle(iter, schema)
      }.collect()
    }
    val rowsBr = df.sparkSession.sparkContext.broadcast(rows)

    incrementVersion(path, keepVersion)

    // create a new cluster name for tensorflow
    val clusterUuid = UUID.randomUUID().toString
    val driverHost = NetUtils.getHost
    val LIMIT = fitParam.length
    val driverPort = PlatformManager.getRuntime.asInstanceOf[SparkRuntime].params.getOrDefault("streaming.driver.port", "9003").toString.toInt

    val wowRDD = fitParamRDD.map { paramAndIndex =>
      val f = paramAndIndex._2
      val algIndex = paramAndIndex._1

      val jobName = f("jobName")
      val taskIndex = f("taskIndex").toInt

      def isPs = {
        jobName == "ps"
      }

      val roleSpec = new JSONObject()
      roleSpec.put("jobName", jobName)
      roleSpec.put("taskIndex", taskIndex)


      val host = NetUtils.getHost
      val holdPort = NetUtils.availableAndReturn(ClusterSpec.MIN_PORT_NUMBER, ClusterSpec.MAX_PORT_NUMBER)

      if (holdPort == null) {
        throw new RuntimeException(s"Fail to create tensorflow cluster, maybe executor cannot bind port ")
      }

      val port = holdPort.getLocalPort

      def reportToMaster(path: String = "/cluster/register") = {
        Request.Post(s"http://${driverHost}:${driverPort}${path}").bodyForm(Form.form().add("cluster", clusterUuid).
          add("hostAndPort", s"${host}:${port}")
          .add("jobName", f("jobName"))
          .add("taskIndex", f("taskIndex"))
          .build())
          .execute().returnContent().asString()
      }

      def fetchClusterFromMaster = {
        Request.Get(s"http://${driverHost}:${driverPort}/cluster?cluster=${clusterUuid}")
          .execute().returnContent().asString()
      }

      def getHostAndPortFromJson(infos: JSONArray, t: String) = {
        infos.asScala.map(f => f.asInstanceOf[JSONObject]).filter(f => f.getString("jobName") == t).map(f => s"${f.getString("host")}:${f.getInt("port")}")
      }

      // here we should report to driver
      reportToMaster()
      // wait until all workers have been registered
      var waitCount = 0
      val maxWaitCount = 100
      var noWait = false
      val clusterSpecRef = new AtomicReference[ClusterSpec]()
      while (!noWait && waitCount < maxWaitCount) {
        val response = fetchClusterFromMaster
        val infos = JSONArray.fromObject(response)
        println(s"Waiting all worker/ps started. Wait times: ${waitCount} times. Already registered tf worker/ps: ${infos.size()}")
        if (infos.size == LIMIT) {
          val workerList = getHostAndPortFromJson(infos, "worker")
          val psList = getHostAndPortFromJson(infos, "ps")
          clusterSpecRef.set(new ClusterSpec(workerList.toList, psList.toList))
          noWait = true
        } else {
          Thread.sleep(5000)
          waitCount += 1
        }
      }

      NetUtils.releasePort(holdPort)
      if (clusterSpecRef.get() == null) {
        throw new RuntimeException(s"Fail to create tensorflow cluster, this maybe caused by  executor cannot connect driver at ${driverHost}:${driverPort}")
      }

      val clusterSpec = clusterSpecRef.get()


      val clusterSpecJson = new JSONObject()
      clusterSpecJson.put("worker", clusterSpec.worker.asJava)
      clusterSpecJson.put("ps", clusterSpec.ps.asJava)

      println(clusterSpecJson.toString)

      var tempDataLocalPathWithAlgSuffix = tempDataLocalPath


      if (!isPs && enableDataLocal) {
        val partitionDataInHDFS = tfDataMap(algIndex)
        tempDataLocalPathWithAlgSuffix = tempDataLocalPathWithAlgSuffix + "/" + algIndex
        println(s"Copy HDFS ${partitionDataInHDFS} to local ${tempDataLocalPathWithAlgSuffix} ")
        recordSingleLineLog(kafkaParam, s"Copy HDFS ${partitionDataInHDFS} to local ${tempDataLocalPathWithAlgSuffix} ")
        HDFSOperator.copyToLocalFile(tempLocalPath = tempDataLocalPathWithAlgSuffix + "/" + partitionDataInHDFS.split("/").last, path = partitionDataInHDFS, true)
      }

      val paramMap = new util.HashMap[String, Object]()
      var item = f.asJava
      if (!f.contains("modelPath")) {
        item = (f + ("modelPath" -> path)).asJava
      }

      // load resource
      var resourceParams = Map.empty[String, String]
      if (f.keys.map(_.split("\\.")(0)).toSet.contains("resource")) {
        val resources = Functions.mapParams(s"resource", f)
        resources.foreach {
          case (resourceName, resourcePath) =>
            val tempResourceLocalPath = SQLPythonFunc.getLocalTempResourcePath(resourcePath, resourceName)
            recordSingleLineLog(kafkaParam, s"resource paramter found,system will load resource ${resourcePath} in ${tempResourceLocalPath} in executor.")
            HDFSOperator.copyToLocalFile(tempResourceLocalPath, resourcePath, true)
            resourceParams += (resourceName -> tempResourceLocalPath)
            recordSingleLineLog(kafkaParam, s"resource loaded.")
        }
      }

      val pythonScript = userPythonScript.get

      val tempModelLocalPath = s"${SQLPythonFunc.getLocalBasePath}/${UUID.randomUUID().toString}/${algIndex}"
      val checkpointDir = s"${SQLPythonFunc.getLocalBasePath}/${UUID.randomUUID().toString}/${algIndex}"
      //FileUtils.forceMkdir(tempModelLocalPath)

      paramMap.put("fitParam", item)

      val kafkaP = kafkaParam + ("group_id" -> (kafkaParam("group_id") + "_" + algIndex))
      paramMap.put("kafkaParam", kafkaP.asJava)

      val internalSystemParam = Map(
        "stopFlagNum" -> stopFlagNum,
        "tempModelLocalPath" -> tempModelLocalPath,
        "tempDataLocalPath" -> tempDataLocalPathWithAlgSuffix,
        "resource" -> resourceParams.asJava,
        "clusterSpec" -> clusterSpecJson.toString(),
        "roleSpec" -> roleSpec.toString(),
        "checkpointDir" -> checkpointDir
      )

      paramMap.put("internalSystemParam", internalSystemParam.asJava)
      paramMap.put("systemParam", systemParam.asJava)


      val command = Seq(pythonPath) ++ pythonParam ++ Seq(pythonScript.fileName)
      val taskDirectory = SQLPythonFunc.getLocalRunPath(UUID.randomUUID().toString)

      val modelTrainStartTime = System.currentTimeMillis()

      var score = 0.0
      var trainFailFlag = false

      if (!isPs) {
        try {

          val res = ExternalCommandRunner.run(taskDirectory = taskDirectory,
            command = command,
            iter = paramMap,
            schema = MapType(StringType, MapType(StringType, StringType)),
            scriptContent = pythonScript.fileContent,
            scriptName = pythonScript.fileName,
            recordLog = SQLPythonFunc.recordAnyLog(kafkaParam),
            modelPath = path, validateData = rowsBr.value
          )

          score = recordUserLog(algIndex, pythonScript, kafkaParam, res)
        } catch {
          case e: Exception =>
            e.printStackTrace()
            trainFailFlag = true
        }
      } else {
        val pythonWorker = new AtomicReference[Process]()
        val context = TaskContext.get()

        // notice: I still get no way how to destroy the ps python worker
        // when spark existed unexpectedly.
        // we have handle the situation eg. task is cancel or
        val pythonPSThread = new Thread(new Runnable {
          override def run(): Unit = {
            TaskContextUtil.setContext(context)
            val taskDirectory = SQLPythonFunc.getLocalRunPath(UUID.randomUUID().toString)
            try {
              val res = ExternalCommandRunner.run(taskDirectory = taskDirectory,
                command = command,
                iter = paramMap,
                schema = MapType(StringType, MapType(StringType, StringType)),
                scriptContent = pythonScript.fileContent,
                scriptName = pythonScript.fileName,
                recordLog = SQLPythonFunc.recordAnyLog(kafkaParam),
                modelPath = path, validateData = rowsBr.value
              )
              pythonWorker.set(res.asInstanceOf[ {def getWorker: Process}].getWorker)
              score = recordUserLog(algIndex, pythonScript, kafkaParam, res)
            } catch {
              case e: Exception =>
                e.printStackTrace()
                trainFailFlag = true
            }
          }
        })
        pythonPSThread.setDaemon(true)
        pythonPSThread.start()


        //        val pythonPSMonitorThread = new Thread(new Runnable {
        //          override def run(): Unit = {
        //            while (!context.isInterrupted && !context.isCompleted) {
        //              Thread.sleep(2000)
        //            }
        //            if (context.isCompleted || context.isInterrupted()) {
        //              try {
        //                pythonWorker.get().destroy()
        //              } catch {
        //                case e: Exception =>
        //              }
        //            }
        //          }
        //        })
        //        pythonPSMonitorThread.setDaemon(true)
        //        pythonPSMonitorThread.start()

        var shouldWait = true

        def fetchClusterStatusFromMaster = {
          Request.Get(s"http://${driverHost}:${driverPort}/cluster/worker/finish?cluster=${clusterUuid}")
            .execute().returnContent().asString()
        }

        // PS should wait until
        // all worker finish their jobs.
        // PS python worker use thread.join to keep it alive .
        while (shouldWait) {
          val response = fetchClusterStatusFromMaster
          if (response.toInt == clusterSpec.worker.size) {
            shouldWait = false
          }
          try {
            Thread.sleep(10000)
          } catch {
            case e: Exception =>
              shouldWait = false
          }
          println(s"check worker finish size. targetSize:${clusterSpec.worker.size} vs ${response.toInt}")
        }
        pythonWorker.get().destroy()
        pythonPSThread.interrupt()

      }
      val modelTrainEndTime = System.currentTimeMillis()
      if (!isPs) {
        if (!trainFailFlag) {
          reportToMaster("/cluster/worker/status")
        }
        val modelTrainEndTime = System.currentTimeMillis()

        val modelHDFSPath = SQLPythonFunc.getAlgModelPath(path, keepVersion) + "/" + algIndex
        try {
          //模型保存到hdfs上
          if (!keepVersion) {
            HDFSOperator.deleteDir(modelHDFSPath)
          }
          if (new File(tempModelLocalPath).exists()) {
            HDFSOperator.copyToHDFS(tempModelLocalPath, modelHDFSPath, true, false)
          } else {
            if (new File(checkpointDir).exists()) {
              HDFSOperator.copyToHDFS(checkpointDir, modelHDFSPath, true, false)
            }
          }

        } catch {
          case e: Exception =>
            trainFailFlag = true
        } finally {
          // delete local model
          FileUtils.deleteDirectory(new File(tempModelLocalPath))
          // delete local data
          FileUtils.deleteDirectory(new File(tempDataLocalPathWithAlgSuffix))
          FileUtils.deleteDirectory(new File(checkpointDir))
        }
        val status = if (trainFailFlag) "fail" else "success"
        Row.fromSeq(Seq(modelHDFSPath, algIndex, pythonScript.fileName, score, status, modelTrainStartTime, modelTrainEndTime, f))
      } else {
        val status = if (trainFailFlag) "fail" else "success"
        Row.fromSeq(Seq("", algIndex, pythonScript.fileName, score, status, modelTrainStartTime, modelTrainEndTime, f))
      }
    }
    df.sparkSession.createDataFrame(wowRDD, StructType(Seq(
      StructField("modelPath", StringType),
      StructField("algIndex", IntegerType),
      StructField("alg", StringType),
      StructField("score", DoubleType),

      StructField("status", StringType),
      StructField("startTime", LongType),
      StructField("endTime", LongType),
      StructField("trainParams", MapType(StringType, StringType))
    ))).
      write.
      mode(SaveMode.Overwrite).
      parquet(SQLPythonFunc.getAlgMetalPath(path, keepVersion) + "/0")

    val tempRDD = df.sparkSession.sparkContext.parallelize(Seq(Seq(Map("pythonPath" -> pythonPath, "pythonVer" -> pythonVer), params)), 1).map { f =>
      Row.fromSeq(f)
    }
    df.sparkSession.createDataFrame(tempRDD, StructType(Seq(
      StructField("systemParam", MapType(StringType, StringType)),
      StructField("trainParams", MapType(StringType, StringType))))).
      write.
      mode(SaveMode.Overwrite).
      parquet(SQLPythonFunc.getAlgMetalPath(path, keepVersion) + "/1")
    emptyDataFrame()(df)

  }

  override def load(sparkSession: SparkSession, _path: String, params: Map[String, String]): Any = {
    throw new RuntimeException("register is support yet")
  }

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
    throw new RuntimeException("register is support yet")
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy