All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.google.cloud.genomics.dataflow.writers.bam.WriteBAIFn Maven / Gradle / Ivy

/*
 * Copyright (C) 2016 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package com.google.cloud.genomics.dataflow.writers.bam;

import com.google.api.services.storage.Storage;
import com.google.cloud.dataflow.sdk.transforms.Aggregator;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.Max;
import com.google.cloud.dataflow.sdk.transforms.Min;
import com.google.cloud.dataflow.sdk.transforms.Sum;
import com.google.cloud.dataflow.sdk.util.GcsUtil;
import com.google.cloud.dataflow.sdk.util.Transport;
import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.dataflow.sdk.values.PCollectionView;
import com.google.cloud.dataflow.sdk.values.TupleTag;
import com.google.cloud.genomics.dataflow.readers.bam.BAMIO;
import com.google.cloud.genomics.dataflow.readers.bam.HeaderInfo;
import com.google.cloud.genomics.dataflow.utils.GCSOptions;
import com.google.cloud.genomics.dataflow.utils.GCSOutputOptions;
import com.google.common.base.Stopwatch;

import htsjdk.samtools.BAMShardIndexer;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMRecordIterator;
import htsjdk.samtools.SamReader;
import htsjdk.samtools.ValidationStringency;

import java.io.OutputStream;
import java.nio.channels.Channels;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;

/**
 * Writes a shard of BAM index (BAI file).
 * The input is a reference name to process and the output is the name of the file written.
 * A side output under the tag NO_COORD_READS_COUNT_TAG is a single value of the number of
 * no-coordinate reads in this shard. This is needed since the final index has to include a
 * total number summed up from the shards.
 */
public class WriteBAIFn extends DoFn {
  private static final Logger LOG = Logger.getLogger(WriteBAIFn.class.getName());

  public static interface Options extends GCSOutputOptions {}


  public static TupleTag NO_COORD_READS_COUNT_TAG = new TupleTag(){};
  public static TupleTag WRITTEN_BAI_NAMES_TAG = new TupleTag(){};

  PCollectionView writtenBAMFilerView;
  PCollectionView headerView;
  PCollectionView>> sequenceShardSizesView;

  Aggregator readCountAggregator;
  Aggregator noCoordReadCountAggregator;
  Aggregator initializedShardCount;
  Aggregator finishedShardCount;
  Aggregator shardTimeMaxSec;
  Aggregator shardTimeMinSec;
  Aggregator shardReadCountMax;
  Aggregator shardReadCountMin;

  public WriteBAIFn(PCollectionView headerView,
      PCollectionView writtenBAMFilerView,
      PCollectionView>> sequenceShardSizesView) {
    this.writtenBAMFilerView = writtenBAMFilerView;
    this.headerView = headerView;
    this.sequenceShardSizesView = sequenceShardSizesView;

    readCountAggregator = createAggregator("Indexed reads", new Sum.SumLongFn());
    noCoordReadCountAggregator = createAggregator("Indexed no coordinate reads", new Sum.SumLongFn());
    initializedShardCount = createAggregator("Initialized Indexing Shard Count", new Sum.SumIntegerFn());
    finishedShardCount = createAggregator("Finished Indexing Shard Count", new Sum.SumIntegerFn());
    shardTimeMaxSec = createAggregator("Maximum Indexing Shard Processing Time (sec)", new Max.MaxLongFn());
    shardTimeMinSec = createAggregator("Minimum Indexing Shard Processing Time (sec)", new Min.MinLongFn());
    shardReadCountMax = createAggregator("Maximum Reads Per Indexing Shard", new Max.MaxLongFn());
    shardReadCountMin = createAggregator("Minimum Reads Per Indexing Shard", new Min.MinLongFn());
  }

  @Override
  public void processElement(DoFn.ProcessContext c) throws Exception {
    initializedShardCount.addValue(1);
    Stopwatch stopWatch = Stopwatch.createStarted();

    final HeaderInfo header = c.sideInput(headerView);
    final String bamFilePath = c.sideInput(writtenBAMFilerView);
    final Iterable> sequenceShardSizes = c.sideInput(sequenceShardSizesView);

    final String sequenceName = c.element();
    final int sequenceIndex = header.header.getSequence(sequenceName).getSequenceIndex();
    final String baiFilePath = bamFilePath + "-" +
        String.format("%02d",sequenceIndex) + "-" + sequenceName + ".bai";

    long offset = 0;
    int skippedReferences  = 0;
    long bytesToProcess = 0;

    for (KV sequenceShardSize : sequenceShardSizes) {
      if (sequenceShardSize.getKey() < sequenceIndex) {
        offset += sequenceShardSize.getValue();
        skippedReferences++;
      } else if (sequenceShardSize.getKey() == sequenceIndex) {
        bytesToProcess = sequenceShardSize.getValue();
      }
    }
    LOG.info("Generating BAI index: " + baiFilePath);
    LOG.info("Reading BAM file: " + bamFilePath + " for reference " + sequenceName +
        ", skipping " + skippedReferences + " references at offset " + offset +
        ", expecting to process " + bytesToProcess + " bytes");

    Options options = c.getPipelineOptions().as(Options.class);
    final Storage.Objects storage = Transport.newStorageClient(
        options
          .as(GCSOptions.class))
          .build()
          .objects();
    final SamReader reader = BAMIO.openBAM(storage, bamFilePath, ValidationStringency.SILENT, true,
        offset);
    final OutputStream outputStream =
        Channels.newOutputStream(
            new GcsUtil.GcsUtilFactory().create(options)
              .create(GcsPath.fromUri(baiFilePath),
                  BAMIO.BAM_INDEX_FILE_MIME_TYPE));
    final BAMShardIndexer indexer = new BAMShardIndexer(outputStream, reader.getFileHeader(), sequenceIndex);

    long processedReads = 0;
    long skippedReads = 0;

    // create and write the content
    if (bytesToProcess > 0) {
      SAMRecordIterator it = reader.iterator();
      boolean foundRecords = false;
      while (it.hasNext()) {
        SAMRecord r = it.next();
        if (!r.getReferenceName().equals(sequenceName)) {
          if (foundRecords) {
            LOG.info("Finishing index building for " + sequenceName + " after processing " + processedReads);
            break;
          }
          skippedReads++;
          continue;
        } else if (!foundRecords) {
          LOG.info("Found records for refrence " + sequenceName + " after skipping " + skippedReads);
          foundRecords = true;
        }
        indexer.processAlignment(r);
        processedReads++;
      }
      it.close();
    } else {
      LOG.info("No records for refrence " + sequenceName + ": writing empty index ");
    }
    long noCoordinateReads = indexer.finish();
    c.output(baiFilePath);
    c.sideOutput(NO_COORD_READS_COUNT_TAG, noCoordinateReads);
    LOG.info("Generated " + baiFilePath + ", " + processedReads + " reads, " +
        noCoordinateReads + " no coordinate reads, " + skippedReads + ", skipped reads");
    stopWatch.stop();
    shardTimeMaxSec.addValue(stopWatch.elapsed(TimeUnit.SECONDS));
    shardTimeMinSec.addValue(stopWatch.elapsed(TimeUnit.SECONDS));
    finishedShardCount.addValue(1);
    readCountAggregator.addValue(processedReads);
    noCoordReadCountAggregator.addValue(noCoordinateReads);
    shardReadCountMax.addValue(processedReads);
    shardReadCountMin.addValue(processedReads);
  }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy