All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.google.cloud.genomics.dataflow.functions.WriteBAMFn Maven / Gradle / Ivy

There is a newer version: v1-0.8
Show newest version
/*
 * Copyright (C) 2015 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package com.google.cloud.genomics.dataflow.functions;

import com.google.api.services.storage.Storage;
import com.google.cloud.dataflow.sdk.transforms.Aggregator;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.Max;
import com.google.cloud.dataflow.sdk.transforms.Min;
import com.google.cloud.dataflow.sdk.transforms.Sum;
import com.google.cloud.dataflow.sdk.transforms.Sum.SumIntegerFn;
import com.google.cloud.dataflow.sdk.util.GcsUtil;
import com.google.cloud.dataflow.sdk.util.Transport;
import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.dataflow.sdk.values.PCollectionView;
import com.google.cloud.dataflow.sdk.values.TupleTag;
import com.google.cloud.genomics.dataflow.readers.bam.BAMIO;
import com.google.cloud.genomics.dataflow.readers.bam.HeaderInfo;
import com.google.cloud.genomics.dataflow.utils.GCSOptions;
import com.google.cloud.genomics.dataflow.utils.GCSOutputOptions;
import com.google.cloud.genomics.dataflow.utils.TruncatedOutputStream;
import com.google.cloud.genomics.utils.Contig;
import com.google.cloud.genomics.utils.grpc.ReadUtils;
import com.google.common.base.Stopwatch;
import com.google.genomics.v1.Read;

import htsjdk.samtools.BAMBlockWriter;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.util.BlockCompressedStreamConstants;

import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.Channels;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;

/*
 * Writes a PCollection of Reads to a BAM file.
 * Assumes sharded execution and writes each bundle as a separate BAM file, outputting
 * its name at the end of the bundle.
 */
public class WriteBAMFn extends DoFn {

  public static interface Options extends GCSOutputOptions {}

  private static final Logger LOG = Logger.getLogger(WriteBAMFn.class.getName());
  public static TupleTag WRITTEN_BAM_NAMES_TAG = new TupleTag(){};
  public static TupleTag> SEQUENCE_SHARD_SIZES_TAG = new TupleTag>(){};

  final PCollectionView headerView;
  Storage.Objects storage;
  Aggregator readCountAggregator;
  Aggregator unmappedReadCountAggregator;
  Aggregator initializedShardCount;
  Aggregator finishedShardCount;
  Aggregator shardTimeMaxSec;
  Aggregator shardReadCountMax;
  Aggregator shardReadCountMin;
  Aggregator outOfOrderCount;
  
  Stopwatch stopWatch;
  int readCount;
  int unmappedReadCount;
  String shardName;
  TruncatedOutputStream ts;
  BAMBlockWriter bw;
  Contig shardContig;
  Options options;
  HeaderInfo headerInfo;
  int sequenceIndex;
  
  SAMRecord prevRead = null;
  long minAlignment = Long.MAX_VALUE;
  long maxAlignment = Long.MIN_VALUE;
  boolean hadOutOfOrder = false;

  public WriteBAMFn(final PCollectionView headerView) {
    this.headerView = headerView;
    readCountAggregator = createAggregator("Written reads", new SumIntegerFn());
    unmappedReadCountAggregator = createAggregator("Written unmapped reads", new SumIntegerFn());
    initializedShardCount = createAggregator("Initialized Write Shard Count", new Sum.SumIntegerFn());
    finishedShardCount = createAggregator("Finished Write Shard Count", new Sum.SumIntegerFn());
    shardTimeMaxSec = createAggregator("Maximum Write Shard Processing Time (sec)", new Max.MaxLongFn());
    shardReadCountMax = createAggregator("Maximum Reads Per Shard", new Max.MaxLongFn());
    shardReadCountMin = createAggregator("Minimum Reads Per Shard", new Min.MinLongFn());
    outOfOrderCount  = createAggregator("Out of order reads",  new Sum.SumIntegerFn());
  }
  
  @Override
  public void startBundle(DoFn.Context c) throws IOException {
    LOG.info("Starting bundle ");
    storage = Transport.newStorageClient(c.getPipelineOptions().as(GCSOptions.class)).build().objects();
    
    initializedShardCount.addValue(1);
    stopWatch = Stopwatch.createStarted();
    
    options = c.getPipelineOptions().as(Options.class);
    
    readCount = 0;
    unmappedReadCount = 0;
    headerInfo = null;
    prevRead = null;
    minAlignment = Long.MAX_VALUE;
    maxAlignment = Long.MIN_VALUE;
    hadOutOfOrder = false;
  }
  
  @Override
  public void finishBundle(DoFn.Context c) throws IOException {
    bw.close();
    shardTimeMaxSec.addValue(stopWatch.elapsed(TimeUnit.SECONDS));
    LOG.info("Finished writing " + shardContig);
    finishedShardCount.addValue(1);
    final long bytesWritten = ts.getBytesWrittenExceptingTruncation();
    LOG.info("Wrote " + readCount + " reads, " + unmappedReadCount + " unmapped, into " + shardName + 
        (hadOutOfOrder ? "ignored out of order" : "") + ", wrote " + bytesWritten + " bytes");
    readCountAggregator.addValue(readCount);
    unmappedReadCountAggregator.addValue(unmappedReadCount);
    final long totalReadCount = (long)readCount + (long)unmappedReadCount;
    shardReadCountMax.addValue(totalReadCount);
    shardReadCountMin.addValue(totalReadCount);
    c.output(shardName);
    c.sideOutput(SEQUENCE_SHARD_SIZES_TAG, KV.of(sequenceIndex, bytesWritten));
  }
  
  @Override
  public void processElement(DoFn.ProcessContext c)
      throws Exception {
   
    if (headerInfo == null) {
      headerInfo = c.sideInput(headerView);
    }
    final Read read = c.element();
    
    if (readCount == 0) {
      
      shardContig = KeyReadsFn.shardKeyForRead(read, 1);
      sequenceIndex = headerInfo.header.getSequenceIndex(shardContig.referenceName);
      final boolean isFirstShard = headerInfo.shardHasFirstRead(shardContig);
      final String outputFileName = options.getOutput();
      shardName = outputFileName + "-" + String.format("%012d", sequenceIndex) + "-"
          + shardContig.referenceName
          + ":" + String.format("%012d", shardContig.start);
      LOG.info("Writing shard file " + shardName);
      final OutputStream outputStream = 
          Channels.newOutputStream(
              new GcsUtil.GcsUtilFactory().create(options)
                .create(GcsPath.fromUri(shardName), 
                    BAMIO.BAM_INDEX_FILE_MIME_TYPE));
      ts = new TruncatedOutputStream(
          outputStream, BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK.length);
      bw = new BAMBlockWriter(ts, null /*file*/);
      bw.setSortOrder(headerInfo.header.getSortOrder(), true);
      bw.setHeader(headerInfo.header);
      if (isFirstShard) {
        LOG.info("First shard - writing header to " + shardName);
        bw.writeHeader(headerInfo.header);
      }
    }
    SAMRecord samRecord = ReadUtils.makeSAMRecord(read, headerInfo.header);
    if (prevRead != null && prevRead.getAlignmentStart() > samRecord.getAlignmentStart()) {
      LOG.info("Out of order read " + prevRead.getAlignmentStart() + " " + 
          samRecord.getAlignmentStart() + " during writing of shard " + shardName + 
          " after processing " + readCount + " reads, min seen alignment is " + 
          minAlignment + " and max is " + maxAlignment + ", this read is " + 
          (samRecord.getReadUnmappedFlag() ? "unmapped" : "mapped") + " and its mate is " + 
          (samRecord.getMateUnmappedFlag() ? "unmapped" : "mapped"));
      outOfOrderCount.addValue(1);
      readCount++;
      hadOutOfOrder = true;
      return;
    }
    minAlignment = Math.min(minAlignment, samRecord.getAlignmentStart());
    maxAlignment = Math.max(maxAlignment, samRecord.getAlignmentStart());
    prevRead = samRecord;
    if (samRecord.getReadUnmappedFlag()) {
      if (!samRecord.getMateUnmappedFlag()) {
        samRecord.setReferenceName(samRecord.getMateReferenceName());
        samRecord.setAlignmentStart(samRecord.getMateAlignmentStart());
      }
      unmappedReadCount++;
    }
    bw.addAlignment(samRecord);
    readCount++;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy