
io.projectglow.bgen.BgenFileFormat.scala Maven / Gradle / Ivy
/*
* Copyright 2019 The Glow Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.projectglow.bgen
import java.io.{BufferedReader, File, InputStreamReader}
import java.nio.file.Paths
import scala.collection.JavaConverters._
import com.google.common.io.LittleEndianDataInputStream
import com.google.common.util.concurrent.Striped
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
import org.apache.spark.TaskContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
import org.apache.spark.sql.types.StructType
import org.skife.jdbi.v2.DBI
import org.skife.jdbi.v2.util.LongMapper
import io.projectglow.common.{BgenOptions, GlowLogging, WithUtils}
import io.projectglow.common.logging.{HlsEventRecorder, HlsTagValues}
import io.projectglow.sql.util.{ComDatabricksDataSource, SerializableConfiguration}
class BgenFileFormat extends FileFormat with DataSourceRegister with Serializable with GlowLogging {
override def shortName(): String = "bgen"
override def inferSchema(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
Option(BgenSchemaInferrer.inferSchema(sparkSession, files, options))
}
override def prepareWrite(
sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
throw new UnsupportedOperationException(
"BGEN data source does not support writing sharded BGENs; use bigbgen."
)
}
override def isSplitable(
sparkSession: SparkSession,
options: Map[String, String],
path: Path): Boolean = {
true
}
override def buildReader(
spark: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
val useIndex = options.get(BgenOptions.USE_INDEX_KEY).forall(_.toBoolean)
val ignoreExtension = options.get(BgenOptions.IGNORE_EXTENSION_KEY).exists(_.toBoolean)
val sampleIdsOpt = BgenFileFormat.getSampleIds(options, hadoopConf)
val hardCallsThreshold = options
.getOrElse(BgenOptions.HARD_CALL_THRESHOLD_KEY, BgenOptions.HARD_CALL_THRESHOLD_VALUE)
.toDouble
// record bgenRead event in the log along with the option
BgenFileFormat.logBgenRead(useIndex)
val serializableConf = new SerializableConfiguration(hadoopConf)
file => {
val path = new Path(file.filePath)
val hadoopFs = path.getFileSystem(serializableConf.value)
nextVariantIndex(hadoopFs, file, useIndex, ignoreExtension) match {
case None =>
Iterator.empty
case Some(pos) =>
logger.info(s"Next variant index is $pos")
val stream = hadoopFs.open(path)
val littleEndianStream = new LittleEndianDataInputStream(stream)
Option(TaskContext.get()).foreach { tc =>
tc.addTaskCompletionListener[Unit] { _ =>
stream.close()
}
}
val header = new BgenHeaderReader(littleEndianStream).readHeader(sampleIdsOpt)
val startPos = Math.max(pos, header.firstVariantOffset)
stream.seek(startPos)
val rowConverter = new BgenRowToInternalRowConverter(requiredSchema, hardCallsThreshold)
val iter = new BgenFileIterator(
header,
littleEndianStream,
stream,
file.start,
file.start + file.length
)
iter.init()
iter.map(rowConverter.convertRow)
}
}
}
/**
* Returns the position where the iterator for a given file split should start looking for
* variants, or None no variants start in this block.
*
* Note that the file iterator should not necessarily begin returning variants from the index
* returned by this file -- it may need to skip over part of the file until it reaches
* the allotted file split.
*/
private def nextVariantIndex(
hadoopFs: FileSystem,
file: PartitionedFile,
useIndex: Boolean,
ignoreExtension: Boolean): Option[Long] = {
if (!file.filePath.endsWith(BgenFileFormat.BGEN_SUFFIX) && !ignoreExtension) {
return None
}
val indexFile = new Path(file.filePath + BgenFileFormat.INDEX_SUFFIX)
if (hadoopFs.exists(indexFile) && useIndex) {
logger.info(s"Found index file ${indexFile} for BGEN file ${file.filePath}")
val localIdxPath = downloadIdxIfNecessary(hadoopFs, indexFile)
val dbi = new DBI(s"jdbc:sqlite:$localIdxPath")
WithUtils.withCloseable(dbi.open()) { handle =>
val query = handle
.createQuery(BgenFileFormat.NEXT_IDX_QUERY)
.bind("pos", file.start)
.map(LongMapper.FIRST)
Option(query.first())
}
} else {
Some(0) // have to start at beginning of file :(
}
}
private def downloadIdxIfNecessary(hadoopFs: FileSystem, path: Path): String = {
val localDir = Paths.get(System.getProperty("java.io.tmpdir")).resolve("bgen_indices").toFile
localDir.mkdirs()
val localPath = s"$localDir/${path.getName.replaceAllLiterally("/", "__")}"
WithUtils.withLock(BgenFileFormat.idxLock.get(path)) {
if (!new File(localPath).exists()) {
hadoopFs.copyToLocalFile(path, new Path("file:" + localPath))
}
}
localPath
}
}
class ComDatabricksBgenFileFormat extends BgenFileFormat with ComDatabricksDataSource
object BgenFileFormat extends HlsEventRecorder {
import io.projectglow.common.BgenOptions._
val BGEN_SUFFIX = ".bgen"
val INDEX_SUFFIX = ".bgi"
val NEXT_IDX_QUERY =
"""
|SELECT MIN(file_start_position) from Variant
|WHERE file_start_position > :pos
""".stripMargin
val idxLock = Striped.lock(100)
def logBgenRead(useIndex: Boolean): Unit = {
recordHlsEvent(HlsTagValues.EVENT_BGEN_READ, Map(BgenOptions.USE_INDEX_KEY -> useIndex))
}
/**
* Given a path to an Oxford-style .sample file exists (option: sampleFilePath), reads the sample
* IDs from the appropriate column (option: sampleIdColumn, default: ID_2).
*
* If the column does not exist in the file, or if at least one row is malformed (does not contain
* a sample ID), an IllegalArgumentException is thrown.
* If no path is given, None is returned; if a valid path is given, Some list is returned.
*/
def getSampleIds(options: Map[String, String], hadoopConf: Configuration): Option[Seq[String]] = {
val samplePathOpt = options.get(SAMPLE_FILE_PATH_OPTION_KEY)
if (samplePathOpt.isEmpty) {
// No sample file path provided
return None
}
val samplePath = samplePathOpt.get
val path = new Path(samplePath)
val hadoopFs = path.getFileSystem(hadoopConf)
val stream = hadoopFs.open(path)
val streamReader = new InputStreamReader(stream)
val bufferedReader = new BufferedReader(streamReader)
// The first two (2) lines in a .sample file are header lines
val sampleCol =
options.getOrElse(SAMPLE_ID_COLUMN_OPTION_KEY, SAMPLE_ID_COLUMN_OPTION_DEFAULT_VALUE)
val sampleColIdx = bufferedReader.readLine().split(" ").indexOf(sampleCol)
bufferedReader.readLine() // Column-type line
WithUtils.withCloseable(stream) { s =>
if (sampleColIdx == -1) {
throw new IllegalArgumentException(s"No column named $sampleCol in $samplePath.")
}
val sampleLines = bufferedReader
.lines()
.iterator()
.asScala
.toList
val sampleIds = sampleLines.map { line =>
if (line.split(" ").length > sampleColIdx) {
line.split(" ").apply(sampleColIdx)
} else {
throw new IllegalArgumentException(
s"Malformed line in $samplePath: fewer than ${sampleColIdx + 1} columns in $line"
)
}
}
Some(sampleIds)
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy