com.google.cloud.genomics.dataflow.writers.WriteBAMTransform Maven / Gradle / Ivy
/*
* Copyright (C) 2015 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.google.cloud.genomics.dataflow.writers;
import com.google.cloud.dataflow.sdk.Pipeline;
import com.google.cloud.dataflow.sdk.coders.Coder;
import com.google.cloud.dataflow.sdk.coders.DelegateCoder;
import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
import com.google.cloud.dataflow.sdk.transforms.Combine;
import com.google.cloud.dataflow.sdk.transforms.Create;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.Flatten;
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.transforms.ParDo;
import com.google.cloud.dataflow.sdk.transforms.Sum;
import com.google.cloud.dataflow.sdk.transforms.View;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.dataflow.sdk.values.PCollection;
import com.google.cloud.dataflow.sdk.values.PCollectionList;
import com.google.cloud.dataflow.sdk.values.PCollectionTuple;
import com.google.cloud.dataflow.sdk.values.PCollectionView;
import com.google.cloud.dataflow.sdk.values.TupleTag;
import com.google.cloud.dataflow.sdk.values.TupleTagList;
import com.google.cloud.genomics.dataflow.functions.CombineShardsFn;
import com.google.cloud.genomics.dataflow.functions.GetReferencesFromHeaderFn;
import com.google.cloud.genomics.dataflow.functions.WriteBAIFn;
import com.google.cloud.genomics.dataflow.functions.WriteBAMFn;
import com.google.cloud.genomics.dataflow.readers.bam.HeaderInfo;
import com.google.cloud.genomics.dataflow.utils.BreakFusionTransform;
import com.google.cloud.genomics.utils.Contig;
import com.google.genomics.v1.Read;
import htsjdk.samtools.SAMTextHeaderCodec;
import htsjdk.samtools.ValidationStringency;
import htsjdk.samtools.util.BlockCompressedStreamConstants;
import htsjdk.samtools.util.StringLineReader;
import java.io.StringWriter;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
/*
* Writes sets of reads to BAM files in parallel, then combines the files and writes an index
* for the combined file.
*/
public class WriteBAMTransform extends PTransform> {
public static interface Options extends WriteBAMFn.Options {}
public static TupleTag SHARDED_READS_TAG = new TupleTag(){};
public static TupleTag HEADER_TAG = new TupleTag(){};
private String output;
private Pipeline pipeline;
@Override
public PCollection apply(PCollectionTuple tuple) {
final PCollection header = tuple.get(HEADER_TAG);
final PCollectionView headerView =
header.apply(View.asSingleton());
final PCollection shardedReads = tuple.get(SHARDED_READS_TAG);
final PCollectionTuple writeBAMFilesResult =
shardedReads.apply(ParDo.named("Write BAM shards")
.withSideInputs(Arrays.asList(headerView))
.withOutputTags(WriteBAMFn.WRITTEN_BAM_NAMES_TAG, TupleTagList.of(WriteBAMFn.SEQUENCE_SHARD_SIZES_TAG))
.of(new WriteBAMFn(headerView)));
PCollection writtenBAMShardNames = writeBAMFilesResult.get(WriteBAMFn.WRITTEN_BAM_NAMES_TAG);
final PCollectionView> writtenBAMShardsView =
writtenBAMShardNames.apply(View.asIterable());
final PCollection> sequenceShardSizes = writeBAMFilesResult.get(WriteBAMFn.SEQUENCE_SHARD_SIZES_TAG);
final PCollection> sequenceShardSizesCombined = sequenceShardSizes.apply(
Combine.perKey(
new Sum.SumLongFn()));
final PCollectionView>> sequenceShardSizesView =
sequenceShardSizesCombined.apply(View.>asIterable());
final PCollection destinationBAMPath = this.pipeline.apply(
Create.of(this.output));
final PCollectionView eofForBAM = pipeline.apply(
Create.of(BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK))
.apply(View.asSingleton());
final PCollection writtenBAMFile = destinationBAMPath.apply(
ParDo.named("Combine BAM shards")
.withSideInputs(writtenBAMShardsView, eofForBAM)
.of(new CombineShardsFn(writtenBAMShardsView, eofForBAM)));
final PCollectionView writtenBAMFileView =
writtenBAMFile.apply(View.asSingleton());
final PCollection indexShards = header.apply(
ParDo.named("Generate index shard tasks")
.of(new GetReferencesFromHeaderFn()));
final PCollectionTuple indexingResult = indexShards
.apply(new BreakFusionTransform())
.apply(
ParDo.named("Write index shards")
.withSideInputs(headerView, writtenBAMFileView, sequenceShardSizesView)
.withOutputTags(WriteBAIFn.WRITTEN_BAI_NAMES_TAG,
TupleTagList.of(WriteBAIFn.NO_COORD_READS_COUNT_TAG))
.of(new WriteBAIFn(headerView, writtenBAMFileView, sequenceShardSizesView)));
final PCollection writtenBAIShardNames = indexingResult.get(WriteBAIFn.WRITTEN_BAI_NAMES_TAG);
final PCollectionView> writtenBAIShardsView =
writtenBAIShardNames.apply(View.asIterable());
final PCollection noCoordCounts = indexingResult.get(WriteBAIFn.NO_COORD_READS_COUNT_TAG);
final PCollection totalNoCoordCount = noCoordCounts
.apply(new BreakFusionTransform())
.apply(
Combine.globally(new Sum.SumLongFn()));
final PCollection totalNoCoordCountBytes = totalNoCoordCount.apply(
ParDo.named("No coord count to bytes").of(new Long2BytesFn()));
final PCollectionView eofForBAI = totalNoCoordCountBytes
.apply(View.asSingleton());
final PCollection destinationBAIPath = this.pipeline.apply(
Create.of(this.output + ".bai"));
final PCollection writtenBAIFile = destinationBAIPath.apply(
ParDo.named("Combine BAI shards")
.withSideInputs(writtenBAIShardsView, eofForBAI)
.of(new CombineShardsFn(writtenBAIShardsView, eofForBAI)));
final PCollection writtenFileNames = PCollectionList.of(writtenBAMFile).and(writtenBAIFile)
.apply(Flatten.pCollections());
return writtenFileNames;
}
/**
* Transforms a long value to bytes (little endian order).
* Used for transforming the no-coord. read count into bytes for writing in
* the footer of the BAI file.
*/
static class Long2BytesFn extends DoFn {
public Long2BytesFn() {
}
@Override
public void processElement(DoFn.ProcessContext c) throws Exception {
ByteBuffer b = ByteBuffer.allocate(8);
b.order(ByteOrder.LITTLE_ENDIAN);
b.putLong(c.element());
c.output(b.array());
}
}
private WriteBAMTransform(String output, Pipeline pipeline) {
this.output = output;
this.pipeline = pipeline;
}
public static PCollection write(PCollection shardedReads, HeaderInfo headerInfo,
String output, Pipeline pipeline) {
final PCollectionTuple tuple = PCollectionTuple
.of(SHARDED_READS_TAG,shardedReads)
.and(HEADER_TAG, pipeline.apply(Create.of(headerInfo).withCoder(HEADER_INFO_CODER)));
return (new WriteBAMTransform(output, pipeline)).apply(tuple);
}
static Coder HEADER_INFO_CODER = DelegateCoder.of(
StringUtf8Coder.of(),
new DelegateCoder.CodingFunction() {
@Override
public String apply(HeaderInfo info) throws Exception {
final StringWriter stringWriter = new StringWriter();
SAM_HEADER_CODEC.encode(stringWriter, info.header);
return info.firstRead.toString() + "\n" + stringWriter.toString();
}
},
new DelegateCoder.CodingFunction() {
@Override
public HeaderInfo apply(String str) throws Exception {
int newLinePos = str.indexOf("\n");
String contigStr = str.substring(0, newLinePos);
String headerStr = str.substring(newLinePos + 1);
return new HeaderInfo(
SAM_HEADER_CODEC.decode(new StringLineReader(headerStr),
"HEADER_INFO_CODER"),
Contig.parseContigsFromCommandLine(contigStr).iterator().next());
}
});
static final SAMTextHeaderCodec SAM_HEADER_CODEC = new SAMTextHeaderCodec();
static {
SAM_HEADER_CODEC.setValidationStringency(ValidationStringency.SILENT);
}
}