All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.dianahep.sparkroot.sparkroot.scala Maven / Gradle / Ivy

There is a newer version: 0.1.16
Show newest version
package org.dianahep

// spark related
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.DataFrameReader
import org.apache.spark.sql.Row
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.sources.PrunedFilteredScan
import org.apache.spark.sql.sources.RelationProvider
import org.apache.spark.sql.types._

// hadoop hdfs 
import org.apache.hadoop.fs.{Path, FileSystem}

// sparkroot or root4j
import org.dianahep.root4j.core.RootInput
import org.dianahep.root4j._
import org.dianahep.root4j.interfaces._
import sparkroot.ast._

package object sparkroot {
  /**
   * An impolicit DataFrame Reader
   */
  implicit class RootDataFrameReader(reader: DataFrameReader) {
    def root(paths: String*) = reader.format("org.dianahep.sparkroot").load(paths: _*)
    def root(path: String) = reader.format("org.dianahep.sparkroot").load(path)
  }

  /**
   *  ROOT TTree Iterator
   *  - iterates over the tree
   *  - pump the data from TTree into a spark Row
   */
  class RootTreeIterator(tree: TTree, // TTree
    streamers: Map[String, TStreamerInfo], // a map of streamers
    requiredColumns: Array[String], // columns that are required for a query
    filters: Array[Filter]) extends Iterator[Row] {

    //  Abstract Schema Tree
//    private val ast = buildAST(rootTree, null, requiredColumns)
    private val att = buildATT(tree, streamers, requiredColumns)
    //  next exists
//    def hasNext = containsNext(ast)
    def hasNext = containsNext(att)
    //  get the next Row
    def next() = readSparkRow(att)
  }

  /**
   * 1. Builds the Schema
   * 2. Maps execution of each file to a Tree iterator
   */
  class RootTableScan(path: String, treeName: String)(@transient val sqlContext: SQLContext) extends BaseRelation with PrunedFilteredScan{

    // path is either a dir with files or a pathToFile
    private val inputPathFiles = {
      val hPath = new Path(path)
      val fs = hPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
      fs.listStatus(hPath).map(_.getPath.toString)
    }
    
    // create the abstract tree
//    private val ast: AbstractSchemaTree = 
    private val att: core.SRType = 
    {
      //logger.info("Building the Abstract Schema Tree...")
      println(s"Building the Abstract Schema Tree... for treeName=$treeName")
      val reader = new RootFileReader(inputPathFiles head)
//      val tmp = buildAST(findTree(reader.getTopDir), null, null) 
      val tmp = 
        if (treeName==null)
          buildATT(findTree(reader.getTopDir), arrangeStreamers(reader), null)
        else 
          buildATT(reader.getKey(treeName).getObject.asInstanceOf[TTree],
            arrangeStreamers(reader), null)
      //logger.info("Done")
      println("Done")
      tmp
    }

    // define the schema from the AST
    def schema: StructType = {
      //logger.info("Building the Spark Schema")
      println("Building the Spark Schema")
      val s = buildSparkSchema(att)
//      val s = buildSparkSchema(ast)
      //logger.info("Done")
      println("Done")
      s
    }

    // builds a scan
    def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
      println("Building Scan")
      println(requiredColumns.mkString(" "))
      println(filters)
      val localTreeName = treeName  

      // parallelize over all the files
      val r = sqlContext.sparkContext.parallelize(inputPathFiles).
        flatMap({pathName =>
          // TODO: support HDFS (may involve changes to root4j)
          val reader = new RootFileReader(pathName)
          // get the TTree
          // TODO: we could add an option that specificies the TTree...
          val rootTree = 
//            findTree(reader)
            if (localTreeName == null) findTree(reader)
            else reader.getKey(localTreeName).getObject.asInstanceOf[TTree]
          // the real work starts here
          new RootTreeIterator(rootTree, arrangeStreamers(reader), 
            requiredColumns, filters)
        })

      println("Done building Scan")
      r
    }
  }
}

/**
 *  Default Source - spark.sqlContext.read.root(filename) will be directed here!
 */
package sparkroot {
  class DefaultSource extends RelationProvider {
    def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) = {
      println(parameters)
      new RootTableScan(parameters.getOrElse("path", sys.error("ROOT path must be specified")), parameters.getOrElse("tree", null))(sqlContext)
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy