All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.google.cloud.genomics.dataflow.writers.WriteBAMTransform Maven / Gradle / Ivy

/*
 * Copyright (C) 2015 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package com.google.cloud.genomics.dataflow.writers;


import com.google.cloud.dataflow.sdk.Pipeline;
import com.google.cloud.dataflow.sdk.coders.Coder;
import com.google.cloud.dataflow.sdk.coders.DelegateCoder;
import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
import com.google.cloud.dataflow.sdk.transforms.Combine;
import com.google.cloud.dataflow.sdk.transforms.Create;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.Flatten;
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.transforms.ParDo;
import com.google.cloud.dataflow.sdk.transforms.Sum;
import com.google.cloud.dataflow.sdk.transforms.View;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.dataflow.sdk.values.PCollection;
import com.google.cloud.dataflow.sdk.values.PCollectionList;
import com.google.cloud.dataflow.sdk.values.PCollectionTuple;
import com.google.cloud.dataflow.sdk.values.PCollectionView;
import com.google.cloud.dataflow.sdk.values.TupleTag;
import com.google.cloud.dataflow.sdk.values.TupleTagList;
import com.google.cloud.genomics.dataflow.functions.CombineShardsFn;
import com.google.cloud.genomics.dataflow.functions.GetReferencesFromHeaderFn;
import com.google.cloud.genomics.dataflow.functions.WriteBAIFn;
import com.google.cloud.genomics.dataflow.functions.WriteBAMFn;
import com.google.cloud.genomics.dataflow.readers.bam.HeaderInfo;
import com.google.cloud.genomics.dataflow.utils.BreakFusionTransform;
import com.google.cloud.genomics.utils.Contig;
import com.google.genomics.v1.Read;

import htsjdk.samtools.SAMTextHeaderCodec;
import htsjdk.samtools.ValidationStringency;
import htsjdk.samtools.util.BlockCompressedStreamConstants;
import htsjdk.samtools.util.StringLineReader;

import java.io.StringWriter;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;

/*
 * Writes sets of reads to BAM files in parallel, then combines the files and writes an index
 * for the combined file.
 */
public class WriteBAMTransform extends PTransform> {
  
  public static interface Options extends WriteBAMFn.Options {}

  public static TupleTag SHARDED_READS_TAG = new TupleTag(){};
  public static TupleTag HEADER_TAG = new TupleTag(){};
  
  private String output;
  private Pipeline pipeline;

  @Override
  public PCollection apply(PCollectionTuple tuple) {
    final PCollection header = tuple.get(HEADER_TAG);
    final PCollectionView headerView =
        header.apply(View.asSingleton());
    
    final PCollection shardedReads = tuple.get(SHARDED_READS_TAG);
    
    final PCollectionTuple writeBAMFilesResult = 
        shardedReads.apply(ParDo.named("Write BAM shards")
          .withSideInputs(Arrays.asList(headerView))
          .withOutputTags(WriteBAMFn.WRITTEN_BAM_NAMES_TAG, TupleTagList.of(WriteBAMFn.SEQUENCE_SHARD_SIZES_TAG))
          .of(new WriteBAMFn(headerView)));
    
    PCollection writtenBAMShardNames = writeBAMFilesResult.get(WriteBAMFn.WRITTEN_BAM_NAMES_TAG);
    final PCollectionView> writtenBAMShardsView = 
        writtenBAMShardNames.apply(View.asIterable());
    
    final PCollection> sequenceShardSizes = writeBAMFilesResult.get(WriteBAMFn.SEQUENCE_SHARD_SIZES_TAG);
    final PCollection> sequenceShardSizesCombined = sequenceShardSizes.apply(
        Combine.perKey(
            new Sum.SumLongFn()));
    final PCollectionView>> sequenceShardSizesView = 
        sequenceShardSizesCombined.apply(View.>asIterable());
    
    final PCollection destinationBAMPath = this.pipeline.apply(
        Create.of(this.output));
    
    final PCollectionView eofForBAM = pipeline.apply(
        Create.of(BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK))
        .apply(View.asSingleton());
        
    final PCollection writtenBAMFile = destinationBAMPath.apply(
        ParDo.named("Combine BAM shards")
          .withSideInputs(writtenBAMShardsView, eofForBAM)
          .of(new CombineShardsFn(writtenBAMShardsView, eofForBAM)));
    
    final PCollectionView writtenBAMFileView =
        writtenBAMFile.apply(View.asSingleton());
    
    final PCollection indexShards = header.apply(
        ParDo.named("Generate index shard tasks")
        .of(new GetReferencesFromHeaderFn()));
    
    final PCollectionTuple indexingResult = indexShards
        .apply(new BreakFusionTransform())
        .apply(
          ParDo.named("Write index shards")
            .withSideInputs(headerView, writtenBAMFileView, sequenceShardSizesView)
            .withOutputTags(WriteBAIFn.WRITTEN_BAI_NAMES_TAG, 
                TupleTagList.of(WriteBAIFn.NO_COORD_READS_COUNT_TAG))
            .of(new WriteBAIFn(headerView, writtenBAMFileView, sequenceShardSizesView)));
    
    final PCollection writtenBAIShardNames = indexingResult.get(WriteBAIFn.WRITTEN_BAI_NAMES_TAG);
    final PCollectionView> writtenBAIShardsView = 
        writtenBAIShardNames.apply(View.asIterable());
    
    final PCollection noCoordCounts = indexingResult.get(WriteBAIFn.NO_COORD_READS_COUNT_TAG);
    
    final PCollection totalNoCoordCount = noCoordCounts
          .apply(new BreakFusionTransform())
          .apply(
              Combine.globally(new Sum.SumLongFn()));
    
    final PCollection totalNoCoordCountBytes = totalNoCoordCount.apply(
        ParDo.named("No coord count to bytes").of(new Long2BytesFn()));
    final PCollectionView eofForBAI = totalNoCoordCountBytes
        .apply(View.asSingleton());
    
    final PCollection destinationBAIPath = this.pipeline.apply(
        Create.of(this.output + ".bai"));
    
    final PCollection writtenBAIFile = destinationBAIPath.apply(
        ParDo.named("Combine BAI shards")
          .withSideInputs(writtenBAIShardsView, eofForBAI)
          .of(new CombineShardsFn(writtenBAIShardsView, eofForBAI)));
    
    final PCollection writtenFileNames = PCollectionList.of(writtenBAMFile).and(writtenBAIFile)
        .apply(Flatten.pCollections());
        
    return writtenFileNames;
  }
  
  /**
   * Transforms a long value to bytes (little endian order).
   * Used for transforming the no-coord. read count into bytes for writing in 
   * the footer of the BAI file.
   */
  static class Long2BytesFn extends DoFn {
    public Long2BytesFn() {
    }
    
    @Override
    public void processElement(DoFn.ProcessContext c) throws Exception {
      ByteBuffer b = ByteBuffer.allocate(8);
      b.order(ByteOrder.LITTLE_ENDIAN); 
      b.putLong(c.element());
      c.output(b.array());
    }
  }
  
  private WriteBAMTransform(String output, Pipeline pipeline) {
    this.output = output;
    this.pipeline = pipeline;
  }
  
  public static PCollection write(PCollection shardedReads, HeaderInfo headerInfo,
      String output, Pipeline pipeline) {
    final PCollectionTuple tuple = PCollectionTuple
        .of(SHARDED_READS_TAG,shardedReads)
        .and(HEADER_TAG, pipeline.apply(Create.of(headerInfo).withCoder(HEADER_INFO_CODER)));
    return (new WriteBAMTransform(output, pipeline)).apply(tuple);
  }
  
  static Coder HEADER_INFO_CODER = DelegateCoder.of(
      StringUtf8Coder.of(),
      new DelegateCoder.CodingFunction() {
        @Override
        public String apply(HeaderInfo info) throws Exception {
          final StringWriter stringWriter = new StringWriter();
          SAM_HEADER_CODEC.encode(stringWriter, info.header);
          return info.firstRead.toString() + "\n" + stringWriter.toString();
        }
      },
      new DelegateCoder.CodingFunction() {
        @Override
        public HeaderInfo apply(String str) throws Exception {
          int newLinePos = str.indexOf("\n");
          String contigStr = str.substring(0, newLinePos);
          String headerStr = str.substring(newLinePos + 1);
          return new HeaderInfo(
              SAM_HEADER_CODEC.decode(new StringLineReader(headerStr), 
                  "HEADER_INFO_CODER"),
              Contig.parseContigsFromCommandLine(contigStr).iterator().next());
        }
      });
  
  static final SAMTextHeaderCodec SAM_HEADER_CODEC = new SAMTextHeaderCodec();
  static {
    SAM_HEADER_CODEC.setValidationStringency(ValidationStringency.SILENT);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy