All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.broadinstitute.hellbender.tools.dragstr.CalibrateDragstrModel Maven / Gradle / Ivy

The newest version!
package org.broadinstitute.hellbender.tools.dragstr;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.Spliterator;
import java.util.Vector;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.apache.commons.io.output.NullOutputStream;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup;
import org.broadinstitute.hellbender.engine.GATKPath;
import org.broadinstitute.hellbender.engine.GATKTool;
import org.broadinstitute.hellbender.engine.ReadsDataSource;
import org.broadinstitute.hellbender.engine.ReadsPathDataSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCaller;
import org.broadinstitute.hellbender.transformers.DRAGENMappingQualityReadTransformer;
import org.broadinstitute.hellbender.transformers.ReadTransformer;
import org.broadinstitute.hellbender.utils.BinaryTableReader;
import org.broadinstitute.hellbender.utils.IntervalMergingRule;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SequenceDictionaryUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.collections.AutoCloseableCollection;
import org.broadinstitute.hellbender.utils.dragstr.DragstrParamUtils;
import org.broadinstitute.hellbender.utils.dragstr.DragstrParams;
import org.broadinstitute.hellbender.utils.dragstr.STRTableFile;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.reference.AbsoluteCoordinates;

import htsjdk.samtools.CigarElement;
import htsjdk.samtools.CigarOperator;
import htsjdk.samtools.SAMFlag;
import htsjdk.samtools.SAMReadGroupRecord;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SamReaderFactory;
import htsjdk.samtools.util.IntervalTree;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMaps;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import scala.Tuple4;

/**
 * Estimates the parameters for the DRAGstr model for an input sample.
 * 

* This tools takes in the sampling sites generated by {@link ComposeSTRTableFile} on the same reference * as the input sample. *

*

* The end result is a text file containing three parameter tables (GOP, GCP, API) that can be fed * directly to {@link HaplotypeCaller} --dragstr-params-path. *

*/ @CommandLineProgramProperties( summary = "estimates the parameters for the DRAGstr model for the input sample using the output of the ComposeSTRTable tool", oneLineSummary = "estimates the parameters for the DRAGstr model", programGroup = ShortVariantDiscoveryProgramGroup.class ) @DocumentedFeature public class CalibrateDragstrModel extends GATKTool { public static final String STR_TABLE_PATH_SHORT_NAME = "str"; public static final String STR_TABLE_PATH_FULL_NAME = "str-table-path"; public static final String PARALLEL_FULL_NAME = "parallel"; public static final String THREADS_FULL_NAME = "threads"; public static final String SHARD_SIZE_FULL_NAME = "shard-size"; public static final String DOWN_SAMPLE_SIZE_FULL_NAME = "down-sample-size"; public static final String DEBUG_SITES_OUTPUT_FULL_NAME = "debug-sites-output"; public static final String FORCE_ESTIMATION_FULL_NAME = "force-estimation"; public static final int DEFAULT_SHARD_SIZE = 1_000_000; public static final int DEFAULT_DOWN_SAMPLE_SIZE = 4096; public static final int SYSTEM_SUGGESTED_THREAD_NUMBER = 0; public static final int MINIMUM_SHARD_SIZE = 100; public static final int MINIMUM_DOWN_SAMPLE_SIZE = 512; @ArgumentCollection private DragstrHyperParameters hyperParameters = new DragstrHyperParameters(); @Argument(shortName=STR_TABLE_PATH_SHORT_NAME, fullName=STR_TABLE_PATH_FULL_NAME, doc="location of the zip that contains the sampling sites for the reference") private GATKPath strTablePath = null; @Argument(fullName=PARALLEL_FULL_NAME, doc="run alignment data collection and estimation in parallel", optional = true) private boolean runInParallel = false; @Argument(fullName=THREADS_FULL_NAME, minValue = SYSTEM_SUGGESTED_THREAD_NUMBER, doc="suggested number of parallel threads to perform the estimation, " + "the default 0 leave it up to the VM to decide. When set to more than 1, this will activate parallel in the absence of --parallel", optional = true) private int threads = SYSTEM_SUGGESTED_THREAD_NUMBER; @Argument(fullName=SHARD_SIZE_FULL_NAME, doc="when running in parallel this is the suggested shard size in base pairs. " + "The actual shard-size may vary to adapt to small contigs and the requested number of threads", minValue = MINIMUM_SHARD_SIZE, optional = true) private int shardSize = DEFAULT_SHARD_SIZE; @Argument(fullName=DOWN_SAMPLE_SIZE_FULL_NAME, doc="Targeted maximum number of cases per combination period repeat count, " + "the larger the more precise but also the slower estimation.", minValue = MINIMUM_DOWN_SAMPLE_SIZE, optional = true) private int downsampleSize = DEFAULT_DOWN_SAMPLE_SIZE; @Argument(fullName= StandardArgumentDefinitions.OUTPUT_LONG_NAME, shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, doc = "where to write the parameter output file.") private GATKPath output = null; @Argument(fullName= DEBUG_SITES_OUTPUT_FULL_NAME, doc = "table with information gather on the samples sites. Includes what sites were downsampled, disqualified or accepted for parameter estimation", optional = true) private String sitesOutput = null; @Argument(fullName= FORCE_ESTIMATION_FULL_NAME, doc = "for testing purpose only; force parameter estimation even with few datapoints available", optional = true) private boolean forceEstimation = false; private SAMSequenceDictionary dictionary; private SamReaderFactory factory; public static final ReadTransformer EXTENDED_MQ_READ_TRANSFORMER = new DRAGENMappingQualityReadTransformer(); @Override public boolean requiresReference() { return true; } @Override public boolean requiresReads() { return true; } @Override protected void onStartup() { super.onStartup(); hyperParameters.validate(); dictionary = directlyAccessEngineReadsDataSource().getSequenceDictionary(); factory = makeSamReaderFactory(); if (runInParallel) { if (threads == 1) { logger.warn("parallel processing was requested but the number of threads was set to 1"); } } else if (threads > 1) { runInParallel = true; } if (runInParallel) { if (threads == 0) { logger.info("Running in parallel using the system suggested default thread count: " + Runtime.getRuntime().availableProcessors()); } else { logger.info("Running in parallel using the requested number of threads: " + threads); } } } @Override public void traverse() { hyperParameters.validate(); dictionary = getBestAvailableSequenceDictionary(); final List readGroups = hasReads() ? getHeaderForReads().getReadGroups() : Collections.emptyList(); final List readGroupIds = readGroups.stream() .map(SAMReadGroupRecord::getId) .collect(Collectors.toList()); final List sampleNames = readGroups.stream() .map(SAMReadGroupRecord::getSample) .distinct().collect(Collectors.toList()); final Optional sampleName = resolveSampleName(sampleNames); try (final PrintWriter sitesOutputWriter = openSitesOutputWriter(sitesOutput); final STRTableFile strTable = STRTableFile.open(strTablePath)) { checkSequenceDictionaryCompatibility(dictionary, strTable.dictionary()); final StratifiedDragstrLocusCases allSites; final List intervals = getTraversalIntervals(); runInParallel |= threads > 1; if (runInParallel) { if (threads == 1) { logger.warn("parallel processing was requested but the number of threads was set to 1"); } allSites = collectCaseStatsParallel(intervals, shardSize, strTable); } else { allSites = collectCaseStatsSequencial(intervals, strTable); } logSiteCounts(allSites, "all loci/cases"); final StratifiedDragstrLocusCases downSampledSites = downSample(allSites, strTable, sitesOutputWriter); logSiteCounts(downSampledSites, "all downsampled (kept) loci/cases"); final StratifiedDragstrLocusCases finalSites = downSampledSites.qualifyingOnly(hyperParameters.minDepth, hyperParameters.minMQ, 0); logSiteCounts(finalSites, "all qualifying loci/cases"); outputDownSampledSiteDetails(downSampledSites, sitesOutputWriter, hyperParameters.minDepth, hyperParameters.minMQ, 0); printOutput(finalSites, sampleName.orElse(null), readGroupIds); } } private void printOutput(final StratifiedDragstrLocusCases finalSites, final String sampleName, final List readGroups) { final boolean enoughCases = isThereEnoughCases(finalSites); final boolean usingDefaults = !enoughCases && !forceEstimation; final Object[] annotations = { "sample", (sampleName == null ? "" : sampleName), "readGroups", (readGroups.isEmpty() ? "" : Utils.join(", ", readGroups)), "estimatedOrDefaults", (usingDefaults ? "defaults" : (enoughCases ? "estimated" : "estimatedByForce")), "commandLine", getCommandLine() }; if (!usingDefaults) { if (!enoughCases) { logger.warn("Forcing parameters estimation using sampled down cases as requested"); } else { logger.info("Estimating parameters using sampled down cases"); } final DragstrParams estimate = estimateParams(finalSites); logger.info("Done with estimation, printing output"); DragstrParamUtils.print(estimate, output, annotations); } else { logger.warn("Not enough cases to estimate parameters, using defaults"); DragstrParamUtils.print(DragstrParams.DEFAULT, output, annotations); } } private Optional resolveSampleName(List sampleNames) { if (sampleNames.size() > 1) { throw new GATKException("the input alignment(s) have more than one sample: " + String.join(", ", sampleNames)); } else if (sampleNames.isEmpty() || sampleNames.get(0) == null) { logger.warn("there is no sample id in the alignment header, assuming that all reads and read/groups make reference to the same anonymous sample"); return Optional.empty(); } else { return Optional.of(sampleNames.get(0)); } } private void checkSequenceDictionaryCompatibility(final SAMSequenceDictionary reference, final SAMSequenceDictionary strTable) { final SequenceDictionaryUtils.SequenceDictionaryCompatibility compatibility = SequenceDictionaryUtils.compareDictionaries(reference, strTable, false); switch (compatibility) { case IDENTICAL: return; case SUPERSET: return; // probably these two below aren't ever be returned since we ask for no check on order but // adding them it just in case case NON_CANONICAL_HUMAN_ORDER: return; // we don't care about the order. case OUT_OF_ORDER: return; // we don't care about the order. default: throw new GATKException("the reference and str-table sequence dictionary are incompatible: " + compatibility); } } @SuppressWarnings("deprecation") private PrintWriter openSitesOutputWriter(final String sitesOutput) { return sitesOutput == null ? new PrintWriter(NullOutputStream.NULL_OUTPUT_STREAM) : new PrintWriter(BucketUtils.createFile(sitesOutput)); } private void outputDownSampledSiteDetails(final StratifiedDragstrLocusCases finalSites, final PrintWriter writer, final int minDepth, final int samplingMinMQ, final int maxSup) { if (sitesOutput != null) { for (final DragstrLocusCases[] periodCases : finalSites.perPeriodAndRepeat) { for (final DragstrLocusCases repeatCases : periodCases) { for (final DragstrLocusCase caze : repeatCases) { outputSiteDetails(writer, caze, caze.qualifies(minDepth, samplingMinMQ, maxSup) ? "used" : "skipped"); } } } } } /** * Holds the minimum counts for each period, repeat-length combo. * If there is lack of data for any of these we use the default param * tables. Missing values, row (periods) or columns (repeat-length) are * interpreted as 0. */ private static final int[][] MINIMUM_CASES_BY_PERIOD_AND_LENGTH = // @formatter:off ; prevents code reformatting by IntelliJ // if enabled: // Preferences > Editor > Code Style > Formatter Control // run-length: // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10+ // period { {}, {0, 200, 200, 200, 200, 200, 200, 200, 200, 200, 0}, // 1 {0, 0, 200, 200, 200, 200, 0, 0, 0, 0, 0}, // 2 {0, 0, 200, 200, 200, 0, 0, 0, 0, 0, 0}, // 3 {0, 0, 200, 200, 0, 0, 0, 0, 0, 0, 0}, // 4 {0, 0, 200, 0, 0, 0, 0, 0, 0, 0, 0}, // 5 {0, 0, 200, 0, 0, 0, 0, 0, 0, 0, 0}, // 6 {0, 0, 200, 0, 0, 0, 0, 0, 0, 0, 0}, // 7 {0, 0, 200, 0, 0, 0, 0, 0, 0, 0, 0}, // 8 }; // zeros to the right are actually not necessary, but add them to make it look more like a matrix. // @formatter:on /** * Check that a minimum number of cases are available in key bins (combo period, repeat). */ private boolean isThereEnoughCases(final StratifiedDragstrLocusCases allSites) { // period 1, repeat length 1 to 9 (inclusive) final int[][] MCBL = MINIMUM_CASES_BY_PERIOD_AND_LENGTH; final int maxP = Math.min(hyperParameters.maxPeriod, MCBL.length - 1); final List> failingCombos = new ArrayList<>(10); for (int i = 1; i <= maxP; i++) { final int maxL = Math.min(hyperParameters.maxRepeatLength, MCBL[i].length - 1); for (int j = 1; j <= maxL; j++) { if (allSites.get(i, j).size() < MCBL[i][j]) { failingCombos.add(new Tuple4<>(i, j, allSites.get(i, j).size(), MCBL[i][j])); } } } if (failingCombos.isEmpty()) { return true; } else if (forceEstimation) { logger.warn("there is not enough data to proceed to parameter empirical estimation " + "but user requested to force it, so we go ahead"); for (final Tuple4 failingCombo : failingCombos) { logger.warn(String.format("(P=%d, L=%d) count %d is less than minimum required %d ", failingCombo._1(), failingCombo._2(), failingCombo._3(), failingCombo._4())); } return true; } else { logger.warn("there is not enough data to proceed to parameter empirical estimation, using defaults instead"); return false; } } /** * Performs the final estimation step. * @param finalSites the site to use for the estimation. * @return {@code never null}. */ private DragstrParams estimateParams(final StratifiedDragstrLocusCases finalSites) { final DragstrParametersEstimator estimator = new DragstrParametersEstimator(hyperParameters); return runInParallel ? Utils.runInParallel(threads, () -> estimator.estimate(finalSites)) : estimator.estimate(finalSites); } /** * Downsample sites so that at most as many as {@link #downsampleSize} cases remain for each period and repeat-length combination. * @param allSites the sites to downsample. * @param strTable that contains the decimation table used to generate those sites. * @param sitesOutputWriter an optional per site informattion output argument for debugging purposes. * @return never {@code null}. */ private StratifiedDragstrLocusCases downSample(final StratifiedDragstrLocusCases allSites, final STRTableFile strTable, final PrintWriter sitesOutputWriter) { final STRDecimationTable decimationTable = strTable.decimationTable(); final List prCombos = new ArrayList<>(hyperParameters.maxPeriod * hyperParameters.maxRepeatLength); for (int i = 1; i <= hyperParameters.maxPeriod; i++) { for (int j = 1; j <= hyperParameters.maxRepeatLength; j++) { prCombos.add(PeriodAndRepeatLength.of(i, j)); } } final Stream prCombosStream = runInParallel ? prCombos.parallelStream() : prCombos.stream(); final Stream downsampledStream = prCombosStream .flatMap(combo -> { final DragstrLocusCases all = allSites.perPeriodAndRepeat[combo.period - 1][combo.repeatLength - 1]; final int decimationBit = decimationTable.decimationBit(combo.period, combo.repeatLength); return downSample(all, decimationBit, downsampleSize, sitesOutputWriter).stream(); }); if (runInParallel) { return Utils.runInParallel(threads, () -> downsampledStream.collect(DragstrLocusCaseStratificator.make(hyperParameters.maxPeriod, hyperParameters.maxRepeatLength))); } else { return downsampledStream.collect(DragstrLocusCaseStratificator.make(hyperParameters.maxPeriod, hyperParameters.maxRepeatLength)); } } /** * Pre-calculated decimation masks used depending on the final decimation bit/level. */ private static final long[] DECIMATION_MASKS_BY_BIT = new long[Long.SIZE]; // Code to populate DECIMATION_MASKS_BY_BIT. static { DECIMATION_MASKS_BY_BIT[0] = 1; for (int i = 1, j = 0; i < Long.SIZE; i++, j++) { DECIMATION_MASKS_BY_BIT[i] = DECIMATION_MASKS_BY_BIT[j] << 1; DECIMATION_MASKS_BY_BIT[j] = ~DECIMATION_MASKS_BY_BIT[j]; } DECIMATION_MASKS_BY_BIT[Long.SIZE -1] = ~DECIMATION_MASKS_BY_BIT[Long.SIZE - 1]; } /** * Decimates the collection of locus/cases to the downsample size provided or smaller. *

* Notice that if we need to downsample (the input size is larger than the downsample size provided) * we take care of not counting those cases that have zero-length toward that limit. * This is due to the apparent behaviour in DRAGEN where those are "sort-of" filtered before * decimation as far as meeting the final downsample size limit is concerned. *

*

* They usually would be skipped eventually in post-downsampling filtering but we don't consider * their number here we end up downsampling some period, repeat-length combintions too much * as compare to DRAGEN. *

*

* This behavior in DRAGEN may well change in future releases. *

* @param in input collection of cases to downsample. * @param minDecimationBit The start decimation bit. Usually the input cases collection won't contain any * cases with lower bit set (already decimated). * @param downsampleSize the target size. * @return never {@code null}. At most the return would contain {@code downsampleSize} cases discounting cases with zero depth. It could be empty. */ private DragstrLocusCases downSample(final DragstrLocusCases in, final int minDecimationBit, final int downsampleSize, final PrintWriter sitesOutputWriter) { final int inSize = in.size(); if (inSize <= downsampleSize) { // we already satisfy the imposed size limit so we do nothing. return in; } else { int zeroDepth = 0; final int[] countByFirstDecimatingBit = new int[Long.SIZE - minDecimationBit]; for (final DragstrLocusCase caze: in) { final DragstrLocus locus = caze.getLocus(); final int depth = caze.getDepth(); if (depth <= 0) { // we discount cases with zero depth as these are going to be skipped eventually. zeroDepth++; continue; } long mask = locus.getMask(); for (int j = minDecimationBit; mask != 0 && j < Long.SIZE; j++) { final long newMask = mask & DECIMATION_MASKS_BY_BIT[j]; if (newMask != mask) { countByFirstDecimatingBit[j]++; break; } } } final IntList progressiveSizes = new IntArrayList(Long.SIZE + 1); progressiveSizes.add(inSize); int finalSize = inSize - zeroDepth; progressiveSizes.add(finalSize); long filterMask = 0; for (int j = minDecimationBit; finalSize > downsampleSize && j < Long.SIZE; j++) { finalSize -= countByFirstDecimatingBit[j]; filterMask |= ~DECIMATION_MASKS_BY_BIT[j]; progressiveSizes.add(finalSize); } final DragstrLocusCases discarded = new DragstrLocusCases(finalSize, in.getPeriod(), in.getRepeatLength()); final DragstrLocusCases result = new DragstrLocusCases(in.size() - finalSize, in.getPeriod(), in.getRepeatLength()); for (final DragstrLocusCase caze: in) { final long mask = caze.getLocus().getMask(); if ((mask & filterMask) == 0 & caze.getDepth() > 0) { discarded.add(caze); } else { result.add(caze); } } // Debug-log message format explained: // period repeat-length [x0, x00, x1, x2, x3 ... xN] // where x0 is the input size. // x00 = x0 - #zero depth cases // x1 = x00 - #first round of decimation // x2 = x1 - #second round of decimation. // ... // xN = final size <= downsampleSize logger.debug(() -> "" + in.getPeriod() + " " + in.getRepeatLength() + " " + Arrays.toString(progressiveSizes.toArray())); // we output info about the sites that are discarded: if (sitesOutput != null && result.size() > 0) { synchronized (this) { for (final DragstrLocusCase caze : result) { outputSiteDetails(sitesOutputWriter, caze, "downsampled-out"); } } } return discarded; } } /** * Logs cases counts in a matrix where columns are periods and rows are * repeat length in repeat units. * @param cases the cases whose counts are to be logged. * @param title the title of the debug message. */ private void logSiteCounts(final StratifiedDragstrLocusCases cases, final String title) { if (logger.isDebugEnabled()) { // here it seems pertinent to check to save time if DEBUG is off since // this method is all about debug logging. logger.debug(title); final int[] columnWidths = IntStream.range(1, hyperParameters.maxPeriod + 1).map(period -> { final int max = IntStream.range(1, hyperParameters.maxRepeatLength + 1).map(repeat -> cases.get(period,repeat).size()) .max().orElse(0); return (int) Math.max(7, Math.ceil(Math.log10(max)) + 1); }).toArray(); logger.debug(" " + IntStream.range(0, hyperParameters.maxPeriod).mapToObj(i -> String.format("%-" + columnWidths[i] + "s", (i + 1))).collect(Collectors.joining())); for (int i = 1; i <= hyperParameters.maxRepeatLength; i++) { final int repeat = i; logger.debug(String.format("%-4s", repeat) + " " + IntStream.range(1, hyperParameters.maxPeriod + 1) .mapToObj(period -> String.format("%-" + columnWidths[period - 1] + "s", cases.get(period, repeat).size())).collect(Collectors.joining(""))); } } } private StratifiedDragstrLocusCases collectCaseStatsSequencial(final List intervals, final STRTableFile strTable) { final StratifiedDragstrLocusCases result = StratifiedDragstrLocusCases.make(hyperParameters.maxPeriod, hyperParameters.maxRepeatLength); final ReadsDataSource dataSource = directlyAccessEngineReadsDataSource(); for (final SimpleInterval interval : intervals) { try (final BinaryTableReader reader = strTable.locusReader(interval)) { streamShardCasesStats(interval, readStream(dataSource, interval), reader.stream()) .peek(caze -> progressMeter.update(caze.getLocation(dictionary))) .forEach(result::add); } catch (final IOException ex) { throw new GATKException("problems accessing str-table-file at " + strTablePath); } } return result; } @SuppressWarnings("try") // silences intended use of unreferenced auto-closable within try-resource. private StratifiedDragstrLocusCases collectCaseStatsParallel(final List intervals, final int shardSize, final STRTableFile strTable) { //TODO: instead of the dictionary this should take on the traversal intervals. //TODO: currently, in the user specifies intervals the progress-meter will show that aount of bases at the beginning of the reference // instead. final AbsoluteCoordinates absoluteCoordinates = AbsoluteCoordinates.of(dictionary); final List shards = shardIntervals(intervals, shardSize); final Collection readSources = new Vector<>(threads); final ThreadLocal threadReadSource = ThreadLocal.withInitial( () -> { final ReadsPathDataSource result = new ReadsPathDataSource(readArguments.getReadPaths(), factory); readSources.add(result); return result; }); try (@SuppressWarnings("unused") final AutoCloseableCollection readSourceCloser = new AutoCloseableCollection<>(readSources)) { final AtomicLong numberBasesProcessed = new AtomicLong(0); return Utils.runInParallel(Math.min(threads, shards.size()), () -> StreamSupport.stream(new InterleavingListSpliterator<>(shards), true) .map(shard -> { try (final BinaryTableReader lociReader = strTable.locusReader(shard)) { final ReadsPathDataSource readsSource = threadReadSource.get(); final StratifiedDragstrLocusCases result = streamShardCasesStats(shard, readStream(readsSource, shard), lociReader.stream()) .collect(DragstrLocusCaseStratificator.make(hyperParameters.maxPeriod, hyperParameters.maxRepeatLength)); final int resultSize = result.size(); synchronized (numberBasesProcessed) { final long processed = numberBasesProcessed.updateAndGet(l -> l + shard.size()); progressMeter.update(absoluteCoordinates.toSimpleInterval(processed, 1), resultSize); } return result; } catch (final IOException ex) { throw new GATKException("problems accessing the str-table-file contents at " + strTablePath, ex); } }) .reduce(StratifiedDragstrLocusCases::merge) .orElseGet(() -> new StratifiedDragstrLocusCases( hyperParameters.maxPeriod, hyperParameters.maxRepeatLength))); } } /** * Shards the traversal intervals based on the requested target shard size. * @param raw the unprocessed traversal intervals. * @param shardSize the target shard size in base-pairs. * @return never {@code null}, but perhaps empty if the input was empty. */ private List shardIntervals(final List raw, final int shardSize) { final List preSharded = sortAndMergeOverlappingIntervals(raw, dictionary); final long size = preSharded.stream().mapToLong(SimpleInterval::size).sum(); final List output = new ArrayList<>((int) (preSharded.size() + size / shardSize)); // if less than 1.5 x the desired shard size is left in the current interval we don't split any further: final int shardingSizeThreshold = (int) Math.round(shardSize * 1.5); for (final SimpleInterval in : preSharded) { if (in.size() < shardingSizeThreshold) { output.add(in); } else { int start = in.getStart(); final int inEnd = in.getEnd(); final int stop = in.getEnd() - shardingSizeThreshold + 1; while (start < stop) { final int end = start + shardSize - 1; output.add(new SimpleInterval(in.getContig(), start, end)); start = end + 1; } if (start <= inEnd) { output.add(new SimpleInterval(in.getContig(), start, inEnd)); } } } return output; } /** * If the traversal interval contains some overlaps we need to fix it. */ private List sortAndMergeOverlappingIntervals(final List input, final SAMSequenceDictionary dictionary) { if (isSortedAndHasNoOverlap(input, dictionary)) { return input; } else { final Map> byContig = IntervalUtils.sortAndMergeIntervals(input, dictionary, IntervalMergingRule.ALL); return byContig.keySet().stream() .sorted(Comparator.comparingInt(name -> dictionary.getSequence(name).getSequenceIndex())) .flatMap(name -> byContig.get(name).stream()) .collect(Collectors.toList()); } } private boolean isSortedAndHasNoOverlap(final List input, final SAMSequenceDictionary dictionary) { if (input.isEmpty()) { return true; } else { String prevCtgName = null; int prevCtgIdx = -1; int prevEnd = 0; for (final SimpleInterval interval : input) { final String ctg = interval.getContig(); final int start = interval.getStart(); final int end = interval.getEnd(); if (ctg.equals(prevCtgName)) { if (start <= prevEnd) { return false; } else { prevEnd = end; } } else { final int ctgIdx = dictionary.getSequenceIndex(ctg); if (ctgIdx <= prevCtgIdx) { return false; } else { prevCtgName = ctg; prevCtgIdx = ctgIdx; prevEnd = end; } } } return true; } } private static class DragstrLocusCaseStratificator implements Collector { private final int maxPeriod; private final int maxRepeats; private static DragstrLocusCaseStratificator make(final int maxPeriod, final int maxRepeats) { return new DragstrLocusCaseStratificator(maxPeriod, maxRepeats); } private DragstrLocusCaseStratificator(final int maxPeriod, final int maxRepeats) { this.maxPeriod = maxPeriod; this.maxRepeats = maxRepeats; } @Override public Supplier supplier() { return () -> new StratifiedDragstrLocusCases(maxPeriod, maxRepeats); } @Override public BiConsumer accumulator() { return StratifiedDragstrLocusCases::add; } @Override public BinaryOperator combiner() { return StratifiedDragstrLocusCases::addAll; } @Override public Function finisher() { return a -> a; } @Override public Set characteristics() { return EnumSet.of(Characteristics.IDENTITY_FINISH, Characteristics.UNORDERED); } } /** * Stream collector class define to coalese several stratified locus case collections. */ private static class DragstrLocusCaseCollector implements Collector { private final DragstrLocus locus; private final long strStart; private final long strEnd; private final long strEndPlusOne; private final long paddedStrStart; private final long paddedStrEnd; private int n; private int k; private int minMQ; private int nSup; private DragstrLocusCaseCollector(final DragstrLocus locus, final long strStart, final long strEnd, final long paddedStrStart, final long paddedStrEnd) { this.locus = locus; this.strStart = strStart; this.strEnd = strEnd; this.strEndPlusOne = strEnd + 1; this.paddedStrStart = paddedStrStart; this.paddedStrEnd = paddedStrEnd; n = k = nSup = 0; minMQ = SAMRecord.UNKNOWN_MAPPING_QUALITY; } public static DragstrLocusCaseCollector create(final DragstrLocus locus, final int padding, final long contingLength) { Utils.nonNull(locus); Utils.validateArg(padding >= 0, "padding must be 0 or greater"); Utils.validateArg(contingLength >= 1, "contig length must be strictly positive"); final long strStart = locus.getStart(); final long strEnd = locus.getEnd(); final long paddedStrStart = Math.max(1, strStart - padding); final long paddedStrEnd = Math.min(contingLength, strEnd + padding); return new DragstrLocusCaseCollector(locus, strStart, strEnd, paddedStrStart, paddedStrEnd); } @Override public Supplier supplier() { return () -> new DragstrLocusCaseCollector(locus, strStart, strEnd, paddedStrStart, paddedStrEnd); } @Override public BiConsumer accumulator() { return DragstrLocusCaseCollector::collect; } /** * Adds the relevant stats of the read to the collector based on its overlap * with the STR and the presence of indel events. * * Assumes that the read is mapped to the same contig as the locus, se we don't test * for that. * @param eset the read to collect. */ private void collect(final EquivalentReadSet eset) { final int readStart = eset.getStart(); final int readEnd = eset.getEnd(); final int size = eset.size(); if (readStart <= paddedStrStart && readEnd >= paddedStrEnd) { if (eset.isSupplementaryAlignment()) { nSup += size; } minMQ = Math.min(minMQ, eset.getMappingQuality()); int refPos = readStart; // int lengthDiff = 0; for (final CigarElement ce : eset.getCigar()) { final CigarOperator op = ce.getOperator(); final int length = ce.getLength(); if (op == CigarOperator.I && refPos >= strStart && refPos <= strEndPlusOne) { k += size; //lengthDiff += length; } else if (op == CigarOperator.D && refPos + length - 1 >= strStart && refPos <= strEnd) { k += size; //lengthDiff -= length; } // update refPos and quick end if we have gone beyond the end of the STR. if ((refPos += op.consumesReferenceBases() ? length : 0) > strEndPlusOne) { break; } } n += size; } } private DragstrLocusCaseCollector combineWith(final DragstrLocusCaseCollector other) { Utils.validateArg(other.locus == this.locus, "collectors at different loci cannot be convined"); final DragstrLocusCaseCollector result = new DragstrLocusCaseCollector(locus, strStart, strEnd, paddedStrStart, paddedStrEnd); result.k = k + other.k; result.n = n + other.n; result.nSup = nSup + other.nSup; result.minMQ = Math.min(minMQ, other.minMQ); return result; } private DragstrLocusCase finish() { return DragstrLocusCase.create(locus, n, k, minMQ, nSup); } @Override public BinaryOperator combiner() { return DragstrLocusCaseCollector::combineWith; } @Override public Function finisher() { return DragstrLocusCaseCollector::finish; } @Override public Set characteristics() { return Collections.emptySet(); } } /** * Generates a stream of locus cases for a interval/shard. *

* The returned stream in turn feeds on two streams: (a) stream of reads for the interval and (b) * the str-table entry loci for that interval. *

*

* This are processed in parallel in position order within the shard so that for every given str-table entry we get * all the relevant reads that overlap the STR. *

* @param shard the target shard. * @param reads a stream on the reads in the input shard. * @param loci a stream on the loci in the input shard. * @return never {@code null}, perhaps an empty stream. */ private Stream streamShardCasesStats(final SimpleInterval shard, final Stream reads, final Stream loci) { final int contigLength = dictionary.getSequence(shard.getContig()).getSequenceLength(); return StreamSupport.stream(new Spliterator() { private final Spliterator readSpliterator = reads.spliterator(); private final Spliterator lociSpliterator = loci.spliterator(); private final ShardReadBuffer readBuffer = new ShardReadBuffer(); private GATKRead read; private DragstrLocus locus; /** * Move forward in the reads stream. *

* If it returns {@code true} the read to process is place in {@link #read}; *

* * @return {@code true} iff there is one more read to process. */ private boolean advanceRead() { return readSpliterator.tryAdvance(read -> this.read = read); } /** * Move forward in the locus stream. *

* If it returns {@code true} the locus to process is place in {@link #locus}; *

* * @return {@code true} iff there is one more loci to process. */ private boolean advanceLocus() { return lociSpliterator.tryAdvance(locus -> this.locus = locus); } @Override public boolean tryAdvance(final Consumer action) { if (advanceLocus()) { // if true, sets 'locus' to the next in the stream. readBuffer.removeUpstreamFrom((int) locus.getStart()); // flush the buffer from up-stream reads that we won't need again. // We keep reading reads into the buffer until we reach the first downstream // from the current subject. while (advanceRead()) { // if true sets 'read' to the next in the stream. readBuffer.add(read.getAssignedStart(), read.getEnd(), read); if (read.getAssignedStart() > locus.getEnd()) { break; } } // Now we compose the case given the locus and all overlapping reads. final List reads = readBuffer.overlapping((int) locus.getStart(), (int) locus.getEnd()); final DragstrLocusCase newCase = composeDragstrLocusCase(locus, reads, contigLength); action.accept(newCase); return true; } else { // no more loci in the stream, we are finished. return false; } } @Override public Spliterator trySplit() { return null; } @Override public long estimateSize() { return 0; } @Override public int characteristics() { return 0; } }, false); } private static void outputSiteDetails(final PrintWriter writer, final DragstrLocusCase caze, final String fate) { writer.println(Utils.join("\t", "" + caze.getLocus().getChromosomeIndex() + ':' + (caze.getLocus().getStart() - 1), caze.getLocus().getPeriod(), caze.getLocus().getRepeats(), caze.getDepth(), caze.getIndels(), caze.getMinMQ(), caze.getNSup(), fate)); } private Stream readStream(final ReadsDataSource source, final SimpleInterval interval) { final Stream unfiltered = interval == null ? Utils.stream(source) : Utils.stream(source.query(interval)); return unfiltered .filter(read -> (read.getFlags() & DISCARD_FLAG_VALUE) == 0 && read.getAssignedStart() <= read.getEnd()) .map(EXTENDED_MQ_READ_TRANSFORMER); } // flags for the reads that are to be discarded from analyses. private static final EnumSet DISCARD_FLAGS = EnumSet.of( SAMFlag.READ_UNMAPPED, SAMFlag.SECONDARY_ALIGNMENT, SAMFlag.READ_FAILS_VENDOR_QUALITY_CHECK); private static final int DISCARD_FLAG_VALUE = DISCARD_FLAGS.stream().mapToInt(SAMFlag::intValue).sum(); private DragstrLocusCase composeDragstrLocusCase(final DragstrLocus locus, final List rawReads, final long contigLength) { return rawReads.stream() .collect(DragstrLocusCaseCollector.create(locus, hyperParameters.strPadding, contigLength)); } /** * Sets of reads that for the intent and proposes of this model are equivalent assuming that they are mapped on the same * location; we don't check for that. */ private static class EquivalentReadSet { private GATKRead example; private int size; public boolean belongs(final GATKRead read) { return (read.isSupplementaryAlignment() == example.isSupplementaryAlignment() && read.getMappingQuality() == example.getMappingQuality() && read.getCigar().equals(example.getCigar())); } public static int hashCode(final GATKRead read) { return (((Boolean.hashCode(read.isSupplementaryAlignment()) * 31) + read.getMappingQuality() * 31) + read.getCigar().hashCode()); } public int hashCode() { return hashCode(example); } private EquivalentReadSet(final GATKRead read) { example = read; size = 1; } public static EquivalentReadSet of(final GATKRead read) { Utils.nonNull(read); return new EquivalentReadSet(read); } public void increase(final int inc) { size += inc; } public int getStart() { return example.getStart(); } public int getEnd() { return example.getEnd(); } public boolean isSupplementaryAlignment() { return example.isSupplementaryAlignment(); } public int size() { return size; } public int getMappingQuality() { return example.getMappingQuality(); } public Iterable getCigar() { return example.getCigar(); } } /** * Simple read-buffer implementation. */ private static class ShardReadBuffer extends IntervalTree> { private static Int2ObjectMap mergeEquivalentReadSets(final Int2ObjectMap left, final Int2ObjectMap right) { // receiver is the map that will collect the output, perhaps one of the inputs. // donor is the other map. // 1 size maps are unmodifiable singletons so they only can be donors. final Int2ObjectMap receiver, donor; if (left.size() > 1) { // receiver = left; donor = right; } else if (right.size() > 1) { receiver = right; donor = left; } else { receiver = new Int2ObjectOpenHashMap<>(left); donor = right; } for (final EquivalentReadSet e2 : donor.values()) { final EquivalentReadSet e1 = receiver.get(e2.hashCode()); if (e1 == null) { // if not in the receiver we simply copy it over. receiver.put(e2.hashCode(), e2); } else { // if present we increase the count. e1.increase(e2.size()); } } return receiver; } public void add(final int start, final int end, final GATKRead elem) { merge(start, end, Int2ObjectMaps.singleton(EquivalentReadSet.hashCode(elem), EquivalentReadSet.of(elem)), ShardReadBuffer::mergeEquivalentReadSets); } void removeUpstreamFrom(final int start) { final Iterator>> it = iterator(); while (it.hasNext()) { final Node> node = it.next(); if (node.getStart() >= start) { break; } else if (node.getEnd() < start) { it.remove(); } } } public List overlapping(final int start, final int end) { Iterator>> it = this.overlappers(start, end); if (!it.hasNext()) { return Collections.emptyList(); } else { final List result = new ArrayList<>(); do { final Node> node = it.next(); result.addAll(node.getValue().values()); } while (it.hasNext()); return result; } } } /** * Simple 2-int tuple to hold a period and repeat-length pair. */ private static class PeriodAndRepeatLength { private final int period; private final int repeatLength; private PeriodAndRepeatLength(final int period, final int repeatLength) { this.period = period; this.repeatLength = repeatLength; } private static PeriodAndRepeatLength of(final int period, final int repeat) { return new PeriodAndRepeatLength(period, repeat); } @Override public String toString() { return "(" + period + ',' + repeatLength + ')'; } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy