org.broadinstitute.hellbender.tools.dragstr.CalibrateDragstrModel Maven / Gradle / Ivy
package org.broadinstitute.hellbender.tools.dragstr;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.Spliterator;
import java.util.Vector;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.commons.io.output.NullOutputStream;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup;
import org.broadinstitute.hellbender.engine.GATKPath;
import org.broadinstitute.hellbender.engine.GATKTool;
import org.broadinstitute.hellbender.engine.ReadsDataSource;
import org.broadinstitute.hellbender.engine.ReadsPathDataSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCaller;
import org.broadinstitute.hellbender.transformers.DRAGENMappingQualityReadTransformer;
import org.broadinstitute.hellbender.transformers.ReadTransformer;
import org.broadinstitute.hellbender.utils.BinaryTableReader;
import org.broadinstitute.hellbender.utils.IntervalMergingRule;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SequenceDictionaryUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.collections.AutoCloseableCollection;
import org.broadinstitute.hellbender.utils.dragstr.DragstrParamUtils;
import org.broadinstitute.hellbender.utils.dragstr.DragstrParams;
import org.broadinstitute.hellbender.utils.dragstr.STRTableFile;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.reference.AbsoluteCoordinates;
import htsjdk.samtools.CigarElement;
import htsjdk.samtools.CigarOperator;
import htsjdk.samtools.SAMFlag;
import htsjdk.samtools.SAMReadGroupRecord;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SamReaderFactory;
import htsjdk.samtools.util.IntervalTree;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMaps;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import scala.Tuple4;
/**
* Estimates the parameters for the DRAGstr model for an input sample.
*
* This tools takes in the sampling sites generated by {@link ComposeSTRTableFile} on the same reference
* as the input sample.
*
*
* The end result is a text file containing three parameter tables (GOP, GCP, API) that can be fed
* directly to {@link HaplotypeCaller} --dragstr-params-path.
*
*/
@CommandLineProgramProperties(
summary = "estimates the parameters for the DRAGstr model for the input sample using the output of the ComposeSTRTable tool",
oneLineSummary = "estimates the parameters for the DRAGstr model",
programGroup = ShortVariantDiscoveryProgramGroup.class
)
@DocumentedFeature
public class CalibrateDragstrModel extends GATKTool {
public static final String STR_TABLE_PATH_SHORT_NAME = "str";
public static final String STR_TABLE_PATH_FULL_NAME = "str-table-path";
public static final String PARALLEL_FULL_NAME = "parallel";
public static final String THREADS_FULL_NAME = "threads";
public static final String SHARD_SIZE_FULL_NAME = "shard-size";
public static final String DOWN_SAMPLE_SIZE_FULL_NAME = "down-sample-size";
public static final String DEBUG_SITES_OUTPUT_FULL_NAME = "debug-sites-output";
public static final String FORCE_ESTIMATION_FULL_NAME = "force-estimation";
public static final int DEFAULT_SHARD_SIZE = 1_000_000;
public static final int DEFAULT_DOWN_SAMPLE_SIZE = 4096;
public static final int SYSTEM_SUGGESTED_THREAD_NUMBER = 0;
public static final int MINIMUM_SHARD_SIZE = 100;
public static final int MINIMUM_DOWN_SAMPLE_SIZE = 512;
@ArgumentCollection
private DragstrHyperParameters hyperParameters = new DragstrHyperParameters();
@Argument(shortName=STR_TABLE_PATH_SHORT_NAME, fullName=STR_TABLE_PATH_FULL_NAME, doc="location of the zip that contains the sampling sites for the reference")
private GATKPath strTablePath = null;
@Argument(fullName=PARALLEL_FULL_NAME, doc="run alignment data collection and estimation in parallel", optional = true)
private boolean runInParallel = false;
@Argument(fullName=THREADS_FULL_NAME, minValue = SYSTEM_SUGGESTED_THREAD_NUMBER, doc="suggested number of parallel threads to perform the estimation, "
+ "the default 0 leave it up to the VM to decide. When set to more than 1, this will activate parallel in the absence of --parallel", optional = true)
private int threads = SYSTEM_SUGGESTED_THREAD_NUMBER;
@Argument(fullName=SHARD_SIZE_FULL_NAME, doc="when running in parallel this is the suggested shard size in base pairs. " +
"The actual shard-size may vary to adapt to small contigs and the requested number of threads",
minValue = MINIMUM_SHARD_SIZE, optional = true)
private int shardSize = DEFAULT_SHARD_SIZE;
@Argument(fullName=DOWN_SAMPLE_SIZE_FULL_NAME, doc="Targeted maximum number of cases per combination period repeat count, " +
"the larger the more precise but also the slower estimation.",
minValue = MINIMUM_DOWN_SAMPLE_SIZE, optional = true)
private int downsampleSize = DEFAULT_DOWN_SAMPLE_SIZE;
@Argument(fullName= StandardArgumentDefinitions.OUTPUT_LONG_NAME, shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, doc = "where to write the parameter output file.")
private GATKPath output = null;
@Argument(fullName= DEBUG_SITES_OUTPUT_FULL_NAME, doc = "table with information gather on the samples sites. Includes what sites were downsampled, disqualified or accepted for parameter estimation", optional = true)
private String sitesOutput = null;
@Argument(fullName= FORCE_ESTIMATION_FULL_NAME, doc = "for testing purpose only; force parameter estimation even with few datapoints available", optional = true)
private boolean forceEstimation = false;
private SAMSequenceDictionary dictionary;
private SamReaderFactory factory;
public static final ReadTransformer EXTENDED_MQ_READ_TRANSFORMER = new DRAGENMappingQualityReadTransformer();
@Override
public boolean requiresReference() {
return true;
}
@Override
public boolean requiresReads() {
return true;
}
@Override
protected void onStartup() {
super.onStartup();
hyperParameters.validate();
dictionary = directlyAccessEngineReadsDataSource().getSequenceDictionary();
factory = makeSamReaderFactory();
if (runInParallel) {
if (threads == 1) {
logger.warn("parallel processing was requested but the number of threads was set to 1");
}
} else if (threads > 1) {
runInParallel = true;
}
if (runInParallel) {
if (threads == 0) {
logger.info("Running in parallel using the system suggested default thread count: " + Runtime.getRuntime().availableProcessors());
} else {
logger.info("Running in parallel using the requested number of threads: " + threads);
}
}
}
@Override
public void traverse() {
hyperParameters.validate();
dictionary = getBestAvailableSequenceDictionary();
final List readGroups = hasReads() ? getHeaderForReads().getReadGroups() : Collections.emptyList();
final List readGroupIds = readGroups.stream()
.map(SAMReadGroupRecord::getId)
.collect(Collectors.toList());
final List sampleNames = readGroups.stream()
.map(SAMReadGroupRecord::getSample)
.distinct().collect(Collectors.toList());
final Optional sampleName = resolveSampleName(sampleNames);
try (final PrintWriter sitesOutputWriter = openSitesOutputWriter(sitesOutput);
final STRTableFile strTable = STRTableFile.open(strTablePath)) {
checkSequenceDictionaryCompatibility(dictionary, strTable.dictionary());
final StratifiedDragstrLocusCases allSites;
final List intervals = getTraversalIntervals();
runInParallel |= threads > 1;
if (runInParallel) {
if (threads == 1) {
logger.warn("parallel processing was requested but the number of threads was set to 1");
}
allSites = collectCaseStatsParallel(intervals, shardSize, strTable);
} else {
allSites = collectCaseStatsSequencial(intervals, strTable);
}
logSiteCounts(allSites, "all loci/cases");
final StratifiedDragstrLocusCases downSampledSites = downSample(allSites, strTable, sitesOutputWriter);
logSiteCounts(downSampledSites, "all downsampled (kept) loci/cases");
final StratifiedDragstrLocusCases finalSites = downSampledSites.qualifyingOnly(hyperParameters.minDepth, hyperParameters.minMQ, 0);
logSiteCounts(finalSites, "all qualifying loci/cases");
outputDownSampledSiteDetails(downSampledSites, sitesOutputWriter, hyperParameters.minDepth, hyperParameters.minMQ, 0);
printOutput(finalSites, sampleName.orElse(null), readGroupIds);
}
}
private void printOutput(final StratifiedDragstrLocusCases finalSites, final String sampleName, final List readGroups) {
final boolean enoughCases = isThereEnoughCases(finalSites);
final boolean usingDefaults = !enoughCases && !forceEstimation;
final Object[] annotations = {
"sample", (sampleName == null ? "" : sampleName),
"readGroups", (readGroups.isEmpty() ? "" : Utils.join(", ", readGroups)),
"estimatedOrDefaults", (usingDefaults ? "defaults" : (enoughCases ? "estimated" : "estimatedByForce")),
"commandLine", getCommandLine()
};
if (!usingDefaults) {
if (!enoughCases) {
logger.warn("Forcing parameters estimation using sampled down cases as requested");
} else {
logger.info("Estimating parameters using sampled down cases");
}
final DragstrParams estimate = estimateParams(finalSites);
logger.info("Done with estimation, printing output");
DragstrParamUtils.print(estimate, output, annotations);
} else {
logger.warn("Not enough cases to estimate parameters, using defaults");
DragstrParamUtils.print(DragstrParams.DEFAULT, output, annotations);
}
}
private Optional resolveSampleName(List sampleNames) {
if (sampleNames.size() > 1) {
throw new GATKException("the input alignment(s) have more than one sample: " + String.join(", ", sampleNames));
} else if (sampleNames.isEmpty() || sampleNames.get(0) == null) {
logger.warn("there is no sample id in the alignment header, assuming that all reads and read/groups make reference to the same anonymous sample");
return Optional.empty();
} else {
return Optional.of(sampleNames.get(0));
}
}
private void checkSequenceDictionaryCompatibility(final SAMSequenceDictionary reference, final SAMSequenceDictionary strTable) {
final SequenceDictionaryUtils.SequenceDictionaryCompatibility compatibility = SequenceDictionaryUtils.compareDictionaries(reference, strTable, false);
switch (compatibility) {
case IDENTICAL: return;
case SUPERSET: return;
// probably these two below aren't ever be returned since we ask for no check on order but
// adding them it just in case
case NON_CANONICAL_HUMAN_ORDER: return; // we don't care about the order.
case OUT_OF_ORDER: return; // we don't care about the order.
default:
throw new GATKException("the reference and str-table sequence dictionary are incompatible: " + compatibility);
}
}
@SuppressWarnings("deprecation")
private PrintWriter openSitesOutputWriter(final String sitesOutput) {
return sitesOutput == null ? new PrintWriter(NullOutputStream.NULL_OUTPUT_STREAM)
: new PrintWriter(BucketUtils.createFile(sitesOutput));
}
private void outputDownSampledSiteDetails(final StratifiedDragstrLocusCases finalSites,
final PrintWriter writer,
final int minDepth,
final int samplingMinMQ,
final int maxSup) {
if (sitesOutput != null) {
for (final DragstrLocusCases[] periodCases : finalSites.perPeriodAndRepeat) {
for (final DragstrLocusCases repeatCases : periodCases) {
for (final DragstrLocusCase caze : repeatCases) {
outputSiteDetails(writer, caze, caze.qualifies(minDepth, samplingMinMQ, maxSup) ? "used" : "skipped");
}
}
}
}
}
/**
* Holds the minimum counts for each period, repeat-length combo.
* If there is lack of data for any of these we use the default param
* tables. Missing values, row (periods) or columns (repeat-length) are
* interpreted as 0.
*/
private static final int[][] MINIMUM_CASES_BY_PERIOD_AND_LENGTH =
// @formatter:off ; prevents code reformatting by IntelliJ
// if enabled:
// Preferences > Editor > Code Style > Formatter Control
// run-length:
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10+ // period
{ {},
{0, 200, 200, 200, 200, 200, 200, 200, 200, 200, 0}, // 1
{0, 0, 200, 200, 200, 200, 0, 0, 0, 0, 0}, // 2
{0, 0, 200, 200, 200, 0, 0, 0, 0, 0, 0}, // 3
{0, 0, 200, 200, 0, 0, 0, 0, 0, 0, 0}, // 4
{0, 0, 200, 0, 0, 0, 0, 0, 0, 0, 0}, // 5
{0, 0, 200, 0, 0, 0, 0, 0, 0, 0, 0}, // 6
{0, 0, 200, 0, 0, 0, 0, 0, 0, 0, 0}, // 7
{0, 0, 200, 0, 0, 0, 0, 0, 0, 0, 0}, // 8
};
// zeros to the right are actually not necessary, but add them to make it look more like a matrix.
// @formatter:on
/**
* Check that a minimum number of cases are available in key bins (combo period, repeat).
*/
private boolean isThereEnoughCases(final StratifiedDragstrLocusCases allSites) {
// period 1, repeat length 1 to 9 (inclusive)
final int[][] MCBL = MINIMUM_CASES_BY_PERIOD_AND_LENGTH;
final int maxP = Math.min(hyperParameters.maxPeriod, MCBL.length - 1);
final List> failingCombos = new ArrayList<>(10);
for (int i = 1; i <= maxP; i++) {
final int maxL = Math.min(hyperParameters.maxRepeatLength, MCBL[i].length - 1);
for (int j = 1; j <= maxL; j++) {
if (allSites.get(i, j).size() < MCBL[i][j]) {
failingCombos.add(new Tuple4<>(i, j, allSites.get(i, j).size(), MCBL[i][j]));
}
}
}
if (failingCombos.isEmpty()) {
return true;
} else if (forceEstimation) {
logger.warn("there is not enough data to proceed to parameter empirical estimation " +
"but user requested to force it, so we go ahead");
for (final Tuple4 failingCombo : failingCombos) {
logger.warn(String.format("(P=%d, L=%d) count %d is less than minimum required %d ",
failingCombo._1(), failingCombo._2(), failingCombo._3(), failingCombo._4()));
}
return true;
} else {
logger.warn("there is not enough data to proceed to parameter empirical estimation, using defaults instead");
return false;
}
}
/**
* Performs the final estimation step.
* @param finalSites the site to use for the estimation.
* @return {@code never null}.
*/
private DragstrParams estimateParams(final StratifiedDragstrLocusCases finalSites) {
final DragstrParametersEstimator estimator = new DragstrParametersEstimator(hyperParameters);
return runInParallel ? Utils.runInParallel(threads, () -> estimator.estimate(finalSites)) : estimator.estimate(finalSites);
}
/**
* Downsample sites so that at most as many as {@link #downsampleSize} cases remain for each period and repeat-length combination.
* @param allSites the sites to downsample.
* @param strTable that contains the decimation table used to generate those sites.
* @param sitesOutputWriter an optional per site informattion output argument for debugging purposes.
* @return never {@code null}.
*/
private StratifiedDragstrLocusCases downSample(final StratifiedDragstrLocusCases allSites, final STRTableFile strTable,
final PrintWriter sitesOutputWriter) {
final STRDecimationTable decimationTable = strTable.decimationTable();
final List prCombos = new ArrayList<>(hyperParameters.maxPeriod * hyperParameters.maxRepeatLength);
for (int i = 1; i <= hyperParameters.maxPeriod; i++) {
for (int j = 1; j <= hyperParameters.maxRepeatLength; j++) {
prCombos.add(PeriodAndRepeatLength.of(i, j));
}
}
final Stream prCombosStream = runInParallel ? prCombos.parallelStream() : prCombos.stream();
final Stream downsampledStream = prCombosStream
.flatMap(combo -> {
final DragstrLocusCases all = allSites.perPeriodAndRepeat[combo.period - 1][combo.repeatLength - 1];
final int decimationBit = decimationTable.decimationBit(combo.period, combo.repeatLength);
return downSample(all, decimationBit, downsampleSize, sitesOutputWriter).stream();
});
if (runInParallel) {
return Utils.runInParallel(threads,
() -> downsampledStream.collect(DragstrLocusCaseStratificator.make(hyperParameters.maxPeriod, hyperParameters.maxRepeatLength)));
} else {
return downsampledStream.collect(DragstrLocusCaseStratificator.make(hyperParameters.maxPeriod, hyperParameters.maxRepeatLength));
}
}
/**
* Pre-calculated decimation masks used depending on the final decimation bit/level.
*/
private static final long[] DECIMATION_MASKS_BY_BIT = new long[Long.SIZE];
// Code to populate DECIMATION_MASKS_BY_BIT.
static {
DECIMATION_MASKS_BY_BIT[0] = 1;
for (int i = 1, j = 0; i < Long.SIZE; i++, j++) {
DECIMATION_MASKS_BY_BIT[i] = DECIMATION_MASKS_BY_BIT[j] << 1;
DECIMATION_MASKS_BY_BIT[j] = ~DECIMATION_MASKS_BY_BIT[j];
}
DECIMATION_MASKS_BY_BIT[Long.SIZE -1] = ~DECIMATION_MASKS_BY_BIT[Long.SIZE - 1];
}
/**
* Decimates the collection of locus/cases to the downsample size provided or smaller.
*
* Notice that if we need to downsample (the input size is larger than the downsample size provided)
* we take care of not counting those cases that have zero-length toward that limit.
* This is due to the apparent behaviour in DRAGEN where those are "sort-of" filtered before
* decimation as far as meeting the final downsample size limit is concerned.
*
*
* They usually would be skipped eventually in post-downsampling filtering but we don't consider
* their number here we end up downsampling some period, repeat-length combintions too much
* as compare to DRAGEN.
*
*
* This behavior in DRAGEN may well change in future releases.
*
* @param in input collection of cases to downsample.
* @param minDecimationBit The start decimation bit. Usually the input cases collection won't contain any
* cases with lower bit set (already decimated).
* @param downsampleSize the target size.
* @return never {@code null}. At most the return would contain {@code downsampleSize} cases discounting cases with zero depth. It could be empty.
*/
private DragstrLocusCases downSample(final DragstrLocusCases in, final int minDecimationBit, final int downsampleSize, final PrintWriter sitesOutputWriter) {
final int inSize = in.size();
if (inSize <= downsampleSize) { // we already satisfy the imposed size limit so we do nothing.
return in;
} else {
int zeroDepth = 0;
final int[] countByFirstDecimatingBit = new int[Long.SIZE - minDecimationBit];
for (final DragstrLocusCase caze: in) {
final DragstrLocus locus = caze.getLocus();
final int depth = caze.getDepth();
if (depth <= 0) { // we discount cases with zero depth as these are going to be skipped eventually.
zeroDepth++;
continue;
}
long mask = locus.getMask();
for (int j = minDecimationBit; mask != 0 && j < Long.SIZE; j++) {
final long newMask = mask & DECIMATION_MASKS_BY_BIT[j];
if (newMask != mask) {
countByFirstDecimatingBit[j]++;
break;
}
}
}
final IntList progressiveSizes = new IntArrayList(Long.SIZE + 1);
progressiveSizes.add(inSize);
int finalSize = inSize - zeroDepth;
progressiveSizes.add(finalSize);
long filterMask = 0;
for (int j = minDecimationBit; finalSize > downsampleSize && j < Long.SIZE; j++) {
finalSize -= countByFirstDecimatingBit[j];
filterMask |= ~DECIMATION_MASKS_BY_BIT[j];
progressiveSizes.add(finalSize);
}
final DragstrLocusCases discarded = new DragstrLocusCases(finalSize, in.getPeriod(), in.getRepeatLength());
final DragstrLocusCases result = new DragstrLocusCases(in.size() - finalSize, in.getPeriod(), in.getRepeatLength());
for (final DragstrLocusCase caze: in) {
final long mask = caze.getLocus().getMask();
if ((mask & filterMask) == 0 & caze.getDepth() > 0) {
discarded.add(caze);
} else {
result.add(caze);
}
}
// Debug-log message format explained:
// period repeat-length [x0, x00, x1, x2, x3 ... xN]
// where x0 is the input size.
// x00 = x0 - #zero depth cases
// x1 = x00 - #first round of decimation
// x2 = x1 - #second round of decimation.
// ...
// xN = final size <= downsampleSize
logger.debug(() -> "" + in.getPeriod() + " " + in.getRepeatLength() + " "
+ Arrays.toString(progressiveSizes.toArray()));
// we output info about the sites that are discarded:
if (sitesOutput != null && result.size() > 0) {
synchronized (this) {
for (final DragstrLocusCase caze : result) {
outputSiteDetails(sitesOutputWriter, caze, "downsampled-out");
}
}
}
return discarded;
}
}
/**
* Logs cases counts in a matrix where columns are periods and rows are
* repeat length in repeat units.
* @param cases the cases whose counts are to be logged.
* @param title the title of the debug message.
*/
private void logSiteCounts(final StratifiedDragstrLocusCases cases, final String title) {
if (logger.isDebugEnabled()) { // here it seems pertinent to check to save time if DEBUG is off since
// this method is all about debug logging.
logger.debug(title);
final int[] columnWidths = IntStream.range(1, hyperParameters.maxPeriod + 1).map(period -> {
final int max = IntStream.range(1, hyperParameters.maxRepeatLength + 1).map(repeat -> cases.get(period,repeat).size())
.max().orElse(0);
return (int) Math.max(7, Math.ceil(Math.log10(max)) + 1); }).toArray();
logger.debug(" " + IntStream.range(0, hyperParameters.maxPeriod).mapToObj(i -> String.format("%-" + columnWidths[i] + "s", (i + 1))).collect(Collectors.joining()));
for (int i = 1; i <= hyperParameters.maxRepeatLength; i++) {
final int repeat = i;
logger.debug(String.format("%-4s", repeat) + " " + IntStream.range(1, hyperParameters.maxPeriod + 1)
.mapToObj(period -> String.format("%-" + columnWidths[period - 1] + "s",
cases.get(period, repeat).size())).collect(Collectors.joining("")));
}
}
}
private StratifiedDragstrLocusCases collectCaseStatsSequencial(final List intervals, final STRTableFile strTable) {
final StratifiedDragstrLocusCases result = StratifiedDragstrLocusCases.make(hyperParameters.maxPeriod, hyperParameters.maxRepeatLength);
final ReadsDataSource dataSource = directlyAccessEngineReadsDataSource();
for (final SimpleInterval interval : intervals) {
try (final BinaryTableReader reader = strTable.locusReader(interval)) {
streamShardCasesStats(interval, readStream(dataSource, interval), reader.stream())
.peek(caze -> progressMeter.update(caze.getLocation(dictionary)))
.forEach(result::add);
} catch (final IOException ex) {
throw new GATKException("problems accessing str-table-file at " + strTablePath);
}
}
return result;
}
@SuppressWarnings("try") // silences intended use of unreferenced auto-closable within try-resource.
private StratifiedDragstrLocusCases collectCaseStatsParallel(final List intervals, final int shardSize, final STRTableFile strTable) {
//TODO: instead of the dictionary this should take on the traversal intervals.
//TODO: currently, in the user specifies intervals the progress-meter will show that aount of bases at the beginning of the reference
// instead.
final AbsoluteCoordinates absoluteCoordinates = AbsoluteCoordinates.of(dictionary);
final List shards = shardIntervals(intervals, shardSize);
final Collection readSources = new Vector<>(threads);
final ThreadLocal threadReadSource = ThreadLocal.withInitial(
() -> {
final ReadsPathDataSource result = new ReadsPathDataSource(readArguments.getReadPaths(), factory);
readSources.add(result);
return result;
});
try (@SuppressWarnings("unused") final AutoCloseableCollection> readSourceCloser = new AutoCloseableCollection<>(readSources)) {
final AtomicLong numberBasesProcessed = new AtomicLong(0);
return Utils.runInParallel(Math.min(threads, shards.size()), () ->
StreamSupport.stream(new InterleavingListSpliterator<>(shards), true)
.map(shard -> {
try (final BinaryTableReader lociReader = strTable.locusReader(shard)) {
final ReadsPathDataSource readsSource = threadReadSource.get();
final StratifiedDragstrLocusCases result = streamShardCasesStats(shard, readStream(readsSource, shard), lociReader.stream())
.collect(DragstrLocusCaseStratificator.make(hyperParameters.maxPeriod, hyperParameters.maxRepeatLength));
final int resultSize = result.size();
synchronized (numberBasesProcessed) {
final long processed = numberBasesProcessed.updateAndGet(l -> l + shard.size());
progressMeter.update(absoluteCoordinates.toSimpleInterval(processed, 1), resultSize);
}
return result;
} catch (final IOException ex) {
throw new GATKException("problems accessing the str-table-file contents at " + strTablePath, ex);
}
})
.reduce(StratifiedDragstrLocusCases::merge)
.orElseGet(() -> new StratifiedDragstrLocusCases(
hyperParameters.maxPeriod,
hyperParameters.maxRepeatLength)));
}
}
/**
* Shards the traversal intervals based on the requested target shard size.
* @param raw the unprocessed traversal intervals.
* @param shardSize the target shard size in base-pairs.
* @return never {@code null}, but perhaps empty if the input was empty.
*/
private List shardIntervals(final List raw, final int shardSize) {
final List preSharded = sortAndMergeOverlappingIntervals(raw, dictionary);
final long size = preSharded.stream().mapToLong(SimpleInterval::size).sum();
final List output = new ArrayList<>((int) (preSharded.size() + size / shardSize));
// if less than 1.5 x the desired shard size is left in the current interval we don't split any further:
final int shardingSizeThreshold = (int) Math.round(shardSize * 1.5);
for (final SimpleInterval in : preSharded) {
if (in.size() < shardingSizeThreshold) {
output.add(in);
} else {
int start = in.getStart();
final int inEnd = in.getEnd();
final int stop = in.getEnd() - shardingSizeThreshold + 1;
while (start < stop) {
final int end = start + shardSize - 1;
output.add(new SimpleInterval(in.getContig(), start, end));
start = end + 1;
}
if (start <= inEnd) {
output.add(new SimpleInterval(in.getContig(), start, inEnd));
}
}
}
return output;
}
/**
* If the traversal interval contains some overlaps we need to fix it.
*/
private List sortAndMergeOverlappingIntervals(final List input, final SAMSequenceDictionary dictionary) {
if (isSortedAndHasNoOverlap(input, dictionary)) {
return input;
} else {
final Map> byContig = IntervalUtils.sortAndMergeIntervals(input, dictionary, IntervalMergingRule.ALL);
return byContig.keySet().stream()
.sorted(Comparator.comparingInt(name -> dictionary.getSequence(name).getSequenceIndex()))
.flatMap(name -> byContig.get(name).stream())
.collect(Collectors.toList());
}
}
private boolean isSortedAndHasNoOverlap(final List input, final SAMSequenceDictionary dictionary) {
if (input.isEmpty()) {
return true;
} else {
String prevCtgName = null;
int prevCtgIdx = -1;
int prevEnd = 0;
for (final SimpleInterval interval : input) {
final String ctg = interval.getContig();
final int start = interval.getStart();
final int end = interval.getEnd();
if (ctg.equals(prevCtgName)) {
if (start <= prevEnd) {
return false;
} else {
prevEnd = end;
}
} else {
final int ctgIdx = dictionary.getSequenceIndex(ctg);
if (ctgIdx <= prevCtgIdx) {
return false;
} else {
prevCtgName = ctg;
prevCtgIdx = ctgIdx;
prevEnd = end;
}
}
}
return true;
}
}
private static class DragstrLocusCaseStratificator implements Collector {
private final int maxPeriod;
private final int maxRepeats;
private static DragstrLocusCaseStratificator make(final int maxPeriod, final int maxRepeats) {
return new DragstrLocusCaseStratificator(maxPeriod, maxRepeats);
}
private DragstrLocusCaseStratificator(final int maxPeriod, final int maxRepeats) {
this.maxPeriod = maxPeriod;
this.maxRepeats = maxRepeats;
}
@Override
public Supplier supplier() {
return () -> new StratifiedDragstrLocusCases(maxPeriod, maxRepeats);
}
@Override
public BiConsumer accumulator() {
return StratifiedDragstrLocusCases::add;
}
@Override
public BinaryOperator combiner() {
return StratifiedDragstrLocusCases::addAll;
}
@Override
public Function finisher() {
return a -> a;
}
@Override
public Set characteristics() {
return EnumSet.of(Characteristics.IDENTITY_FINISH, Characteristics.UNORDERED);
}
}
/**
* Stream collector class define to coalese several stratified locus case collections.
*/
private static class DragstrLocusCaseCollector implements Collector {
private final DragstrLocus locus;
private final long strStart;
private final long strEnd;
private final long strEndPlusOne;
private final long paddedStrStart;
private final long paddedStrEnd;
private int n;
private int k;
private int minMQ;
private int nSup;
private DragstrLocusCaseCollector(final DragstrLocus locus, final long strStart,
final long strEnd, final long paddedStrStart, final long paddedStrEnd) {
this.locus = locus;
this.strStart = strStart;
this.strEnd = strEnd;
this.strEndPlusOne = strEnd + 1;
this.paddedStrStart = paddedStrStart;
this.paddedStrEnd = paddedStrEnd;
n = k = nSup = 0;
minMQ = SAMRecord.UNKNOWN_MAPPING_QUALITY;
}
public static DragstrLocusCaseCollector create(final DragstrLocus locus, final int padding, final long contingLength) {
Utils.nonNull(locus);
Utils.validateArg(padding >= 0, "padding must be 0 or greater");
Utils.validateArg(contingLength >= 1, "contig length must be strictly positive");
final long strStart = locus.getStart();
final long strEnd = locus.getEnd();
final long paddedStrStart = Math.max(1, strStart - padding);
final long paddedStrEnd = Math.min(contingLength, strEnd + padding);
return new DragstrLocusCaseCollector(locus, strStart, strEnd, paddedStrStart, paddedStrEnd);
}
@Override
public Supplier supplier() {
return () -> new DragstrLocusCaseCollector(locus, strStart, strEnd, paddedStrStart, paddedStrEnd);
}
@Override
public BiConsumer accumulator() {
return DragstrLocusCaseCollector::collect;
}
/**
* Adds the relevant stats of the read to the collector based on its overlap
* with the STR and the presence of indel events.
*
* Assumes that the read is mapped to the same contig as the locus, se we don't test
* for that.
* @param eset the read to collect.
*/
private void collect(final EquivalentReadSet eset) {
final int readStart = eset.getStart();
final int readEnd = eset.getEnd();
final int size = eset.size();
if (readStart <= paddedStrStart && readEnd >= paddedStrEnd) {
if (eset.isSupplementaryAlignment()) {
nSup += size;
}
minMQ = Math.min(minMQ, eset.getMappingQuality());
int refPos = readStart;
// int lengthDiff = 0;
for (final CigarElement ce : eset.getCigar()) {
final CigarOperator op = ce.getOperator();
final int length = ce.getLength();
if (op == CigarOperator.I && refPos >= strStart && refPos <= strEndPlusOne) {
k += size;
//lengthDiff += length;
} else if (op == CigarOperator.D && refPos + length - 1 >= strStart && refPos <= strEnd) {
k += size;
//lengthDiff -= length;
}
// update refPos and quick end if we have gone beyond the end of the STR.
if ((refPos += op.consumesReferenceBases() ? length : 0) > strEndPlusOne) {
break;
}
}
n += size;
}
}
private DragstrLocusCaseCollector combineWith(final DragstrLocusCaseCollector other) {
Utils.validateArg(other.locus == this.locus, "collectors at different loci cannot be convined");
final DragstrLocusCaseCollector result = new DragstrLocusCaseCollector(locus, strStart,
strEnd, paddedStrStart, paddedStrEnd);
result.k = k + other.k;
result.n = n + other.n;
result.nSup = nSup + other.nSup;
result.minMQ = Math.min(minMQ, other.minMQ);
return result;
}
private DragstrLocusCase finish() {
return DragstrLocusCase.create(locus, n, k, minMQ, nSup);
}
@Override
public BinaryOperator combiner() {
return DragstrLocusCaseCollector::combineWith;
}
@Override
public Function finisher() {
return DragstrLocusCaseCollector::finish;
}
@Override
public Set characteristics() {
return Collections.emptySet();
}
}
/**
* Generates a stream of locus cases for a interval/shard.
*
* The returned stream in turn feeds on two streams: (a) stream of reads for the interval and (b)
* the str-table entry loci for that interval.
*
*
* This are processed in parallel in position order within the shard so that for every given str-table entry we get
* all the relevant reads that overlap the STR.
*
* @param shard the target shard.
* @param reads a stream on the reads in the input shard.
* @param loci a stream on the loci in the input shard.
* @return never {@code null}, perhaps an empty stream.
*/
private Stream streamShardCasesStats(final SimpleInterval shard, final Stream reads, final Stream loci) {
final int contigLength = dictionary.getSequence(shard.getContig()).getSequenceLength();
return StreamSupport.stream(new Spliterator() {
private final Spliterator readSpliterator = reads.spliterator();
private final Spliterator lociSpliterator = loci.spliterator();
private final ShardReadBuffer readBuffer = new ShardReadBuffer();
private GATKRead read;
private DragstrLocus locus;
/**
* Move forward in the reads stream.
*
* If it returns {@code true} the read to process is place in {@link #read};
*
*
* @return {@code true} iff there is one more read to process.
*/
private boolean advanceRead() {
return readSpliterator.tryAdvance(read -> this.read = read);
}
/**
* Move forward in the locus stream.
*
* If it returns {@code true} the locus to process is place in {@link #locus};
*
*
* @return {@code true} iff there is one more loci to process.
*/
private boolean advanceLocus() {
return lociSpliterator.tryAdvance(locus -> this.locus = locus);
}
@Override
public boolean tryAdvance(final Consumer super DragstrLocusCase> action) {
if (advanceLocus()) { // if true, sets 'locus' to the next in the stream.
readBuffer.removeUpstreamFrom((int) locus.getStart()); // flush the buffer from up-stream reads that we won't need again.
// We keep reading reads into the buffer until we reach the first downstream
// from the current subject.
while (advanceRead()) { // if true sets 'read' to the next in the stream.
readBuffer.add(read.getAssignedStart(), read.getEnd(), read);
if (read.getAssignedStart() > locus.getEnd()) {
break;
}
}
// Now we compose the case given the locus and all overlapping reads.
final List reads = readBuffer.overlapping((int) locus.getStart(), (int) locus.getEnd());
final DragstrLocusCase newCase = composeDragstrLocusCase(locus, reads, contigLength);
action.accept(newCase);
return true;
} else { // no more loci in the stream, we are finished.
return false;
}
}
@Override
public Spliterator trySplit() {
return null;
}
@Override
public long estimateSize() {
return 0;
}
@Override
public int characteristics() {
return 0;
}
}, false);
}
private static void outputSiteDetails(final PrintWriter writer, final DragstrLocusCase caze, final String fate) {
writer.println(Utils.join("\t", "" + caze.getLocus().getChromosomeIndex() + ':' + (caze.getLocus().getStart() - 1),
caze.getLocus().getPeriod(),
caze.getLocus().getRepeats(),
caze.getDepth(),
caze.getIndels(),
caze.getMinMQ(),
caze.getNSup(),
fate));
}
private Stream readStream(final ReadsDataSource source, final SimpleInterval interval) {
final Stream unfiltered = interval == null ? Utils.stream(source) : Utils.stream(source.query(interval));
return unfiltered
.filter(read -> (read.getFlags() & DISCARD_FLAG_VALUE) == 0 && read.getAssignedStart() <= read.getEnd())
.map(EXTENDED_MQ_READ_TRANSFORMER);
}
// flags for the reads that are to be discarded from analyses.
private static final EnumSet DISCARD_FLAGS = EnumSet.of(
SAMFlag.READ_UNMAPPED, SAMFlag.SECONDARY_ALIGNMENT, SAMFlag.READ_FAILS_VENDOR_QUALITY_CHECK);
private static final int DISCARD_FLAG_VALUE = DISCARD_FLAGS.stream().mapToInt(SAMFlag::intValue).sum();
private DragstrLocusCase composeDragstrLocusCase(final DragstrLocus locus, final List rawReads, final long contigLength) {
return rawReads.stream()
.collect(DragstrLocusCaseCollector.create(locus, hyperParameters.strPadding, contigLength));
}
/**
* Sets of reads that for the intent and proposes of this model are equivalent assuming that they are mapped on the same
* location; we don't check for that.
*/
private static class EquivalentReadSet {
private GATKRead example;
private int size;
public boolean belongs(final GATKRead read) {
return (read.isSupplementaryAlignment() == example.isSupplementaryAlignment()
&& read.getMappingQuality() == example.getMappingQuality()
&& read.getCigar().equals(example.getCigar()));
}
public static int hashCode(final GATKRead read) {
return (((Boolean.hashCode(read.isSupplementaryAlignment()) * 31) + read.getMappingQuality() * 31) + read.getCigar().hashCode());
}
public int hashCode() {
return hashCode(example);
}
private EquivalentReadSet(final GATKRead read) {
example = read;
size = 1;
}
public static EquivalentReadSet of(final GATKRead read) {
Utils.nonNull(read);
return new EquivalentReadSet(read);
}
public void increase(final int inc) {
size += inc;
}
public int getStart() {
return example.getStart();
}
public int getEnd() {
return example.getEnd();
}
public boolean isSupplementaryAlignment() {
return example.isSupplementaryAlignment();
}
public int size() {
return size;
}
public int getMappingQuality() {
return example.getMappingQuality();
}
public Iterable extends CigarElement> getCigar() {
return example.getCigar();
}
}
/**
* Simple read-buffer implementation.
*/
private static class ShardReadBuffer extends IntervalTree> {
private static Int2ObjectMap mergeEquivalentReadSets(final Int2ObjectMap left,
final Int2ObjectMap right) {
// receiver is the map that will collect the output, perhaps one of the inputs.
// donor is the other map.
// 1 size maps are unmodifiable singletons so they only can be donors.
final Int2ObjectMap receiver, donor;
if (left.size() > 1) { //
receiver = left; donor = right;
} else if (right.size() > 1) {
receiver = right; donor = left;
} else {
receiver = new Int2ObjectOpenHashMap<>(left);
donor = right;
}
for (final EquivalentReadSet e2 : donor.values()) {
final EquivalentReadSet e1 = receiver.get(e2.hashCode());
if (e1 == null) { // if not in the receiver we simply copy it over.
receiver.put(e2.hashCode(), e2);
} else { // if present we increase the count.
e1.increase(e2.size());
}
}
return receiver;
}
public void add(final int start, final int end, final GATKRead elem) {
merge(start, end, Int2ObjectMaps.singleton(EquivalentReadSet.hashCode(elem), EquivalentReadSet.of(elem)),
ShardReadBuffer::mergeEquivalentReadSets);
}
void removeUpstreamFrom(final int start) {
final Iterator>> it = iterator();
while (it.hasNext()) {
final Node> node = it.next();
if (node.getStart() >= start) {
break;
} else if (node.getEnd() < start) {
it.remove();
}
}
}
public List overlapping(final int start, final int end) {
Iterator>> it = this.overlappers(start, end);
if (!it.hasNext()) {
return Collections.emptyList();
} else {
final List result = new ArrayList<>();
do {
final Node> node = it.next();
result.addAll(node.getValue().values());
} while (it.hasNext());
return result;
}
}
}
/**
* Simple 2-int tuple to hold a period and repeat-length pair.
*/
private static class PeriodAndRepeatLength {
private final int period;
private final int repeatLength;
private PeriodAndRepeatLength(final int period, final int repeatLength) {
this.period = period;
this.repeatLength = repeatLength;
}
private static PeriodAndRepeatLength of(final int period, final int repeat) {
return new PeriodAndRepeatLength(period, repeat);
}
@Override
public String toString() {
return "(" + period + ',' + repeatLength + ')';
}
}
}