
io.projectglow.vcf.VCFLineToInternalRowConverter.scala Maven / Gradle / Ivy
/*
* Copyright 2019 The Glow Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.projectglow.vcf
import java.util
import java.util.Collections
import com.google.common.base.Splitter
import htsjdk.samtools.ValidationStringency
import htsjdk.samtools.util.OverlapDetector
import htsjdk.variant.vcf.VCFHeader
import org.apache.hadoop.io.Text
import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types.{ArrayType, BooleanType, DataType, DoubleType, IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}
import io.projectglow.common.{GenotypeFields, HasStringency, SimpleInterval, VariantSchemas}
import scala.util.control.NonFatal
/**
* Converts the raw bytes in a VCF line into an [[InternalRow]]
*
* @param header The VCF header object this is currently only used to extract sample ids
* @param schema The schema of the converted rows
* @param stringency How to handle errors
* @param overlapDetectorOpt If provided, the converter will check to see if a variant passes this detector before
* parsing genotypes
*/
class VCFLineToInternalRowConverter(
header: VCFHeader,
schema: StructType,
val stringency: ValidationStringency,
overlapDetectorOpt: Option[OverlapDetector[SimpleInterval]])
extends HasStringency {
private val genotypeHolder = new Array[Any](header.getNGenotypeSamples)
private def findFieldIdx(field: StructField): Int = {
schema.indexWhere(SQLUtils.structFieldsEqualExceptNullability(_, field))
}
private val contigIdx = findFieldIdx(VariantSchemas.contigNameField)
private val startIdx = findFieldIdx(VariantSchemas.startField)
private val namesIdx = findFieldIdx(VariantSchemas.namesField)
private val refAlleleIdx = findFieldIdx(VariantSchemas.refAlleleField)
private val altAllelesIdx = findFieldIdx(VariantSchemas.alternateAllelesField)
private val qualIdx = findFieldIdx(VariantSchemas.qualField)
private val filtersIdx = findFieldIdx(VariantSchemas.filtersField)
private val endIdx = findFieldIdx(VariantSchemas.endField)
private val genotypesIdx = schema.indexWhere(_.name == VariantSchemas.genotypesFieldName)
private val splitFromMultiIdx = findFieldIdx(VariantSchemas.splitFromMultiAllelicField)
private lazy val sampleIds = header
.getGenotypeSamples
.toArray()
.map(el => UTF8String.fromString(el.asInstanceOf[String]))
private val infoFields = schema
.zipWithIndex
.collect {
case (f, idx) if f.name.startsWith("INFO_") =>
val name = f.name.stripPrefix("INFO_")
val typ = f.dataType
(UTF8String.fromString(name), (typ, idx))
}
.toMap
private val flagFields = infoFields.filter { f =>
f._2._1 == BooleanType
}
private val gSchemaOpt = schema
.find(_.name == VariantSchemas.genotypesFieldName)
.map(_.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType])
private var callsIdx = -1
private var phasedIdx = -1
private var sampleIdIdx = -1
private val genotypeFieldsOpt: Option[Map[UTF8String, (DataType, Int)]] = gSchemaOpt.map {
schema =>
schema
.zipWithIndex
.flatMap {
case (gf, idx) =>
if (SQLUtils.structFieldsEqualExceptNullability(gf, VariantSchemas.sampleIdField)) {
sampleIdIdx = idx
None
} else if (SQLUtils.structFieldsEqualExceptNullability(gf, VariantSchemas.callsField)) {
callsIdx = idx
None
} else if (SQLUtils.structFieldsEqualExceptNullability(gf, VariantSchemas.phasedField)) {
phasedIdx = idx
None
} else {
val vcfName = GenotypeFields.reverseAliases.getOrElse(gf.name, gf.name)
val typ = gf.dataType
Some((UTF8String.fromString(vcfName), (typ, idx)))
}
}
.toMap
}
private def set(row: InternalRow, idx: Int, value: Any): Unit = {
if (idx == -1) {
return
}
row.update(idx, value)
}
/**
* Converts a VCF line into an [[InternalRow]]
* @param line A text object containing the VCF line
* @return The converted row or null to represent a parsing failure or filtered row
*/
def convert(line: Text): InternalRow = {
var contigName: UTF8String = null
var start: Long = -1
var end: Long = -1
val row = new GenericInternalRow(schema.size)
set(row, splitFromMultiIdx, false)
// By default, FLAG fields should be false
flagFields.foreach {
case (key, (typ, idx)) =>
set(row, idx, false)
}
val ctx = new LineCtx(line)
if (ctx.isHeader) {
return null
}
contigName = ctx.parseString()
set(row, contigIdx, contigName)
ctx.expectTab()
start = ctx.parseLong() - 1
set(row, startIdx, start)
ctx.expectTab()
val names = ctx.parseStringArray()
set(row, namesIdx, ctx.toGenericArrayData(names))
ctx.expectTab()
val refAllele = ctx.parseString()
set(row, refAlleleIdx, refAllele)
end = start + refAllele.numChars()
set(row, endIdx, end)
ctx.expectTab()
val altAlleles = ctx.parseStringArray()
set(row, altAllelesIdx, ctx.toGenericArrayData(altAlleles))
ctx.expectTab()
val qual = ctx.parseDouble()
set(row, qualIdx, qual)
ctx.expectTab()
val filters = ctx.parseStringArray()
set(row, filtersIdx, ctx.toGenericArrayData(filters))
ctx.expectTab()
while (!ctx.isTab) {
val key = ctx.parseString('=', ';')
tryWithWarning(key, FieldTypes.INFO) {
ctx.eat('=')
if (key == UTF8String.fromString("END")) {
end = ctx.parseInfoVal(LongType).asInstanceOf[Long]
set(row, endIdx, end)
} else if (!infoFields.contains(key)) {
ctx.parseString(';')
} else {
val (typ, idx) = infoFields(key)
val value = ctx.parseInfoVal(typ)
set(row, idx, value)
}
}
ctx.eat(';')
}
if (overlapDetectorOpt.isDefined) {
val contigStr = if (contigName == null) null else contigName.toString
val interval = SimpleInterval(contigStr, start.toInt + 1, end.toInt)
if (!overlapDetectorOpt.get.overlapsAny(interval)) {
return null
}
}
if (genotypeFieldsOpt.isEmpty) {
return row
}
if (ctx.isLineEnd) {
return row
}
ctx.expectTab()
row.update(genotypesIdx, parseGenotypes(ctx))
row
}
/**
* Parses the genotypes from the format section of a VCF line.
*
* Basic flow:
* - Look at the format description block and map each field to an index in the genotype SQL schema and a data type
* - Iterate through the fields for each sample and update the genotype struct based on the stored indices and types
* @param ctx
* @return
*/
private def parseGenotypes(ctx: LineCtx): GenericArrayData = {
val gSchema = gSchemaOpt.get
val genotypeFields = genotypeFieldsOpt.get
val fieldNames = ctx.parseStringArray(':')
var gtIdx = -1
val typeAndIdx: Array[(DataType, Int)] = new Array[(DataType, Int)](fieldNames.length)
var i = 0
while (i < typeAndIdx.length) {
val name = fieldNames(i).asInstanceOf[UTF8String]
// GT maps to two fields, so special case it
if (name.toString == "GT") {
gtIdx = i
}
if (genotypeFields.contains(name)) {
typeAndIdx(i) = genotypeFields(name)
}
i += 1
}
ctx.expectTab()
var sampleIdx = 0
while (!ctx.isLineEnd && sampleIdx < genotypeHolder.length) {
val gRow = new GenericInternalRow(gSchema.size)
if (sampleIdIdx != -1) {
gRow.update(sampleIdIdx, sampleIds(sampleIdx))
}
var i = 0
while (!ctx.isTab && i < typeAndIdx.length) {
tryWithWarning(fieldNames(i).asInstanceOf[UTF8String], FieldTypes.FORMAT) {
if (i == gtIdx) {
ctx.parseCallsAndPhasing(gRow, phasedIdx, callsIdx)
} else if (typeAndIdx(i) == null) {
// Eat this value as a string since we don't need the parsed value
ctx.parseString(':')
} else {
val (typ, idx) = typeAndIdx(i)
val value = ctx.parseFormatVal(typ)
gRow.update(idx, value)
}
ctx.eat(':')
i += 1
}
}
ctx.eat('\t')
genotypeHolder(sampleIdx) = gRow
sampleIdx += 1
}
new GenericArrayData(genotypeHolder)
}
private def tryWithWarning(fieldName: UTF8String, fieldType: String)(f: => Unit): Unit = {
try {
f
} catch {
case NonFatal(ex) =>
raiseValidationError(
s"Could not parse $fieldType field ${fieldName.toString}. " +
s"Exception: ${ex.getMessage}",
ex
)
}
}
}
class LineCtx(text: Text) {
val line = text.getBytes
var pos = 0
var delimiter = '\0' // unset
val longWrapper = new LongWrapper()
val intWrapper = new IntWrapper()
val stringList = new util.ArrayList[UTF8String]()
val intList = new util.ArrayList[java.lang.Integer]()
val longList = new util.ArrayList[java.lang.Long]()
val doubleList = new util.ArrayList[java.lang.Double]()
def setInInfoVal(): Unit = {
delimiter = ';'
}
def setInFormatVal(): Unit = {
delimiter = ':'
}
def printRemaining(): Unit = {
println(new String(line.slice(pos, text.getLength), "UTF-8")) // scalastyle:ignore
}
def isHeader: Boolean = {
line.isEmpty || line(0) == '#'
}
def parseString(extraStopChar1: Byte = '\0', extraStopChar2: Byte = '\0'): UTF8String = {
var stop = pos
while (stop < text.getLength && line(stop) != delimiter && line(stop) != '\t' && line(stop) != extraStopChar1 && line(
stop) != extraStopChar2) {
stop += 1
}
if (stop - pos == 0) {
return null
}
val out = UTF8String.fromBytes(line, pos, stop - pos)
pos = stop
if (out == LineCtx.MISSING) {
null
} else {
out
}
}
def isTab: Boolean = {
pos >= text.getLength || line(pos) == '\t'
}
def isDelimiter: Boolean = {
pos >= text.getLength || line(pos) == delimiter
}
def isLineEnd: Boolean = {
pos >= text.getLength || line(pos) == '\n' || line(pos) == '\r'
}
def expectTab(): Unit = {
if (line(pos) != '\t') {
throw new IllegalStateException(s"Expected a tab at position $pos")
}
pos += 1
}
def eat(char: Byte): Unit = {
if (pos < text.getLength && line(pos) == char) {
pos += 1
}
}
def parseLong(): java.lang.Long = {
val s = parseString()
if (s == null) {
return null
}
require(s.toLong(longWrapper), s"Could not parse field as long")
longWrapper.value
}
def parseInt(
stopChar1: Byte = '\0',
stopChar2: Byte = '\0',
nullValue: java.lang.Integer = null): java.lang.Integer = {
val s = parseString(stopChar1, stopChar2)
if (s == null) {
return nullValue
}
require(s.toInt(intWrapper), s"Could not parse field as int")
intWrapper.value
}
def parseDouble(stopChar: Byte = '\0'): java.lang.Double = {
val utfStr = parseString(stopChar)
if (utfStr == null) {
return null
}
val s = utfStr.toLowerCase
if (s == LineCtx.NAN) {
Double.NaN
} else if (s == LineCtx.NEG_NAN) {
Double.NaN
} else if (s == LineCtx.POS_NAN) {
Double.NaN
} else if (s == LineCtx.INF) {
Double.PositiveInfinity
} else if (s == LineCtx.POS_INF) {
Double.PositiveInfinity
} else if (s == LineCtx.NEG_INF) {
Double.NegativeInfinity
} else {
s.toString.toDouble
}
}
def parseStringArray(sep: Byte = ','): Array[AnyRef] = {
stringList.clear()
while (!isLineEnd && !isTab && !isDelimiter) {
stringList.add(parseString(sep))
eat(sep)
}
stringList.toArray()
}
def parseIntArray(): Array[AnyRef] = {
intList.clear()
while (!isLineEnd && !isTab && !isDelimiter) {
intList.add(parseInt(','))
eat(',')
}
intList.toArray()
}
def parseDoubleArray(): Array[AnyRef] = {
doubleList.clear()
while (!isLineEnd && !isTab && !isDelimiter) {
doubleList.add(parseDouble(','))
eat(',')
}
doubleList.toArray()
}
def toGenericArrayData(arr: Array[_]): GenericArrayData = {
if (arr.length == 0) {
null
} else if (allNull(arr)) {
if (arr.length == 1) {
null
} else {
// This doesn't really make sense... We only represent a multiple element array with all nulls as an empty list
// for consistency with probabilities read by the BGEN reader.
// TODO(hhd): Fix the BGEN reader to properly insert nulls
new GenericArrayData(Array.empty[Any])
}
} else {
new GenericArrayData(arr.asInstanceOf[Array[Any]])
}
}
private def allNull(arr: Array[_]): Boolean = {
var i = 0
while (i < arr.length) {
if (arr(i) != null) {
return false
}
i += 1
}
true
}
def parseByType(typ: DataType): Any = {
if (typ == IntegerType) {
parseInt()
} else if (typ == LongType) {
parseLong()
} else if (typ == BooleanType) {
true
} else if (typ == StringType) {
parseString()
} else if (typ == DoubleType) {
parseDouble()
} else if (SQLUtils.dataTypesEqualExceptNullability(typ, ArrayType(StringType))) {
toGenericArrayData(parseStringArray())
} else if (SQLUtils.dataTypesEqualExceptNullability(typ, ArrayType(IntegerType))) {
toGenericArrayData(parseIntArray())
} else if (SQLUtils.dataTypesEqualExceptNullability(typ, ArrayType(DoubleType))) {
toGenericArrayData(parseDoubleArray())
} else if (typ.isInstanceOf[ArrayType] && typ
.asInstanceOf[ArrayType]
.elementType
.isInstanceOf[StructType]) {
val strings = parseStringArray()
val list = new util.ArrayList[String](strings.length)
var i = 0
while (i < strings.length) {
list.add(strings(i).toString)
i += 1
}
new GenericArrayData(
VariantContextToInternalRowConverter.getAnnotationArray(
list,
typ.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType]))
} else {
null
}
}
def parseInfoVal(typ: DataType): Any = {
setInInfoVal()
parseByType(typ)
}
def parseFormatVal(typ: DataType): Any = {
setInFormatVal()
parseByType(typ)
}
def parseCallsAndPhasing(row: GenericInternalRow, phasedIdx: Int, callsIdx: Int): Unit = {
setInFormatVal()
intList.clear()
val first = parseInt('|', '/', -1)
intList.add(first)
var phased = false
if (line(pos) == '|') {
phased = true
}
eat('/')
eat('|')
while (!isTab && !isDelimiter) {
intList.add(parseInt('|', '/', -1))
eat('|')
eat('/')
}
if (phasedIdx != -1) {
row.update(phasedIdx, phased)
}
if (callsIdx != -1) {
row.update(callsIdx, new GenericArrayData(intList.toArray().asInstanceOf[Array[Any]]))
}
}
}
object LineCtx {
val INF = UTF8String.fromString("inf")
val POS_INF = UTF8String.fromString("+inf")
val NEG_INF = UTF8String.fromString("-inf")
val NAN = UTF8String.fromString("nan")
val POS_NAN = UTF8String.fromString("+nan")
val NEG_NAN = UTF8String.fromString("-nan")
val END = UTF8String.fromString("END")
val MISSING = UTF8String.fromString(".")
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy