All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.sourced.gitbase.spark.udf.Uast.scala Maven / Gradle / Ivy

The newest version!
package tech.sourced.gitbase.spark.udf

import gopkg.in.bblfsh.sdk.v1.protocol.generated.Status
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types.{BinaryType, DataType}
import org.bblfsh.client.BblfshClient
import tech.sourced.enry.Enry

case class Uast(exprs: Seq[Expression]) extends Expression with CodegenFallback {

  override def nullable: Boolean = true

  override def toString: String = s"uast(${exprs.mkString(", ")})"

  override def children: Seq[Expression] = exprs

  override def eval(inputRow: InternalRow): Any = {
    val args = exprs.map(_.eval(inputRow))
    var query = ""
    var content: Array[Byte] = null
    var lang = ""

    args.length match {
      case 1 =>
        content = args.head.asInstanceOf[Array[Byte]]
      case 2 =>
        content = args.head.asInstanceOf[Array[Byte]]
        lang = Option(args(1)).map(_.toString).orNull
      case 3 =>
        content = args.head.asInstanceOf[Array[Byte]]
        lang = Option(args(1)).map(_.toString).orNull
        query = Option(args(2)).map(_.toString).orNull
      case ln =>
        throw new SparkException(
          s"invalid arguments provided to uast, expecting 1, 2 or 3, got $ln"
        )
    }

    if (lang == null || content == null || query == null) {
      return null
    }

    if (lang == "") {
      lang = Enry.getLanguage("file", content)
    }

    Uast.get(content, lang, query).orNull
  }

  override def dataType: DataType = BinaryType

}

object Uast extends CustomExprFunction with Logging {

  override def name: String = "uast"
  override def function: FunctionBuilder = exprs => Uast(exprs)

  def apply(cols: Column*): Column = new Column(Uast(cols.map(_.expr)))

  def get(content: Array[Byte],
          lang: String = "",
          query: String = ""): Option[Array[Byte]] =
    try {
      if (content == null || content.isEmpty) {
        return None
      }

      if (!BblfshUtils.isSupportedLanguage(lang)) {
        return None
      }

      val res = BblfshUtils
        .getClient().parse("", new String(content, "UTF-8"), lang)

      if (res.status != Status.OK) {
        log.warn(s"couldn't get UAST : error ${res.status}: ${res.errors.mkString("; ")}")
        return None
      }

      val nodes = query match {
        case "" => Seq(res.uast.get)
        case q: String => BblfshClient.filter(res.uast.get, q)
      }

      BblfshUtils.marshalNodes(nodes)
    } catch {
      case e@(_: RuntimeException | _: Exception) =>
        log.error(s"couldn't get UAST: $e")
        None
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy