Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
package org.broadinstitute.hellbender.engine.spark;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import org.apache.spark.SparkFiles;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.hellbender.engine.*;
import org.broadinstitute.hellbender.tools.DownsampleableSparkReadShard;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.activityprofile.ActivityProfileState;
import org.broadinstitute.hellbender.utils.activityprofile.ActivityProfileStateRange;
import org.broadinstitute.hellbender.utils.downsampling.PositionalDownsampler;
import org.broadinstitute.hellbender.utils.downsampling.ReadsDownsampler;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import scala.Tuple2;
import javax.annotation.Nullable;
import java.util.Iterator;
import java.util.List;
import java.util.function.Supplier;
/**
* Find assembly regions from reads in a distributed Spark setting. There are two algorithms available, fast,
* which looks for assembly regions in each read shard in parallel, and strict, which looks for assembly regions
* in each contig in parallel. Fast mode may produce read shard boundary artifacts for assembly regions compared to the
* walker version. Strict mode should be identical to the walker version, at the cost of increased runtime compared to
* the fast version.
*/
public class FindAssemblyRegionsSpark {
/**
* Get an RDD of assembly regions for the given reads and intervals using the fast algorithm (looks for
* assembly regions in each read shard in parallel).
* @param ctx the Spark context
* @param reads the coordinate-sorted reads
* @param header the header for the reads
* @param sequenceDictionary the sequence dictionary for the reads
* @param referenceFileName the file name for the reference
* @param features source of arbitrary features (may be null)
* @param intervalShards the sharded intervals to find assembly regions for
* @param assemblyRegionEvaluatorSupplierBroadcast evaluator used to determine whether a locus is active
* @param shardingArgs the arguments for sharding reads
* @param assemblyRegionArgs the arguments for finding assembly regions
* @param shuffle whether to use a shuffle or not when sharding reads
* @return an RDD of assembly regions
*/
public static JavaRDD getAssemblyRegionsFast(
final JavaSparkContext ctx,
final JavaRDD reads,
final SAMFileHeader header,
final SAMSequenceDictionary sequenceDictionary,
final String referenceFileName,
final FeatureManager features,
final List intervalShards,
final Broadcast> assemblyRegionEvaluatorSupplierBroadcast,
final AssemblyRegionReadShardArgumentCollection shardingArgs,
final AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean shuffle) {
JavaRDD> shardedReads = SparkSharder.shard(ctx, reads, GATKRead.class, sequenceDictionary, intervalShards, shardingArgs.readShardSize, shuffle);
Broadcast bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedReads.mapPartitions(getAssemblyRegionsFunctionFast(referenceFileName, bFeatureManager, header,
assemblyRegionEvaluatorSupplierBroadcast, assemblyRegionArgs));
}
private static FlatMapFunction>, AssemblyRegionWalkerContext> getAssemblyRegionsFunctionFast(
final String referenceFileName,
final Broadcast bFeatureManager,
final SAMFileHeader header,
final Broadcast> supplierBroadcast,
final AssemblyRegionArgumentCollection assemblyRegionArgs) {
return (FlatMapFunction>, AssemblyRegionWalkerContext>) shardedReadIterator -> {
final ReferenceDataSource reference = referenceFileName == null ? null : new ReferenceFileSource(IOUtils.getPath(SparkFiles.get(referenceFileName)));
final FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
final AssemblyRegionEvaluator assemblyRegionEvaluator = supplierBroadcast.getValue().get(); // one AssemblyRegionEvaluator instance per Spark partition
final ReadsDownsampler readsDownsampler = assemblyRegionArgs.maxReadsPerAlignmentStart > 0 ?
new PositionalDownsampler(assemblyRegionArgs.maxReadsPerAlignmentStart, header) : null;
Iterator> iterators = Utils.stream(shardedReadIterator)
.map(shardedRead -> new ShardToMultiIntervalShardAdapter<>(
new DownsampleableSparkReadShard(
new ShardBoundary(shardedRead.getInterval(), shardedRead.getPaddedInterval()), shardedRead, readsDownsampler)))
.map(downsampledShardedRead -> {
final Iterator assemblyRegionIter = new AssemblyRegionIterator(
new ShardToMultiIntervalShardAdapter<>(downsampledShardedRead),
header, reference, features, assemblyRegionEvaluator, assemblyRegionArgs);
return Utils.stream(assemblyRegionIter).map(assemblyRegion ->
new AssemblyRegionWalkerContext(assemblyRegion,
new ReferenceContext(reference, assemblyRegion.getPaddedSpan()),
new FeatureContext(features, assemblyRegion.getPaddedSpan()))).iterator();
}).iterator();
return Iterators.concat(iterators);
};
}
/**
* Get an RDD of assembly regions for the given reads and intervals using the strict algorithm (looks for
* assembly regions in each contig in parallel).
* @param ctx the Spark context
* @param reads the coordinate-sorted reads
* @param header the header for the reads
* @param sequenceDictionary the sequence dictionary for the reads
* @param referenceFileName the file name for the reference
* @param features source of arbitrary features (may be null)
* @param intervalShards the sharded intervals to find assembly regions for
* @param assemblyRegionEvaluatorSupplierBroadcast evaluator used to determine whether a locus is active
* @param shardingArgs the arguments for sharding reads
* @param assemblyRegionArgs the arguments for finding assembly regions
* @param shuffle whether to use a shuffle or not when sharding reads
* @return an RDD of assembly regions
*/
public static JavaRDD getAssemblyRegionsStrict(
final JavaSparkContext ctx,
final JavaRDD reads,
final SAMFileHeader header,
final SAMSequenceDictionary sequenceDictionary,
final String referenceFileName,
final FeatureManager features,
final List intervalShards,
final Broadcast> assemblyRegionEvaluatorSupplierBroadcast,
final AssemblyRegionReadShardArgumentCollection shardingArgs,
final AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean shuffle) {
JavaRDD> shardedReads = SparkSharder.shard(ctx, reads, GATKRead.class, sequenceDictionary, intervalShards, shardingArgs.readShardSize, shuffle);
Broadcast bFeatureManager = features == null ? null : ctx.broadcast(features);
// 1. Calculate activity for each locus in the desired intervals, in parallel.
JavaRDD activityProfileStates = shardedReads.mapPartitions(getActivityProfileStatesFunction(referenceFileName, bFeatureManager, header,
assemblyRegionEvaluatorSupplierBroadcast, assemblyRegionArgs));
// 2. Group by contig. We need to do this so we can perform the band pass filter over the whole contig, so we
// produce assembly regions that are identical to those produced by AssemblyRegionWalker.
// This step requires a shuffle, but the amount of data in the ActivityProfileStateRange should be small, so it
// should not be prohibitive.
JavaPairRDD> contigToGroupedStates = activityProfileStates
.keyBy((Function) range -> range.getContig())
.groupByKey();
// 3. Run the band pass filter to find AssemblyRegions. The filtering is fairly cheap, so should be fast
// even though it has to scan a whole contig. Note that we *don't* fill in reads here, since after we have found
// the assembly regions we want to do assembly using the full resources of the cluster. So if we have
// very small assembly region objects, then we can repartition them for redistribution across the cluster,
// at which points the reads can be filled in. (See next step.)
JavaRDD readlessAssemblyRegions = contigToGroupedStates
.flatMap(getReadlessAssemblyRegionsFunction(header, assemblyRegionArgs));
// repartition to distribute the data evenly across the cluster again
readlessAssemblyRegions = readlessAssemblyRegions.repartition(readlessAssemblyRegions.getNumPartitions());
// 4. Fill in the reads. Each shard is an assembly region, with its overlapping reads.
JavaRDD> assemblyRegionShardedReads = SparkSharder.shard(ctx, reads, GATKRead.class, header.getSequenceDictionary(), readlessAssemblyRegions, shardingArgs.readShardSize);
// 5. Convert shards to assembly regions. Reads downsampling is done again here. Note it will only be
// consistent with the downsampling done in step 1 when https://github.com/broadinstitute/gatk/issues/5437 is in.
JavaRDD assemblyRegions = assemblyRegionShardedReads.mapPartitions((FlatMapFunction>, AssemblyRegion>) shardedReadIterator -> {
final ReadsDownsampler readsDownsampler = assemblyRegionArgs.maxReadsPerAlignmentStart > 0 ?
new PositionalDownsampler(assemblyRegionArgs.maxReadsPerAlignmentStart, header) : null;
return Utils.stream(shardedReadIterator)
.map(shardedRead -> toAssemblyRegion(shardedRead, header, readsDownsampler)).iterator();
});
// 6. Add reference and feature context.
return assemblyRegions.mapPartitions(getAssemblyRegionWalkerContextFunction(referenceFileName, bFeatureManager));
}
private static FlatMapFunction>, ActivityProfileStateRange> getActivityProfileStatesFunction(
final String referenceFileName,
final Broadcast bFeatureManager,
final SAMFileHeader header,
final Broadcast> supplierBroadcast,
final AssemblyRegionArgumentCollection assemblyRegionArgs) {
return (FlatMapFunction>, ActivityProfileStateRange>) shardedReadIterator -> {
final ReferenceDataSource reference = referenceFileName == null ? null : new ReferenceFileSource(IOUtils.getPath(SparkFiles.get(referenceFileName)));
final FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
final AssemblyRegionEvaluator assemblyRegionEvaluator = supplierBroadcast.getValue().get(); // one AssemblyRegionEvaluator instance per Spark partition
return Utils.stream(shardedReadIterator)
.map(shardedRead -> {
final ReadsDownsampler readsDownsampler = assemblyRegionArgs.maxReadsPerAlignmentStart > 0 ?
new PositionalDownsampler(assemblyRegionArgs.maxReadsPerAlignmentStart, header) : null;
return new ShardToMultiIntervalShardAdapter<>(
new DownsampleableSparkReadShard(
new ShardBoundary(shardedRead.getInterval(), shardedRead.getPaddedInterval()), shardedRead, readsDownsampler));
})
.map(shardedRead -> {
final Iterator activityProfileStateIter = new ActivityProfileStateIterator(
new ShardToMultiIntervalShardAdapter<>(shardedRead),
header, reference, features, assemblyRegionEvaluator
);
return new ActivityProfileStateRange(shardedRead, activityProfileStateIter);
}).iterator();
};
}
private static FlatMapFunction>, ReadlessAssemblyRegion> getReadlessAssemblyRegionsFunction(
final SAMFileHeader header,
final AssemblyRegionArgumentCollection assemblyRegionArgs) {
return (FlatMapFunction>, ReadlessAssemblyRegion>) iter ->
Iterators.transform(
new AssemblyRegionFromActivityProfileStateIterator(
ActivityProfileStateRange.toIteratorActivityProfileState(iter._2.iterator()),
header,
assemblyRegionArgs.minAssemblyRegionSize,
assemblyRegionArgs.maxAssemblyRegionSize,
assemblyRegionArgs.assemblyRegionPadding,
assemblyRegionArgs.activeProbThreshold,
assemblyRegionArgs.maxProbPropagationDistance), new com.google.common.base.Function() {
@Nullable
@Override
public ReadlessAssemblyRegion apply(@Nullable AssemblyRegion input) {
return new ReadlessAssemblyRegion(input);
}
});
}
private static AssemblyRegion toAssemblyRegion(Shard shard, SAMFileHeader header, ReadsDownsampler readsDownsampler) {
Shard downsampledShardedRead =
new DownsampleableSparkReadShard(
new ShardBoundary(shard.getInterval(), shard.getPaddedInterval()), shard, readsDownsampler);
// TODO: interfaces could be improved to avoid casting
ReadlessAssemblyRegion readlessAssemblyRegion = (ReadlessAssemblyRegion) ((ShardBoundaryShard) shard).getShardBoundary();
int extension = Math.max(shard.getInterval().getStart() - shard.getPaddedInterval().getStart(), shard.getPaddedInterval().getEnd() - shard.getInterval().getEnd());
AssemblyRegion assemblyRegion = new AssemblyRegion(shard.getInterval(), readlessAssemblyRegion.isActive(), extension, header);
assemblyRegion.addAll(Lists.newArrayList(downsampledShardedRead));
return assemblyRegion;
}
private static FlatMapFunction, AssemblyRegionWalkerContext> getAssemblyRegionWalkerContextFunction(
final String referenceFileName,
final Broadcast bFeatureManager) {
return (FlatMapFunction, AssemblyRegionWalkerContext>) assemblyRegionIter -> {
final ReferenceDataSource reference = referenceFileName == null ? null : new ReferenceFileSource(IOUtils.getPath(SparkFiles.get(referenceFileName)));
final FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
return Utils.stream(assemblyRegionIter).map(assemblyRegion ->
new AssemblyRegionWalkerContext(assemblyRegion,
new ReferenceContext(reference, assemblyRegion.getPaddedSpan()),
new FeatureContext(features, assemblyRegion.getPaddedSpan()))).iterator();
};
}
}