All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.projectglow.vcf.VCFWriterUtils.scala Maven / Gradle / Ivy

/*
 * Copyright 2019 The Glow Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.projectglow.vcf

import htsjdk.variant.variantcontext.{VariantContext, VariantContextBuilder}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{ArrayType, StructType}

import io.projectglow.common.GlowLogging

object VCFWriterUtils extends GlowLogging {

  def throwMixedSamplesFailure(): Unit = {
    throw new IllegalArgumentException("Cannot mix missing and non-missing sample IDs.")
  }

  def throwSampleInferenceFailure(): Unit = {
    throw new IllegalArgumentException(
      "Cannot infer sample ids because they are not the same in every row.")
  }

  /**
   * Infer sample IDs from a genomic DataFrame.
   *
   * - If there are no genotypes, there are no sample IDs.
   * - If there are genotypes and sample IDs are all...
   *   - Present, the sample IDs are found by unifying sample IDs across all rows.
   *   - Missing, checks that the rows have the same number of genotypes.
   */
  def inferSampleIdsIfPresent(data: DataFrame): SampleIdInfo = {
    val genotypeSchemaOpt = data
      .schema
      .find(_.name == "genotypes")
      .map(_.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType])
    if (genotypeSchemaOpt.isEmpty) {
      logger.info("No genotypes column, no sample IDs will be inferred.")
      return SampleIds(Seq.empty)
    }
    val genotypeSchema = genotypeSchemaOpt.get

    import data.sparkSession.implicits._
    val hasSampleIdsColumn = genotypeSchema.exists(_.name == "sampleId")

    if (hasSampleIdsColumn) {
      val distinctSampleIds = data
        .selectExpr("explode(genotypes.sampleId)")
        .distinct()
        .as[String]
        .collect
      val numPresentSampleIds = distinctSampleIds.count(!sampleIsMissing(_))

      if (numPresentSampleIds > 0) {
        if (numPresentSampleIds < distinctSampleIds.length) {
          throwMixedSamplesFailure()
        }
        return SampleIds(distinctSampleIds)
      }
    }

    val numGenotypesPerRow = data
      .selectExpr("size(genotypes)")
      .distinct()
      .as[Int]
      .collect

    if (numGenotypesPerRow.length > 1) {
      throw new IllegalArgumentException(
        "Rows contain varying number of missing samples; cannot infer sample IDs.")
    }

    logger.warn("Detected missing sample IDs, inferring sample IDs.")
    InferSampleIds
  }

  def sampleIsMissing(s: String): Boolean = {
    s == null || s.isEmpty
  }

  def convertVcAttributesToStrings(vc: VariantContext): VariantContextBuilder = {
    val vcBuilder = new VariantContextBuilder(vc)
    val iterator = vc.getAttributes.entrySet().iterator()
    while (iterator.hasNext) {
      // parse to string, then write, as the VCF encoder messes up double precisions
      val entry = iterator.next()
      vcBuilder.attribute(
        entry.getKey,
        VariantContextToInternalRowConverter.parseObjectAsString(entry.getValue))
    }
    vcBuilder
  }
}

case class SampleIds(unsortedSampleIds: Seq[String]) extends SampleIdInfo {
  val sortedSampleIds: Seq[String] = unsortedSampleIds.sorted
}
case object InferSampleIds extends SampleIdInfo {
  def fromNumberMissing(numMissingSamples: Int): Seq[String] = {
    (1 to numMissingSamples).map { idx =>
      "sample_" + idx
    }
  }
}

sealed trait SampleIdInfo




© 2015 - 2025 Weber Informatics LLC | Privacy Policy