All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.SQLPythonEnvExt.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs

import org.apache.spark.ml.param.Param
import org.apache.spark.ps.cluster.Message
import org.apache.spark.ps.cluster.Message.CreateOrRemovePythonCondaEnvResponse
import org.apache.spark.scheduler.cluster.PSDriverEndpoint
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.{DataFrame, SparkSession}
import streaming.common.HDFSOperator
import streaming.core.strategy.platform.{PlatformManager, SparkRuntime}
import streaming.dsl.mmlib._
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import streaming.dsl.mmlib.algs.python.BasicCondaEnvManager

/**
  * 2019-01-16 WilliamZhu([email protected])
  */
class SQLPythonEnvExt(override val uid: String) extends SQLAlg with WowParams {

  def this() = this(BaseParams.randomUID())

  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    val spark = df.sparkSession

    params.get(command.name).map { s =>
      set(command, s)
      s
    }.getOrElse {
      throw new MLSQLException(s"${command.name} is required")
    }

    params.get(condaYamlFilePath.name).map { s =>
      set(condaYamlFilePath, s)
    }.getOrElse {
      params.get(condaFile.name).map { s =>
        val condaContent = spark.table(s).head().getString(0)
        val baseFile = path + "/__mlsql_temp_dir__/conda"
        val fileName = "conda.yaml"
        HDFSOperator.saveFile(baseFile, fileName, Seq(("", condaContent)).iterator)
        set(condaYamlFilePath, baseFile + "/" + fileName)
      }.getOrElse {
        throw new MLSQLException(s"${condaFile.name} || ${condaYamlFilePath} is required")
      }

    }

    val wowCommand = $(command) match {
      case "create" => Message.AddEnvCommand
      case "remove" => Message.RemoveEnvCommand
    }
    val appName = spark.sparkContext.getConf.get("spark.app.name")
    val remoteCommand = Message.CreateOrRemovePythonCondaEnv($(condaYamlFilePath), params ++ Map(BasicCondaEnvManager.MLSQL_INSTNANCE_NAME_KEY -> appName), wowCommand)

    val response = if (spark.sparkContext.isLocal) {
      val psDriverBackend = PlatformManager.getRuntime.asInstanceOf[SparkRuntime].localSchedulerBackend
      psDriverBackend.localEndpoint.askSync[CreateOrRemovePythonCondaEnvResponse](remoteCommand, PSDriverEndpoint.MLSQL_DEFAULT_RPC_TIMEOUT(spark.sparkContext.getConf))
    } else {
      val psDriverBackend = PlatformManager.getRuntime.asInstanceOf[SparkRuntime].psDriverBackend
      psDriverBackend.psDriverRpcEndpointRef.askSync[CreateOrRemovePythonCondaEnvResponse](remoteCommand, PSDriverEndpoint.MLSQL_DEFAULT_RPC_TIMEOUT(spark.sparkContext.getConf))
    }
    import spark.implicits._
    spark.createDataset[CreateOrRemovePythonCondaEnvResponse](Seq(response)).toDF()
  }


  override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    train(df, path, params)
  }

  override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = throw new RuntimeException("register is not support")

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = throw new RuntimeException("register is not support")

  final val command: Param[String] = new Param[String](this, "command", "create|remove", isValid = (s: String) => {
    s == "create" || s == "remove"
  })

  final val condaYamlFilePath: Param[String] = new Param[String](this, "condaYamlFilePath", "the conda file path")
  final val condaFile: Param[String] = new Param[String](this, "condaFile", "variable ref configured by set")

  override def explainParams(sparkSession: SparkSession): DataFrame = _explainParams(sparkSession)

  override def codeExample: Code = Code(SQLCode,
    """
      |```sql
      |set dependencies='''
      |name: tutorial4
      |dependencies:
      |  - python=3.6
      |  - pip
      |  - pip:
      |    - --index-url https://mirrors.aliyun.com/pypi/simple/
      |    - numpy==1.14.3
      |    - kafka==1.3.5
      |    - pyspark==2.3.2
      |    - pandas==0.22.0
      |''';
      |
      |load script.`dependencies` as dependencies;
      |run command as PythonEnvExt.`/tmp/jack` where condaFile="dependencies" and command="create";
      |```
      |
    """.stripMargin)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy