All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.sourced.engine.package.scala Maven / Gradle / Ivy

The newest version!
package tech.sourced

import gopkg.in.bblfsh.sdk.v1.uast.generated.Node
import org.apache.spark.SparkException
import org.apache.spark.api.java.function.MapFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import tech.sourced.engine.udf._
import tech.sourced.engine.util.Bblfsh

/**
  * Provides the [[tech.sourced.engine.Engine]] class, which is the main entry point
  * of all the analysis you might do using this library as well as some implicits
  * to make it easier to use. In particular, it adds some methods to be able to
  * join with other "tables" directly from any [[org.apache.spark.sql.DataFrame]].
  *
  * {{{
  * import tech.sourced.engine._
  *
  * val engine = Engine(sparkSession, "/path/to/repositories")
  * }}}
  *
  * If you don't want to import everything in the engine, even though it only exposes
  * what's truly needed to not pollute the user namespace, you can do it by just importing
  * the [[tech.sourced.engine.Engine]] class and the [[tech.sourced.engine.EngineDataFrame]]
  * implicit class.
  *
  * {{{
  * import tech.sourced.engine.{Engine, EngineDataFrame}
  *
  * val engine = Engine(sparkSession, "/path/to/repositories")
  * }}}
  */
package object engine {

  /**
    * Key used for the option to specify the path of siva files.
    */
  private[engine] val RepositoriesPathKey = "spark.tech.sourced.engine.repositories.path"

  /**
    * Key used for the option to specify the type of repositories. It can be siva, bare or standard
    */
  private[engine] val RepositoriesFormatKey = "spark.tech.sourced.engine.repositories.format"

  /**
    * Key used for the option to specify whether files should be deleted after
    * their usage or not.
    */
  private[engine] val SkipCleanupKey = "spark.tech.sourced.engine.cleanup.skip"

  /**
    * Key used for the option to specify whether read errors should be skipped
    * or not.
    */
  private[engine] val SkipReadErrorsKey = "spark.tech.sourced.engine.skip.read.errors"

  // DataSource names
  val DefaultSourceName: String = "tech.sourced.engine"
  val MetadataSourceName: String = "tech.sourced.engine.MetadataSource"

  // Table names
  val RepositoriesTable: String = "repositories"
  val ReferencesTable: String = "references"
  val CommitsTable: String = "commits"
  val TreeEntriesTable: String = "tree_entries"
  val BlobsTable: String = "blobs"
  val RepositoryHasCommitsTable: String = "repository_has_commits"


  // The keys RepositoriesPathKey, bblfshHostKey, bblfshPortKey and SkipCleanupKey must
  // start by "spark." to be able to be loaded from the "spark-defaults.conf" file.

  /**
    * Implicit class that adds some functions to the [[org.apache.spark.sql.SparkSession]].
    *
    * @param session Spark Session
    */
  implicit class SessionFunctions(session: SparkSession) {

    /**
      * Registers some user defined functions in the [[org.apache.spark.sql.SparkSession]].
      */
    def registerUDFs(): Unit = {
      SessionFunctions.UDFtoRegister.foreach(customUDF => session.udf.register(
        customUDF.name,
        customUDF(session)
      ))
    }

    /**
      * Returns the number of executors that are active right now.
      *
      * @return number of active executors
      */
    def currentActiveExecutors(): Int = {
      val sc = session.sparkContext
      val driver = sc.getConf.get("spark.driver.host")
      val executors = sc.getExecutorMemoryStatus
        .keys
        .filter(ex => ex.split(":").head != driver)
        .toArray
        .distinct
        .length

      // If there are no executors, it means it's a local job
      // so there's just one node to get the data from.
      if (executors > 0) executors else 1
    }

  }

  /**
    * Create a new [[Node]] from a binary-encoded node as a byte array.
    *
    * @param data binary-encoded node as byte array
    * @return parsed [[Node]]
    */
  def parseUASTNode(data: Array[Byte]): Node = Node.parseFrom(data)

  /**
    * Adds some utility methods to the [[org.apache.spark.sql.DataFrame]] class
    * so you can, for example, get the references, commits, etc from a data frame
    * containing repositories.
    *
    * @param df the DataFrame
    */
  implicit class EngineDataFrame(df: DataFrame) {

    import df.sparkSession.implicits._

    /**
      * Returns a new [[org.apache.spark.sql.DataFrame]] with the product of joining the
      * current dataframe with the references dataframe.
      * It requires the dataframe to have an "id" column, which should be the repository
      * identifier.
      *
      * {{{
      * val refsDf = reposDf.getReferences
      * }}}
      *
      * @return new DataFrame containing also references data.
      */
    def getReferences: DataFrame = {
      checkCols(df, "id")
      val reposIdsDf = df.select($"id")
      getDataSource(ReferencesTable, df.sparkSession)
        .join(reposIdsDf, $"repository_id" === $"id")
        .drop($"id")
    }

    /**
      * Returns a new [[org.apache.spark.sql.DataFrame]] with only remote references in it.
      * If the DataFrame contains repository data it will automatically get the references for
      * those repositories.
      *
      * {{{
      * val remoteRefs = reposDf.getRemoteReferences
      * val remoteRefs2 = reposDf.getReferences.getRemoteReferences
      * }}}
      *
      * @return
      */
    def getRemoteReferences: DataFrame = {
      if (df.columns.contains("is_remote")) {
        df.filter(df("is_remote") === true)
      } else {
        df.getReferences.getRemoteReferences
      }
    }

    /**
      * Returns a new [[org.apache.spark.sql.DataFrame]] with the product of joining the
      * current dataframe with the commits dataframe, returning only the last commit in a
      * reference (aka the current state).
      * It requires the current dataframe to have a "repository_id" column, which is the
      * identifier of the repository.
      *
      * {{{
      * val commitDf = refsDf.getCommits
      * }}}
      *
      * You can use [[tech.sourced.engine.EngineDataFrame#getAllReferenceCommits]] to
      * get all the commits in the references, but do so knowing that is a very costly operation.
      *
      * {{{
      * val allCommitsDf = refsDf.getCommits.getAllReferenceCommits
      * }}}
      *
      * @return new DataFrame containing also commits data.
      */
    def getCommits: DataFrame = {
      checkCols(df, "repository_id")
      val refsIdsDf = df.select($"name", $"repository_id")
      val commitsDf = getDataSource(CommitsTable, df.sparkSession)
      commitsDf.join(refsIdsDf, refsIdsDf("repository_id") === commitsDf("repository_id") &&
        commitsDf("reference_name") === refsIdsDf("name"))
        .drop(refsIdsDf("name")).drop(refsIdsDf("repository_id"))
    }

    /**
      * Returns a new [[org.apache.spark.sql.DataFrame]] with all the commits in a reference.
      * After calling this method, commits may appear multiple times in your [[DataFrame]],
      * because most of your commits will be shared amongst references. Take into account that
      * calling this method will make all your further operations way slower, as there is much
      * more data, specially if you query blobs, which will be repeated over and over.
      *
      * For the next example, consider we have a master branch with 100 commits and a "foo" branch
      * whose parent is the HEAD of master and has two more commits.
      * {{{
      * > commitsDf.count()
      * 2
      * > commitsDf.getAllReferenceCommits.count()
      * 202
      * }}}
      *
      * @return dataframe with all reference commits
      */
    def getAllReferenceCommits: DataFrame = {
      if (df.columns.contains("index")) {
        df.filter("index <> '-1'")
      } else {
        df.getCommits.getAllReferenceCommits
      }
    }

    /**
      * Returns a new [[org.apache.spark.sql.DataFrame]] with the product of joining the
      * current dataframe with the tree entries dataframe.
      *
      * {{{
      * val entriesDf = commitsDf.getTreeEntries
      * }}}
      *
      * @return new DataFrame containing also tree entries data.
      */
    def getTreeEntries: DataFrame = {
      checkCols(df, "index", "hash") // references also has hash, index makes sure that is commits
      val commitsDf = df.select("hash")
      val entriesDf = getDataSource(TreeEntriesTable, df.sparkSession)
      entriesDf.join(commitsDf, entriesDf("commit_hash") === commitsDf("hash"))
        .drop($"hash")
    }

    /**
      * Returns a new [[org.apache.spark.sql.DataFrame]] with the product of joining the
      * current dataframe with the blobs dataframe. If the current dataframe does not contain
      * the tree entries data, getTreeEntries will be called automatically.
      *
      * {{{
      * val blobsDf = treeEntriesDf.getBlobs
      * val blobsDf2 = commitsDf.getBlobs // can be obtained from commits too
      * }}}
      *
      * @return new DataFrame containing also blob data.
      */
    def getBlobs: DataFrame = {
      if (!df.columns.contains("blob")) {
        df.getTreeEntries.getBlobs
      } else {
        val treesDf = df.select("path", "blob")
        val blobsDf = getDataSource(BlobsTable, df.sparkSession)
        blobsDf.join(treesDf, treesDf("blob") === blobsDf("blob_id"))
          .drop($"blob")
      }
    }

    private val HeadRef = "refs/heads/HEAD"
    private val LocalHeadRef = "HEAD"

    /**
      * Returns a new [[org.apache.spark.sql.DataFrame]] containing only the rows
      * with a HEAD reference.
      *
      * {{{
      * val headDf = refsDf.getHEAD
      * }}}
      *
      * @return new dataframe with only HEAD reference rows
      */
    def getHEAD: DataFrame = {
      if (df.schema.fieldNames.contains("reference_name")) {
        df.filter(($"reference_name" === HeadRef).or($"reference_name" === LocalHeadRef))
      } else if (df.schema.fieldNames.contains("name")) {
        df.filter(($"name" === HeadRef).or($"name" === LocalHeadRef))
      } else {
        df.getReferences.getHEAD
      }
    }

    /**
      * Returns a new [[org.apache.spark.sql.DataFrame]] containing only the rows
      * with a master reference.
      *
      * {{{
      * val masterDf = refsDf.getMaster
      * }}}
      *
      * @return new dataframe with only the master reference rows
      */
    def getMaster: DataFrame = getReference("refs/heads/master")

    /**
      * Returns a new [[org.apache.spark.sql.DataFrame]] containing only the rows
      * with a reference whose name equals the one provided.
      *
      * {{{
      * val developDf = refsDf.getReference("refs/heads/develop")
      * }}}
      *
      * @param name name of the reference to filter by
      * @return new dataframe with only the given reference rows
      */
    def getReference(name: String): DataFrame = {
      if (df.schema.fieldNames.contains("reference_name")) {
        df.filter($"reference_name" === name)
      } else if (df.schema.fieldNames.contains("name")) {
        df.filter($"name" === name)
      } else {
        df.getReferences.getReference(name)
      }
    }

    /**
      * Returns a new [[org.apache.spark.sql.DataFrame]] with a new column "lang" added
      * containing the language of the non-binary files.
      * It requires the current dataframe to have the files data.
      *
      * {{{
      * val languagesDf = filesDf.classifyLanguages
      * }}}
      *
      * @return new DataFrame containing also language data.
      */
    def classifyLanguages: DataFrame = {
      val newDf = df.withColumn("lang", typedLit(null: String))
      val encoder = RowEncoder(newDf.schema)
      newDf.map(new MapFunction[Row, Row] {
        override def call(row: Row): Row = {
          val (isBinaryIdx, pathIdx, contentIdx) = try {
            (row.fieldIndex("is_binary"), row.fieldIndex("path"), row.fieldIndex("content"))
          } catch {
            case _: IllegalArgumentException =>
              throw new SparkException(s"classifyLanguages can not be applied to this DataFrame: "
                + "unable to find all these columns: is_binary, path, content")
          }

          val (isBinary, path, content) = (
            row.getBoolean(isBinaryIdx),
            row.getString(pathIdx),
            row.getAs[Array[Byte]](contentIdx)
          )

          val lang = ClassifyLanguagesUDF.getLanguage(isBinary, path, content).orNull
          Row(row.toSeq.dropRight(1) ++ Seq(lang): _*)
        }
      }, encoder)
    }

    /**
      * Returns a new [[org.apache.spark.sql.DataFrame]] with a new column "uast" added,
      * that contains Protobuf serialized UAST.
      * It requires the current dataframe to have file's data: path and content.
      * If language is available, it's going to be used i.e to avoid parsing Text and Markdown.
      *
      * {{{
      * val uastsDf = filesDf.extractUASTs
      * }}}
      *
      * @return new DataFrame that contains Protobuf serialized UAST.
      */
    def extractUASTs(): DataFrame = {
      val newDf = df.withColumn("uast", typedLit(Seq[Array[Byte]]()))
      val configB = df.sparkSession.sparkContext.broadcast(Bblfsh.getConfig(df.sparkSession))
      val encoder = RowEncoder(newDf.schema)
      newDf.map(new MapFunction[Row, Row] {
        override def call(row: Row): Row = {
          val lang = (try {
            Some(row.fieldIndex("lang"))
          } catch {
            case _: IllegalArgumentException =>
              None
          }).map(row.getString).orNull

          val (pathIdx, contentIdx) = try {
            (row.fieldIndex("path"), row.fieldIndex("content"))
          } catch {
            case _: IllegalArgumentException =>
              throw new SparkException(s"extractUASTs can not be applied to this DataFrame: "
                + "unable to find all these columns: path, content")
          }

          val (path, content) = (row.getString(pathIdx), row.getAs[Array[Byte]](contentIdx))
          val uast = ExtractUASTsUDF.extractUASTs(path, content, lang, configB.value)
          Row(row.toSeq.dropRight(1) ++ Seq(uast): _*)
        }
      }, encoder)
    }

    /**
      * Queries a list of UAST nodes with the given query to get specific nodes,
      * and puts the result in another column.
      *
      * {{{
      * import gopkg.in.bblfsh.sdk.v1.uast.generated.{Node, Role}
      *
      * // get all identifiers
      * var identifiers = uastsDf.queryUAST("//\*[@roleIdentifier and not(@roleIncomplete)]")
      *   .collect()
      *   .map(row => row(row.fieldIndex("result")))
      *   .flatMap(_.asInstanceOf[Seq[Array[Byte]]])
      *   .map(Node.from)
      *   .map(_.token)
      *
      * // get all identifiers from column "foo" and put them in "bar"
      * identifiers = uastsDf.queryUAST("//\*[@roleIdentifier and not(@roleIncomplete)]",
      *                                 "foo", "bar")
      *   .collect()
      *   .map(row => row(row.fieldIndex("result")))
      *   .flatMap(_.asInstanceOf[Seq[Array[Byte]]])
      *   .map(Node.from)
      *   .map(_.token)
      * }}}
      *
      * @param query        xpath query
      * @param queryColumn  column where the list of UAST nodes to query are
      * @param outputColumn column where the result of the query will be placed
      * @return [[DataFrame]] with "result" column containing the nodes
      */
    def queryUAST(query: String,
                  queryColumn: String = "uast",
                  outputColumn: String = "result"): DataFrame = {
      checkCols(df, "uast", queryColumn)
      if (df.columns.contains(outputColumn)) {
        throw new SparkException(s"DataFrame already contains a column named $outputColumn")
      }

      import org.apache.spark.sql.functions.lit
      df.withColumn(outputColumn, QueryXPathUDF(df.sparkSession)(df(queryColumn), lit(query)))
    }

    /**
      * Extracts the tokens in all nodes of a given column and puts the list of retrieved
      * tokens in a new column.
      *
      * {{{
      * val tokensDf = uastDf.queryUAST("//\*[@roleIdentifier and not(@roleIncomplete)]")
      *     .extractTokens()
      * }}}
      *
      * @param queryColumn  column where the UAST nodes are.
      * @param outputColumn column to put the result
      * @return new [[DataFrame]] with the extracted tokens
      */
    def extractTokens(queryColumn: String = "result",
                      outputColumn: String = "tokens"): DataFrame = {
      checkCols(df, queryColumn)
      if (df.columns.contains(outputColumn)) {
        throw new SparkException(s"DataFrame already contains a column named $outputColumn")
      }

      df.withColumn(outputColumn, ExtractTokensUDF()(df(queryColumn)))
    }

  }

  /**
    * Returns a [[org.apache.spark.sql.DataFrame]] for the given table using the provided
    * [[org.apache.spark.sql.SparkSession]].
    *
    * @param table   name of the table
    * @param session spark session
    * @return dataframe for the given table
    */
  private[engine] def getDataSource(table: String, session: SparkSession): DataFrame =
    session.table(table)

  /**
    * Ensures the given [[org.apache.spark.sql.DataFrame]] contains some required columns.
    *
    * @param df   dataframe that must have the columns
    * @param cols names of the columns that are required
    * @throws SparkException if any column is missing
    */
  private[engine] def checkCols(df: DataFrame, cols: String*): Unit = {
    if (!df.columns.exists(cols.contains)) {
      throw new SparkException(s"Method can not be applied to this DataFrame: "
        + s"required:'${cols.mkString(" ")}',"
        + s" actual columns:'${df.columns.mkString(" ")}'")
    }
  }

  /**
    * Object that contains the user defined functions that will be added to the session.
    */
  private[engine] object SessionFunctions {

    /**
      * List of custom functions to be registered.
      */
    val UDFtoRegister: List[CustomUDF] = List(
      ClassifyLanguagesUDF,
      ExtractUASTsUDF,
      QueryXPathUDF,
      ExtractTokensUDF,
      ConcatArrayUDF
    )
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy