All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.projectglow.bgen.BigBgenDatasource.scala Maven / Gradle / Ivy

/*
 * Copyright 2019 The Glow Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.projectglow.bgen

import java.io.ByteArrayOutputStream

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLUtils}
import org.apache.spark.sql.sources.DataSourceRegister

import io.projectglow.common.logging.{HlsEventRecorder, HlsTagValues}
import io.projectglow.sql.BigFileDatasource
import io.projectglow.sql.util.ComDatabricksDataSource

class BigBgenDatasource extends BigFileDatasource with DataSourceRegister {

  override def shortName(): String = "bigbgen"

  override def serializeDataFrame(
      options: Map[String, String],
      data: DataFrame): RDD[Array[Byte]] = {
    BigBgenDatasource.serializeDataFrame(options, data)
  }

}

class ComDatabricksBigBgenDatasource extends BigBgenDatasource with ComDatabricksDataSource

object BigBgenDatasource extends HlsEventRecorder {

  import io.projectglow.common.BgenOptions._

  private def parseOptions(options: Map[String, String]): BigBgenOptions = {
    val bitsPerProb = options.getOrElse(BITS_PER_PROB_KEY, BITS_PER_PROB_DEFAULT_VALUE).toInt
    val maxPloidy = options.getOrElse(MAX_PLOIDY_KEY, MAX_PLOIDY_VALUE).toInt
    val defaultPloidy = options.getOrElse(DEFAULT_PLOIDY_KEY, DEFAULT_PLOIDY_VALUE).toInt
    val defaultPhasing = options.getOrElse(DEFAULT_PHASING_KEY, DEFAULT_PHASING_VALUE).toBoolean
    BigBgenOptions(bitsPerProb, maxPloidy, defaultPloidy, defaultPhasing)
  }

  private def logBgenWrite(parsedOptions: BigBgenOptions): Unit = {
    val logOptions = Map(
      BITS_PER_PROB_KEY -> parsedOptions.bitsPerProb,
      MAX_PLOIDY_KEY -> parsedOptions.maxPloidy,
      DEFAULT_PLOIDY_KEY -> parsedOptions.defaultPloidy,
      DEFAULT_PHASING_KEY -> parsedOptions.defaultPhasing
    )
    recordHlsEvent(HlsTagValues.EVENT_BGEN_WRITE, logOptions)
  }

  def serializeDataFrame(options: Map[String, String], data: DataFrame): RDD[Array[Byte]] = {

    val parsedOptions = parseOptions(options)
    logBgenWrite(parsedOptions)

    val dSchema = data.schema
    val numVariants = data.count
    val rawRdd = data.queryExecution.toRdd

    val inputRdd = if (rawRdd.getNumPartitions == 0) {
      logger.warn("Writing BGEN header only as the input DataFrame has zero partitions.")
      SQLUtils.createEmptyRDD(data.sparkSession)
    } else {
      rawRdd
    }

    inputRdd.mapPartitionsWithIndex {
      case (idx, it) =>
        val baos = new ByteArrayOutputStream()

        val writeHeader = idx == 0
        val writer = new BgenRecordWriter(
          baos,
          dSchema,
          writeHeader,
          numVariants,
          parsedOptions.bitsPerProb,
          parsedOptions.maxPloidy,
          parsedOptions.defaultPloidy,
          parsedOptions.defaultPhasing
        )

        it.foreach { row =>
          writer.write(row)
        }

        writer.close()
        Iterator(baos.toByteArray)
    }
  }
}

case class BigBgenOptions(
    bitsPerProb: Int,
    maxPloidy: Int,
    defaultPloidy: Int,
    defaultPhasing: Boolean)




© 2015 - 2025 Weber Informatics LLC | Privacy Policy