All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.sourced.gitbase.spark.udf.BblfshUtils.scala Maven / Gradle / Ivy

The newest version!
package tech.sourced.gitbase.spark.udf

import java.nio.ByteBuffer

import gopkg.in.bblfsh.sdk.v1.uast.generated.Node
import org.apache.spark.internal.Logging
import org.bblfsh.client.BblfshClient

import scala.collection.mutable.ArrayBuffer
import scala.util.{Try, Success, Failure}

object BblfshUtils extends Logging {
  /** Key used for the option to specify the host of the bblfsh grpc service. */
  val hostKey = "spark.tech.sourced.bblfsh.grpc.host"

  /** Key used for the option to specify the port of the bblfsh grpc service. */
  val portKey = "spark.tech.sourced.bblfsh.grpc.port"

  /** Default bblfsh host. */
  val defaultHost = "0.0.0.0"

  /** Default bblfsh port. */
  val defaultPort = 9432

  private var client: BblfshClient = _
  private var supportedLanguages: Seq[String] = _

  def getClient(): BblfshClient = synchronized {
    if (client == null) {
      val host = spark.conf.get(hostKey, defaultHost)
      val port = Try(spark.conf.get(portKey, defaultPort.toString).toInt) match {
        case Success(p) => p
        case Failure(e) => {
          log.warn(s"couldn't get value for config key ${portKey}, " +
            s"default value ${defaultPort} will be used")
          defaultPort
        }
      }
      client = BblfshClient(host, port)
    }
    client
  }

  def isSupportedLanguage(lang: String): Boolean = {
    if (supportedLanguages == null) {
      supportedLanguages = getClient()
        .supportedLanguages()
        .languages
        .map(_.language.toLowerCase())
    }

    supportedLanguages.contains(lang.toLowerCase)
  }

  def marshalNodes(nodes: Seq[Node]): Option[Array[Byte]] = {
    if (nodes == null || nodes.isEmpty) {
      None
    } else {
      val serialized = nodes.filter(_ != null).map(_.toByteArray)
      val size = serialized.foldLeft(0)((sum, sn) => sum + sn.length) + 4 * serialized.length
      val result = new Array[Byte](size)
      val buffer = ByteBuffer.wrap(result)
      serialized.foreach(sn => buffer.putInt(sn.length).put(sn))
      Some(result)
    }
  }

  def unmarshalNodes(blob: Array[Byte]): Option[Seq[Node]] = {
    if (blob == null || blob.isEmpty) {
      None
    } else {
      val nodes = ArrayBuffer[Node]()
      val buffer = ByteBuffer.wrap(blob)
      while (buffer.hasRemaining()) {
        val size = buffer.getInt()
        val sn = new Array[Byte](size)
        buffer.get(sn)
        val node = Node.parseFrom(sn)
        nodes += node
      }
      Some(nodes.toList)
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy