All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.projectglow.sql.SqlExtensionProvider.scala Maven / Gradle / Ivy

/*
 * Copyright 2019 The Glow Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.projectglow.sql

import java.util.{List => JList, Map => JMap}

import scala.collection.JavaConverters._

import org.apache.spark.sql.{SQLUtils, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.yaml.snakeyaml.Yaml

import io.projectglow.SparkShim._
import io.projectglow.common.WithUtils
import io.projectglow.sql.expressions._
import io.projectglow.sql.optimizer.{ReplaceExpressionsRule, ResolveAggregateFunctionsRule, ResolveExpandStructRule, ResolveGenotypeFields}

// TODO(hhd): Spark 3.0 allows extensions to register functions. After Spark 3.0 is released,
// we should move all extensions into this class.
class GlowSQLExtensions extends (SparkSessionExtensions => Unit) {
  val resolutionRules: Seq[Rule[LogicalPlan]] =
    Seq(
      ReplaceExpressionsRule,
      ResolveAggregateFunctionsRule,
      ResolveExpandStructRule,
      ResolveGenotypeFields)
  val optimizations: Seq[Rule[LogicalPlan]] = Seq()

  def apply(extensions: SparkSessionExtensions): Unit = {
    resolutionRules.foreach { r =>
      extensions.injectResolutionRule(_ => r)
    }
    optimizations.foreach(r => extensions.injectOptimizerRule(_ => r))
  }
}

object SqlExtensionProvider {
  private val FUNCTION_YAML_PATH = "functions.yml"

  private def loadFunctionDefinitions(
      resourcePath: String
  ): Iterable[JMap[String, Any]] = {
    val yml = new Yaml()
    WithUtils.withCloseable(
      Thread
        .currentThread()
        .getContextClassLoader
        .getResourceAsStream(resourcePath)
    ) { stream =>
      val groups = yml.loadAs(stream, classOf[JMap[String, JMap[String, Any]]])
      groups
        .values()
        .asScala
        .flatMap(
          group =>
            group
              .asScala("functions")
              .asInstanceOf[JList[JMap[String, Any]]]
              .asScala
        )
    }
  }

  private def parameterError(functionName: String, params: Int): Exception = {
    SQLUtils.newAnalysisException(
      s"Invalid number of parameters for function '$functionName': $params"
    )
  }

  private def makeArgsDoc(args: Seq[JMap[String, Any]]): String = {
    args.map { _arg =>
      val arg = _arg.asScala
      val suffix =
        if (arg.get("is_optional").exists(_.asInstanceOf[Boolean])) {
          " (optional)"
        } else if (arg.get("is_var_args").exists(_.asInstanceOf[Boolean])) {
          " (repeated)"
        } else {
          ""
        }
      s"${arg("name")}: ${arg("doc")} $suffix"
    }.mkString("\n")
  }

  /**
   * Using the argument descriptions from the YAML file and the runtime argument expressions,
   * create the list of constructor parameters for the expression class.
   */
  private def makeChildren(
      functionName: String,
      args: Seq[JMap[String, Any]],
      exprs: Seq[Expression]): Seq[AnyRef] = {
    args.zipWithIndex.flatMap {
      case (_arg: JMap[String, Any], idx: Int) =>
        val arg = _arg.asScala
        // If the argument is optional and doesn't have a matching input, don't add a new
        // expression to the list of children.
        if (arg
            .get("is_optional")
            .exists(_.asInstanceOf[Boolean]) && idx >= exprs.size) {
          None
          // If we have a var args argument, the child expressions from here on are part of
          // the var args list.
        } else if (arg.get("is_var_args").exists(_.asInstanceOf[Boolean])) {
          Some(exprs.slice(idx, exprs.size))
        } else if (idx >= exprs.size) {
          throw parameterError(functionName, exprs.size)
        } else if (idx == args.size - 1 && exprs.size != args.size) {
          throw parameterError(functionName, exprs.size)
        } else {
          Some(exprs(idx))
        }
    }
  }

  /**
   * Register SQL functions based on a yaml function definition file.
   */
  def registerFunctions(
      conf: SQLConf,
      functionRegistry: FunctionRegistry,
      resourcePath: String = FUNCTION_YAML_PATH): Unit = {

    loadFunctionDefinitions(resourcePath).foreach { _function =>
      val function = _function.asScala
      val id = FunctionIdentifier(function("name").asInstanceOf[String])
      val exprClass = function("expr_class").asInstanceOf[String]
      val args = function("args").asInstanceOf[JList[JMap[String, Any]]].asScala
      val info = createExpressionInfo(
        exprClass,
        null,
        function("name").asInstanceOf[String],
        function("doc").asInstanceOf[String],
        makeArgsDoc(args),
        "",
        "",
        function("since").asInstanceOf[String]
      )
      functionRegistry.registerFunction(
        id,
        info,
        exprs => {
          val clazz = Class.forName(
            exprClass,
            true,
            Thread.currentThread().getContextClassLoader
          )
          val constructorArgs = makeChildren(id.funcName, args, exprs)
          val constructor = clazz
            .getConstructors
            .find(_.getParameterCount == constructorArgs.size)
            .getOrElse(throw parameterError(id.funcName, exprs.size))

          ExpressionHelper.rewrite(
            constructor
              .newInstance(constructorArgs: _*)
              .asInstanceOf[Expression]
          )
        }
      )
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy