All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.hammerlab.genomics.readsets.ReadSets.scala Maven / Gradle / Ivy

The newest version!
package org.hammerlab.genomics.readsets

import java.io.File

import grizzled.slf4j.Logging
import htsjdk.samtools.ValidationStringency
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.bdgenomics.adam.models.SequenceDictionary
import org.bdgenomics.adam.rdd.ADAMContext
import org.hammerlab.genomics.loci.set.LociSet
import org.hammerlab.genomics.reads.Read
import org.hammerlab.genomics.readsets.args.{ Base ⇒ BaseArgs }
import org.hammerlab.genomics.readsets.io.{ Input, InputConfig }
import org.hammerlab.genomics.readsets.rdd.ReadsRDD
import org.hammerlab.genomics.reference.{ ContigLengths, ContigName, Locus }
import org.seqdoop.hadoop_bam.util.SAMHeaderReader
import org.seqdoop.hadoop_bam.{ AnySAMInputFormat, BAMInputFormat, SAMRecordWritable }

import scala.collection.JavaConversions.seqAsJavaList


/**
 * A [[ReadSets]] contains reads from multiple inputs as well as [[SequenceDictionary]] / contig-length information
 * merged from them.
 */
case class ReadSets(readsRDDs: PerSample[ReadsRDD],
                    sequenceDictionary: SequenceDictionary,
                    contigLengths: ContigLengths)
  extends PerSample[ReadsRDD] {

  def inputs: PerSample[Input] = readsRDDs.map(_.input)

  def numSamples: NumSamples = readsRDDs.length
  def sampleNames: PerSample[String] = inputs.map(_.sampleName)

  override def length: NumSamples = readsRDDs.length
  override def apply(sampleId: SampleId): ReadsRDD = readsRDDs(sampleId)

  def sc = readsRDDs.head.reads.sparkContext

  lazy val mappedReadsRDDs = readsRDDs.map(_.mappedReads)

  lazy val allMappedReads = sc.union(mappedReadsRDDs).setName("unioned reads")

  lazy val sampleIdxKeyedMappedReads: RDD[SampleRead] =
    sc.union(
      for {
        (mappedReadsRDD, sampleId) ← mappedReadsRDDs.zipWithIndex
      } yield
        mappedReadsRDD.map(r ⇒ (sampleId → r): SampleRead)
    )
}

object ReadSets extends Logging {

  def apply(sc: SparkContext, args: BaseArgs): (ReadSets, LociSet) = {
    val config = args.parseConfig(sc.hadoopConfiguration)
    val readsets = apply(sc, args.inputs, config, !args.noSequenceDictionary)
    (readsets, LociSet(config.loci, readsets.contigLengths))
  }

  /**
    * Load reads from multiple files, merging their sequence dictionaries and verifying that they are consistent.
    */
  def apply(sc: SparkContext,
            inputs: PerSample[Input],
            config: InputConfig,
            contigLengthsFromDictionary: Boolean = true): ReadSets =
    apply(sc, inputs.map((_, config)), contigLengthsFromDictionary)

  /**
   * Load reads from multiple files, allowing different filters to be applied to each file.
   */
  def apply(sc: SparkContext,
            inputsAndFilters: PerSample[(Input, InputConfig)],
            contigLengthsFromDictionary: Boolean): ReadSets = {

    val (inputs, _) = inputsAndFilters.unzip

    val (readsRDDs, sequenceDictionaries) =
      (for {
        (Input(sampleId, _, filename), config) <- inputsAndFilters
      } yield
        load(filename, sc, sampleId, config)
      ).unzip

    val sequenceDictionary = mergeSequenceDictionaries(inputs, sequenceDictionaries)

    val contigLengths: ContigLengths =
      if (contigLengthsFromDictionary)
        getContigLengthsFromSequenceDictionary(sequenceDictionary)
      else
        sc.union(readsRDDs)
          .flatMap(_.asMappedRead)
          .map(read => read.contigName -> read.end)
          .reduceByKey(_ max _)
          .collectAsMap()
          .toMap

    ReadSets(
      (for {
        (reads, input) <- readsRDDs.zip(inputs)
      } yield
       ReadsRDD(reads, input)
      )
      .toVector,
      sequenceDictionary,
      contigLengths
    )
  }

  def apply(readsRDDs: PerSample[ReadsRDD], sequenceDictionary: SequenceDictionary): ReadSets =
    ReadSets(
      readsRDDs,
      sequenceDictionary,
      getContigLengths(sequenceDictionary)
    )

  /**
   * Given a filename and a spark context, return a pair (RDD, SequenceDictionary), where the first element is an RDD
   * of Reads, and the second element is the Sequence Dictionary giving info (e.g. length) about the contigs in the BAM.
   *
   * @param filename name of file containing reads
   * @param sc spark context
   * @param config config to apply
   * @return
   */
  private[readsets] def load(filename: String,
                             sc: SparkContext,
                             sampleId: Int,
                             config: InputConfig): (RDD[Read], SequenceDictionary) = {

    val (allReads, sequenceDictionary) =
      if (filename.endsWith(".bam") || filename.endsWith(".sam"))
        loadFromBAM(filename, sc, sampleId, config)
      else
        loadFromADAM(filename, sc, sampleId, config)

    val reads = filterRDD(allReads, config, sequenceDictionary)

    (reads, sequenceDictionary)
  }

  /** Returns an RDD of Reads and SequenceDictionary from reads in BAM format **/
  private def loadFromBAM(filename: String,
                          sc: SparkContext,
                          sampleId: Int,
                          config: InputConfig): (RDD[Read], SequenceDictionary) = {

    val path = new Path(filename)

    val basename = new File(filename).getName
    val shortName = basename.substring(0, math.min(basename.length, 100))

    val conf = sc.hadoopConfiguration
    val samHeader = SAMHeaderReader.readSAMHeaderFrom(path, conf)
    val sequenceDictionary = SequenceDictionary(samHeader)

    config
      .maxSplitSizeOpt
      .foreach(
        maxSplitSize =>
          conf.set(FileInputFormat.SPLIT_MAXSIZE, maxSplitSize.toString)
      )

    config
      .overlapsLociOpt
      .fold(conf.unset(BAMInputFormat.INTERVALS_PROPERTY)) (
        overlapsLoci =>
          if (filename.endsWith(".bam")) {
            val contigLengths = getContigLengths(sequenceDictionary)

            val bamIndexIntervals =
              LociSet(
                overlapsLoci,
                contigLengths
              ).toHtsJDKIntervals

            BAMInputFormat.setIntervals(conf, bamIndexIntervals)
          } else if (filename.endsWith(".sam")) {
            warn(s"Loading SAM file: $filename with intervals specified. This requires parsing the entire file.")
          } else {
            throw new IllegalArgumentException(s"File $filename is not a BAM or SAM file")
          }
      )

    val reads: RDD[Read] =
      sc
        .newAPIHadoopFile[LongWritable, SAMRecordWritable, AnySAMInputFormat](filename)
        .setName(s"Hadoop file: $shortName")
        .values
        .setName(s"Hadoop reads: $shortName")
        .map(r => Read(r.get))
        .setName(s"Guac reads: $shortName")

    (reads, sequenceDictionary)
  }

  /** Returns an RDD of Reads and SequenceDictionary from reads in ADAM format **/
  private def loadFromADAM(filename: String,
                           sc: SparkContext,
                           sampleId: Int,
                           config: InputConfig): (RDD[Read], SequenceDictionary) = {

    logger.info(s"Using ADAM to read: $filename")

    val adamContext: ADAMContext = sc

    val alignmentRDD =
      adamContext.loadAlignments(filename, projection = None, stringency = ValidationStringency.LENIENT)

    val sequenceDictionary = alignmentRDD.sequences

    (alignmentRDD.rdd.map(Read(_, sampleId)), sequenceDictionary)
  }


  /** Extract the length of each contig from a sequence dictionary */
  private def getContigLengths(sequenceDictionary: SequenceDictionary): ContigLengths = {
    val builder = Map.newBuilder[ContigName, Locus]
    sequenceDictionary.records.foreach(record => builder += ((record.name.toString, record.length)))
    builder.result
  }

  /**
   * SequenceDictionaries store information about the contigs that will be found in a given set of reads: names,
   * lengths, etc.
   *
   * When loading/manipulating multiple sets of reads, we generally want to understand the set of all contigs that
   * are referenced by the reads, perform some consistency-checking (e.g. verifying that each contig is listed as having
   * the same length in each set of reads in which it appears), and finally pass the downstream user a
   * SequenceDictionary that encapsulates all of this.
   *
   * This function performs all of the above.
   *
   * @param inputs Input files, each containing a set of reads.
   * @param dicts SequenceDictionaries that have been parsed from @filenames.
   * @return a SequenceDictionary that has been merged and validated from the inputs.
   */
  private[readsets] def mergeSequenceDictionaries(inputs: Seq[Input],
                                                  dicts: Seq[SequenceDictionary]): SequenceDictionary = {
    val records =
      (for {
        (input, dict) <- inputs.zip(dicts)
        record <- dict.records
      } yield {
        input -> record
      })
      .groupBy(_._2.name)
      .values
      .map(values => {
        val (input, record) = values.head

        // Verify that all records for a given contig are equal.
        values.tail.toList.filter(_._2 != record) match {
          case Nil =>
          case mismatched =>
            throw new IllegalArgumentException(
              (
                s"Conflicting sequence records for ${record.name}:" ::
                s"${input.path}: $record" ::
                mismatched.map { case (otherFile, otherRecord) => s"$otherFile: $otherRecord" }
              ).mkString("\n\t")
            )
        }

        record
      })

    new SequenceDictionary(records.toVector).sorted
  }

  /**
   * Apply filters to an RDD of reads.
   */
  private def filterRDD(reads: RDD[Read], config: InputConfig, sequenceDictionary: SequenceDictionary): RDD[Read] = {
    /* Note that the InputFilter properties are public, and some loaders directly apply
     * the filters as the reads are loaded, instead of filtering an existing RDD as we do here. If more filters
     * are added, be sure to update those implementations.
     *
     * This is implemented as a static function instead of a method in InputConfig because the overlapsLoci
     * attribute cannot be serialized.
     */
    var result = reads
    config
      .overlapsLociOpt
      .foreach(overlapsLoci => {
        val contigLengths = getContigLengths(sequenceDictionary)
        val loci = LociSet(overlapsLoci, contigLengths)
        val broadcastLoci = reads.sparkContext.broadcast(loci)
        result = result.filter(_.asMappedRead.exists(broadcastLoci.value.intersects))
      })

    if (config.nonDuplicate) result = result.filter(!_.isDuplicate)
    if (config.passedVendorQualityChecks) result = result.filter(!_.failedVendorQualityChecks)
    if (config.isPaired) result = result.filter(_.isPaired)

    config.minAlignmentQualityOpt.foreach(
      minAlignmentQuality =>
        result =
          result.filter(
            _.asMappedRead
             .forall(_.alignmentQuality >= minAlignmentQuality)
          )
    )

    result
  }

  /**
    * Construct a map from contig name -> length of contig, using a SequenceDictionary.
    */
  private def getContigLengthsFromSequenceDictionary(sequenceDictionary: SequenceDictionary): ContigLengths = {
    val builder = Map.newBuilder[ContigName, Locus]
    for {
      record <- sequenceDictionary.records
    } {
      builder += ((record.name.toString, record.length))
    }
    builder.result
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy