All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.python.PythonTrain.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs.python

import java.io.{File, FileWriter}
import java.util
import java.util.UUID

import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.apache.spark.util._
import streaming.common.ScalaMethodMacros._
import streaming.common.{HDFSOperator, NetUtils}
import streaming.dsl.ScriptSQLExec
import streaming.dsl.mmlib.algs.SQLPythonFunc._
import streaming.dsl.mmlib.algs.{Functions, SQLPythonAlg, SQLPythonFunc}

import scala.collection.JavaConverters._

class PythonTrain extends Functions with Serializable {
  def train_per_partition(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    val keepVersion = params.getOrElse("keepVersion", "false").toBoolean
    val keepLocalDirectory = params.getOrElse("keepLocalDirectory", "false").toBoolean

    val partitionKey = params.get("partitionKey")

    var kafkaParam = mapParams("kafkaParam", params)

    val enableDataLocal = params.getOrElse("enableDataLocal", "true").toBoolean

    var stopFlagNum = -1
    if (!enableDataLocal) {
      require(kafkaParam.size > 0, "We detect that you do not set enableDataLocal true, " +
        "in this situation, kafkaParam should be configured")
      val (_kafkaParam, _newRDD) = writeKafka(df, path, params)
      stopFlagNum = _newRDD.getNumPartitions
      kafkaParam = _kafkaParam
    }

    // find python project
    val pythonProject = PythonAlgProject.loadProject(params, df.sparkSession)
    incrementVersion(path, keepVersion)

    val mlsqlContext = ScriptSQLExec.contextGetOrForTest()

    val pythonProjectPath = Option(pythonProject.get.filePath)

    val projectName = pythonProject.get.projectName
    val projectType = pythonProject.get.scriptType

    var fitParam = mapParams("fitParam", params)


    def cleanNumber() = {
      fitParam.map { f =>
        var key = f._1
        val value = f._2
        try {
          val prefix = key.split("\\.")(0)
          (prefix).toLong
          key = key.substring(prefix.length + 1)
        } catch {
          case e: Exception =>
        }
        (key, value)

      }
    }

    fitParam = cleanNumber()

    val systemParam = mapParams("systemParam", params)
    val mlflowConfig = MLFlowConfig.buildFromSystemParam(systemParam)
    val pythonConfig = PythonConfig.buildFromSystemParam(systemParam)

    val appName = df.sparkSession.sparkContext.getConf.get("spark.app.name")

    val envs = EnvConfig.buildFromSystemParam(systemParam) ++ Map(BasicCondaEnvManager.MLSQL_INSTNANCE_NAME_KEY -> appName)


    if (!keepVersion) {
      if (path.contains("..") || path == "/" || path.split("/").length < 3) {
        throw new MLSQLException("path should at least three layer")
      }
      HDFSOperator.deleteDir(SQLPythonFunc.getAlgModelPath(path, keepVersion))
    }

    val wowRDD = df.toJSON.rdd.mapPartitionsWithIndex { case (algIndex, iter) =>

      ScriptSQLExec.setContext(mlsqlContext)

      val localPathConfig = LocalPathConfig.buildFromParams(path)
      var tempDataLocalPathWithAlgSuffix = localPathConfig.localDataPath

      if (enableDataLocal) {
        tempDataLocalPathWithAlgSuffix = tempDataLocalPathWithAlgSuffix + "/" + algIndex
        val msg = s"dataLocalFormat enabled ,system will generate data in ${tempDataLocalPathWithAlgSuffix} "
        logInfo(format(msg))
        if (!new File(tempDataLocalPathWithAlgSuffix).exists()) {
          FileUtils.forceMkdir(new File(tempDataLocalPathWithAlgSuffix))
        }
        val localFile = new File(tempDataLocalPathWithAlgSuffix, UUID.randomUUID().toString + ".json")
        val localFileWriter = new FileWriter(localFile)
        try {
          iter.foreach { line =>
            localFileWriter.write(line + "\n")
          }
        } finally {
          localFileWriter.close()
        }
      }

      val paramMap = new util.HashMap[String, Object]()
      var item = fitParam.asJava
      if (!fitParam.contains("modelPath")) {
        item = (fitParam + ("modelPath" -> path)).asJava
      }

      val resourceParams = new ResourceManager(fitParam).loadResourceInTrain

      val taskDirectory = localPathConfig.localRunPath + "/" + projectName
      val tempModelLocalPath = s"${localPathConfig.localModelPath}/${algIndex}"

      paramMap.put("fitParam", item)

      if (!kafkaParam.isEmpty) {
        val kafkaP = kafkaParam + ("group_id" -> (kafkaParam("group_id") + "_" + algIndex))
        paramMap.put("kafkaParam", kafkaP.asJava)
      }

      val internalSystemParam = Map(
        str[RunPythonConfig.InternalSystemParam](_.stopFlagNum) -> stopFlagNum,
        str[RunPythonConfig.InternalSystemParam](_.tempModelLocalPath) -> tempModelLocalPath,
        str[RunPythonConfig.InternalSystemParam](_.tempDataLocalPath) -> tempDataLocalPathWithAlgSuffix,
        str[RunPythonConfig.InternalSystemParam](_.resource) -> resourceParams.asJava
      )

      paramMap.put(RunPythonConfig.internalSystemParam, internalSystemParam.asJava)

      pythonProjectPath match {
        case Some(_) =>
          if (projectType == MLFlow) {
            logInfo(format(s"'${pythonProjectPath.get}' is MLflow project. download it from [${pythonProject}] to local [${taskDirectory}]"))
            SQLPythonAlg.downloadPythonProject(taskDirectory, pythonProjectPath)
          } else {
            val localPath = taskDirectory + "/" + pythonProjectPath.get.split("/").last
            logInfo(format(s"'${pythonProjectPath.get}' is Normal project. download it from [${pythonProject}] to local [${localPath}]"))
            SQLPythonAlg.downloadPythonProject(localPath, pythonProjectPath)
          }


        case None =>
          // this will not happen, cause even script is a project contains only one python script file.
          throw new RuntimeException("The project or script you configured in pythonScriptPath is not a validate project")
      }


      val command = new PythonAlgExecCommand(pythonProject.get, None, None, envs).
        generateCommand(MLProject.train_command)


      val execDesc =
        s"""
           |
           |----------------------------------------------------------------
           |host: ${NetUtils.getHost}
           |
           |Command: ${command.mkString(" ")}
           |TaskDirectory: ${taskDirectory}
           |DataDirectory: ${tempDataLocalPathWithAlgSuffix}
           |ModelDirectory: ${tempModelLocalPath}
           |
           |Notice:
           |
           |If you wanna keep [${taskDirectory}] for debug, please set
           |keepLocalDirectory=true in train statement.
           |
           |e.g:
           |
           |train data as PythonParallelExt.`/tmp/abc` where keepLocalDirectory="true";
           |----------------------------------------------------------------
         """.stripMargin
      logInfo(format(execDesc))

      val modelTrainStartTime = System.currentTimeMillis()

      var score = 0.0
      var trainFailFlag = false
      val runner = new PythonProjectExecuteRunner(
        taskDirectory = taskDirectory,
        keepLocalDirectory = keepLocalDirectory,
        envVars = envs,
        logCallback = (msg) => {
          ScriptSQLExec.setContextIfNotPresent(mlsqlContext)
          logInfo(format(msg))
        }
      )
      try {

        val res = runner.run(
          command = command,
          params = paramMap,
          schema = MapType(StringType, MapType(StringType, StringType)),
          scriptContent = pythonProject.get.fileContent,
          scriptName = pythonProject.get.fileName,
          validateData = Array()
        )

        def filterScore(str: String) = {
          if (str != null && str.startsWith("mlsql_validation_score:")) {
            str.split(":").last.toDouble
          } else 0d
        }

        val scores = res.map { f =>
          logInfo(format(f))
          filterScore(f)
        }.toSeq
        score = if (scores.size > 0) scores.head else 0d
      } catch {
        case e: Exception =>
          logError(format_cause(e))
          e.printStackTrace()
          trainFailFlag = true
      }

      val modelTrainEndTime = System.currentTimeMillis()

      val partitionName = partitionKey match {
        case Some(i) => s"${i}=" + algIndex
        case None => algIndex
      }

      val modelHDFSPath = SQLPythonFunc.getAlgModelPath(path, keepVersion) + "/" + partitionName
      try {
        // If training failed, we do not need
        // copy model to hdfs
        if (!trainFailFlag) {
          //copy model to HDFS
          val fs = FileSystem.get(new Configuration())
          if (!keepVersion) {
            fs.delete(new Path(modelHDFSPath), true)
          }
          fs.copyFromLocalFile(new Path(tempModelLocalPath),
            new Path(modelHDFSPath))
        }
      } catch {
        case e: Exception =>
          e.printStackTrace()
          trainFailFlag = true
      } finally {
        // delete resource
        resourceParams.foreach { rp =>
          FileUtils.deleteQuietly(new File(rp._2))
        }
        // delete local model
        FileUtils.deleteQuietly(new File(tempModelLocalPath))
        // delete local data
        if (!keepLocalDirectory) {
          FileUtils.deleteQuietly(new File(tempDataLocalPathWithAlgSuffix))
        }
      }
      val status = if (trainFailFlag) "fail" else "success"
      val row = Row.fromSeq(Seq(
        modelHDFSPath,
        algIndex,
        pythonProject.get.fileName,
        score,
        status,
        modelTrainStartTime,
        modelTrainEndTime,
        fitParam,
        execDesc
      ))
      Seq(row).toIterator
    }

    df.sparkSession.createDataFrame(wowRDD, PythonTrainingResultSchema.algSchema).write.mode(SaveMode.Overwrite).parquet(SQLPythonFunc.getAlgMetalPath(path, keepVersion) + "/0")

    val tempRDD = df.sparkSession.sparkContext.parallelize(Seq(Seq(
      Map(
        str[PythonConfig](_.pythonPath) -> pythonConfig.pythonPath,
        str[PythonConfig](_.pythonVer) -> pythonConfig.pythonVer
      ), params)), 1).map { f =>
      Row.fromSeq(f)
    }
    df.sparkSession.createDataFrame(tempRDD, PythonTrainingResultSchema.trainParamsSchema).
      write.
      mode(SaveMode.Overwrite).
      parquet(SQLPythonFunc.getAlgMetalPath(path, keepVersion) + "/1")

    df.sparkSession.read.parquet(SQLPythonFunc.getAlgMetalPath(path, keepVersion) + "/0")
  }

  def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {

    val keepVersion = params.getOrElse("keepVersion", "true").toBoolean
    val keepLocalDirectory = params.getOrElse("keepLocalDirectory", "false").toBoolean

    var kafkaParam = mapParams("kafkaParam", params)

    // save data to hdfs and broadcast validate table
    val dataManager = new DataManager(df, path, params)
    val enableDataLocal = dataManager.enableDataLocal
    val dataHDFSPath = dataManager.saveDataToHDFS
    val rowsBr = dataManager.broadCastValidateTable

    var stopFlagNum = -1
    if (!enableDataLocal) {
      require(kafkaParam.size > 0, "We detect that you do not set enableDataLocal true, " +
        "in this situation, kafkaParam should be configured")
      val (_kafkaParam, _newRDD) = writeKafka(df, path, params)
      stopFlagNum = _newRDD.getNumPartitions
      kafkaParam = _kafkaParam
    }

    val systemParam = mapParams("systemParam", params)
    var fitParam = arrayParamsWithIndex("fitParam", params)

    if (fitParam.size == 0) {
      logWarning(format("fitParam is not configured, we will use empty configuration"))
      fitParam = Array(0 -> Map[String, String]())
    }
    val fitParamRDD = df.sparkSession.sparkContext.parallelize(fitParam, fitParam.length)


    //configuration about project and env
    val mlflowConfig = MLFlowConfig.buildFromSystemParam(systemParam)
    val pythonConfig = PythonConfig.buildFromSystemParam(systemParam)

    val appName = df.sparkSession.sparkContext.getConf.get("spark.app.name")
    val envs = EnvConfig.buildFromSystemParam(systemParam) ++ Map(BasicCondaEnvManager.MLSQL_INSTNANCE_NAME_KEY -> appName)


    // find python project
    val pythonProject = PythonAlgProject.loadProject(params, df.sparkSession)


    incrementVersion(path, keepVersion)

    val mlsqlContext = ScriptSQLExec.contextGetOrForTest()

    val pythonProjectPath = Option(pythonProject.get.filePath)

    val projectName = pythonProject.get.projectName
    val projectType = pythonProject.get.scriptType

    val wowRDD = fitParamRDD.map { paramAndIndex =>

      ScriptSQLExec.setContext(mlsqlContext)


      val f = paramAndIndex._2
      val algIndex = paramAndIndex._1

      val localPathConfig = LocalPathConfig.buildFromParams(path)
      var tempDataLocalPathWithAlgSuffix = localPathConfig.localDataPath

      if (enableDataLocal) {
        tempDataLocalPathWithAlgSuffix = tempDataLocalPathWithAlgSuffix + "/" + algIndex
        val msg = s"dataLocalFormat enabled ,system will generate data in ${tempDataLocalPathWithAlgSuffix} "
        logInfo(format(msg))
        HDFSOperator.copyToLocalFile(tempLocalPath = tempDataLocalPathWithAlgSuffix, path = dataHDFSPath, true)
      }

      val paramMap = new util.HashMap[String, Object]()
      var item = f.asJava
      if (!f.contains("modelPath")) {
        item = (f + ("modelPath" -> path)).asJava
      }

      val resourceParams = new ResourceManager(f).loadResourceInTrain

      val taskDirectory = localPathConfig.localRunPath + "/" + projectName
      val tempModelLocalPath = s"${localPathConfig.localModelPath}/${algIndex}"

      paramMap.put("fitParam", item)

      if (!kafkaParam.isEmpty) {
        val kafkaP = kafkaParam + ("group_id" -> (kafkaParam("group_id") + "_" + algIndex))
        paramMap.put("kafkaParam", kafkaP.asJava)
      }

      val internalSystemParam = Map(
        str[RunPythonConfig.InternalSystemParam](_.stopFlagNum) -> stopFlagNum,
        str[RunPythonConfig.InternalSystemParam](_.tempModelLocalPath) -> tempModelLocalPath,
        str[RunPythonConfig.InternalSystemParam](_.tempDataLocalPath) -> tempDataLocalPathWithAlgSuffix,
        str[RunPythonConfig.InternalSystemParam](_.resource) -> resourceParams.asJava
      )

      paramMap.put(RunPythonConfig.internalSystemParam, internalSystemParam.asJava)
      paramMap.put(RunPythonConfig.systemParam, systemParam.asJava)

      pythonProjectPath match {
        case Some(_) =>
          if (projectType == MLFlow) {
            logInfo(format(s"'${pythonProjectPath.get}' is MLflow project. download it from [${pythonProject}] to local [${taskDirectory}]"))
            SQLPythonAlg.downloadPythonProject(taskDirectory, pythonProjectPath)
          } else {
            val localPath = taskDirectory + "/" + pythonProjectPath.get.split("/").last
            logInfo(format(s"'${pythonProjectPath.get}' is Normal project. download it from [${pythonProject}] to local [${localPath}]"))
            SQLPythonAlg.downloadPythonProject(localPath, pythonProjectPath)
          }


        case None =>
          // this will not happen, cause even script is a project contains only one python script file.
          throw new RuntimeException("The project or script you configured in pythonScriptPath is not a validate project")
      }


      val command = new PythonAlgExecCommand(pythonProject.get, Option(mlflowConfig), Option(pythonConfig), envs).
        generateCommand(MLProject.train_command)


      val execDesc =
        s"""
           |
           |----------------------------------------------------------------
           |host: ${NetUtils.getHost}
           |
           |Command: ${command.mkString(" ")}
           |TaskDirectory: ${taskDirectory}
           |DataDirectory: ${tempDataLocalPathWithAlgSuffix}
           |ModelDirectory: ${tempModelLocalPath}
           |
           |Notice:
           |
           |If you wanna keep [${taskDirectory}] for debug, please set
           |keepLocalDirectory=true in train statement.
           |
           |e.g:
           |
           |train data as PythonAlg.`/tmp/abc` where keepLocalDirectory="true";
           |----------------------------------------------------------------
         """.stripMargin
      logInfo(format(execDesc))

      val modelTrainStartTime = System.currentTimeMillis()

      var score = 0.0
      var trainFailFlag = false
      val runner = new PythonProjectExecuteRunner(
        taskDirectory = taskDirectory,
        keepLocalDirectory = keepLocalDirectory,
        envVars = envs,
        logCallback = (msg) => {
          ScriptSQLExec.setContextIfNotPresent(mlsqlContext)
          logInfo(format(msg))
        }
      )
      try {

        val res = runner.run(
          command = command,
          params = paramMap,
          schema = MapType(StringType, MapType(StringType, StringType)),
          scriptContent = pythonProject.get.fileContent,
          scriptName = pythonProject.get.fileName,
          validateData = rowsBr.value
        )

        def filterScore(str: String) = {
          if (str != null && str.startsWith("mlsql_validation_score:")) {
            str.split(":").last.toDouble
          } else 0d
        }

        val scores = res.map { f =>
          logInfo(format(f))
          filterScore(f)
        }.toSeq
        score = if (scores.size > 0) scores.head else 0d
      } catch {
        case e: Exception =>
          logError(format_cause(e))
          e.printStackTrace()
          trainFailFlag = true
      }

      val modelTrainEndTime = System.currentTimeMillis()

      val modelHDFSPath = SQLPythonFunc.getAlgModelPath(path, keepVersion) + "/" + algIndex
      try {
        // If training failed, we do not need
        // copy model to hdfs
        if (!trainFailFlag) {
          //copy model to HDFS
          val fs = FileSystem.get(new Configuration())
          if (!keepVersion) {
            fs.delete(new Path(modelHDFSPath), true)
          }
          fs.copyFromLocalFile(new Path(tempModelLocalPath),
            new Path(modelHDFSPath))
        }
      } catch {
        case e: Exception =>
          e.printStackTrace()
          trainFailFlag = true
      } finally {
        // delete local model
        FileUtils.forceDelete(new File(tempModelLocalPath))
        // delete local data
        if (!keepLocalDirectory) {
          FileUtils.forceDelete(new File(tempDataLocalPathWithAlgSuffix))
        }
        // delete resource
        resourceParams.foreach { rp =>
          FileUtils.forceDelete(new File(rp._2))
        }
      }
      val status = if (trainFailFlag) "fail" else "success"
      Row.fromSeq(Seq(
        modelHDFSPath,
        algIndex,
        pythonProject.get.fileName,
        score,
        status,
        modelTrainStartTime,
        modelTrainEndTime,
        f,
        execDesc
      ))
    }

    df.sparkSession.createDataFrame(wowRDD, PythonTrainingResultSchema.algSchema).write.mode(SaveMode.Overwrite).parquet(SQLPythonFunc.getAlgMetalPath(path, keepVersion) + "/0")

    val tempRDD = df.sparkSession.sparkContext.parallelize(Seq(Seq(
      Map(
        str[PythonConfig](_.pythonPath) -> pythonConfig.pythonPath,
        str[PythonConfig](_.pythonVer) -> pythonConfig.pythonVer
      ), params)), 1).map { f =>
      Row.fromSeq(f)
    }
    df.sparkSession.createDataFrame(tempRDD, PythonTrainingResultSchema.trainParamsSchema).
      write.
      mode(SaveMode.Overwrite).
      parquet(SQLPythonFunc.getAlgMetalPath(path, keepVersion) + "/1")

    df.sparkSession.read.parquet(SQLPythonFunc.getAlgMetalPath(path, keepVersion) + "/0")
  }


}






© 2015 - 2024 Weber Informatics LLC | Privacy Policy