All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.udf.PythonRuntimeCompileUDAF.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.udf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StructType}
import org.python.core.{Py, PyObject}
import streaming.dsl.ScriptSQLExec
import streaming.dsl.mmlib.algs.ScriptUDFCacheKey
import streaming.jython.{JythonUtils, PythonInterp}

/**
  * Created by fchen on 2018/11/15.
  */
object PythonRuntimeCompileUDAF extends RuntimeCompileUDAF {
  /**
    * validate the source code
    */
  override def check(sourceCode: String): Boolean = true

  /**
    * compile the source code.
    *
    * @param scriptCacheKey
    * @return
    */
  override def compile(scriptCacheKey: ScriptUDFCacheKey): AnyRef = {
    PythonInterp.compilePython(scriptCacheKey.originalCode, scriptCacheKey.className)
  }


  override def generateFunction(scriptCacheKey: ScriptUDFCacheKey): UserDefinedAggregateFunction = {

    new UserDefinedAggregateFunction with Serializable {

      val c = ScriptSQLExec.contextGetOrForTest()
      val wrap = (fn: () => Any) => {
        try {
          ScriptSQLExec.setContextIfNotPresent(c)
          fn()
        } catch {
          case e: Exception =>
            throw e
        }
      }

      @transient val objectUsingInDriver = wrap(() => {
        driverExecute(scriptCacheKey).asInstanceOf[PyObject].__call__()
      }).asInstanceOf[PyObject]

      lazy val objectUsingInExecutor = wrap(() => {
        executorExecute(scriptCacheKey).asInstanceOf[PyObject].__call__()
      }).asInstanceOf[PyObject]


      val _inputSchema = objectUsingInDriver.__getattr__("inputSchema").__call__()
      val _dataType = objectUsingInDriver.__getattr__("dataType").__call__()
      val _bufferSchema = objectUsingInDriver.__getattr__("bufferSchema").__call__()
      val _deterministic = objectUsingInDriver.__getattr__("deterministic").__call__()

      override def inputSchema: StructType = {
        wrap(() => {
          _inputSchema.__tojava__(classOf[StructType]).asInstanceOf[StructType]
        }).asInstanceOf[StructType]
      }

      override def dataType: DataType = {
        wrap(() => {
          _dataType.__tojava__(classOf[DataType]).asInstanceOf[DataType]
        }).asInstanceOf[DataType]
      }

      override def bufferSchema: StructType = {
        wrap(() => {
          _bufferSchema.__tojava__(classOf[StructType]).asInstanceOf[StructType]
        }).asInstanceOf[StructType]

      }

      override def deterministic: Boolean = {
        wrap(() => {
          JythonUtils.toJava(_deterministic).toString.toInt match {
            case 0 => false
            case 1 => true
          }
        }).asInstanceOf[Boolean]

      }

      lazy val _update = objectUsingInExecutor.__getattr__("update")
      lazy val _merge = objectUsingInExecutor.__getattr__("merge")
      lazy val _initialize = objectUsingInExecutor.__getattr__("initialize")
      lazy val _evaluate = objectUsingInExecutor.__getattr__("evaluate")

      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        wrap(() => {
          _update.__call__(Py.java2py(buffer), Py.java2py(input))
        })

      }

      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        wrap(() => {
          _merge.__call__(Py.java2py(buffer1), Py.java2py(buffer2))
        })

      }

      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        wrap(() => {
          _initialize.__call__(Py.java2py(buffer))
        })

      }

      override def evaluate(buffer: Row): Any = {
        wrap(() => {
          JythonUtils.toJava(_evaluate.__call__(Py.java2py(buffer)))
        })

      }


    }
  }

  override def lang: String = "python"
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy