All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.sourced.engine.util.Bblfsh.scala Maven / Gradle / Ivy

There is a newer version: 1.0.0
Show newest version
package tech.sourced.engine.util

import java.nio.charset.StandardCharsets

import gopkg.in.bblfsh.sdk.v1.protocol.generated.Status
import gopkg.in.bblfsh.sdk.v1.uast.generated.Node
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.bblfsh.client.BblfshClient

object Bblfsh extends Logging {

  case class Config(host: String, port: Int)

  /** Languages whose UAST will not be retrieved. */
  val excludedLangs = Set("markdown", "text")

  /** Key used for the option to specify the host of the bblfsh grpc service. */
  val hostKey = "spark.tech.sourced.bblfsh.grpc.host"

  /** Key used for the option to specify the port of the bblfsh grpc service. */
  val portKey = "spark.tech.sourced.bblfsh.grpc.port"

  /** Default bblfsh host. */
  val defaultHost = "0.0.0.0"

  /** Default bblfsh port. */
  val defaultPort = 9432

  private var config: Config = _
  private var client: BblfshClient = _

  /**
    * Returns the configuration for bblfsh.
    *
    * @param session Spark session
    * @return bblfsh configuration
    */
  def getConfig(session: SparkSession): Config = {
    if (config == null) {
      val host = session.conf.get(hostKey, Bblfsh.defaultHost)
      val port = session.conf.get(portKey, Bblfsh.defaultPort.toString).toInt
      config = Config(host, port)
    }

    config
  }

  private def getClient(config: Config): BblfshClient = synchronized {
    if (client == null) {
      client = BblfshClient(config.host, config.port)
    }

    client
  }

  /**
    * Extracts the UAST using bblfsh.
    *
    * @param path    File path
    * @param content File content
    * @param lang    File language
    * @param config bblfsh configuration
    * @return List of uast nodes binary-encoded as a byte array
    */
  def extractUAST(path: String,
                  content: Array[Byte],
                  lang: String,
                  config: Config): Seq[Array[Byte]] = {

    //FIXME(bzz): not everything is UTF-8 encoded :/

    // if lang == null, it hasn't been classified yet
    // so rely on bblfsh to guess this file's language
    if (lang != null && excludedLangs.contains(lang.toLowerCase())) {
      Seq()
    } else {
      val client = getClient(config)
      val contentStr = new String(content, StandardCharsets.UTF_8)
      val parsed = client.parse(path, content = contentStr, lang = lang)
      if (parsed.status == Status.OK) {
        Seq(parsed.uast.get.toByteArray)
      } else {
        logWarning(s"${parsed.status} $path: ${parsed.errors.mkString("; ")}")
        Seq()
      }
    }
  }

  /**
    * Filter an UAST node using the given query.
    *
    * @param node An UAST node
    * @param query XPath expression
    * @param config bblfsh configuration
    * @return UAST list of filtered nodes
    */
  def filter(node: Node, query: String, config: Config): List[Node] = {
    getClient(config).filter(node, query)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy