org.broadinstitute.hellbender.tools.copynumber.models.MultidimensionalModeller Maven / Gradle / Ivy
The newest version!
package org.broadinstitute.hellbender.tools.copynumber.models;
import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.util.OverlapDetector;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberArgumentValidationUtils;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.AllelicCountCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.CopyRatioCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.ModeledSegmentCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.SimpleIntervalCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SampleLocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SimpleLocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.AllelicCount;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.CopyRatio;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.ModeledSegment;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
/**
* Represents a segmented model for copy ratio and allele fraction.
*
* @author Samuel Lee <[email protected]>
*/
public final class MultidimensionalModeller {
private static final Logger logger = LogManager.getLogger(MultidimensionalModeller.class);
private final SampleLocatableMetadata metadata;
private final CopyRatioCollection denoisedCopyRatios;
private final OverlapDetector copyRatioMidpointOverlapDetector;
private final AllelicCountCollection allelicCounts;
private final OverlapDetector allelicCountOverlapDetector;
private final AlleleFractionPrior alleleFractionPrior;
private CopyRatioModeller copyRatioModeller;
private AlleleFractionModeller alleleFractionModeller;
private SimpleIntervalCollection currentSegments;
private final List modeledSegments = new ArrayList<>();
//similar-segment merging may leave model in a state where it is not properly fit (deciles may be estimated naively)
private boolean isModelFit;
private final int numSamplesCopyRatio;
private final int numBurnInCopyRatio;
private final int numSamplesAlleleFraction;
private final int numBurnInAlleleFraction;
/**
* Constructs a copy-ratio and allele-fraction modeller, specifying number of total samples
* and number of burn-in samples for Markov-Chain Monte Carlo model fitting.
* An initial model fit is performed.
*/
public MultidimensionalModeller(final SimpleIntervalCollection segments,
final CopyRatioCollection denoisedCopyRatios,
final AllelicCountCollection allelicCounts,
final AlleleFractionPrior alleleFractionPrior,
final int numSamplesCopyRatio,
final int numBurnInCopyRatio,
final int numSamplesAlleleFraction,
final int numBurnInAlleleFraction) {
Utils.nonNull(segments);
Utils.nonNull(denoisedCopyRatios);
Utils.nonNull(allelicCounts);
Utils.nonNull(alleleFractionPrior);
ParamUtils.isPositiveOrZero(numBurnInCopyRatio, "Number of burn-in copy-ratio samples must be non-negative.");
Utils.validateArg(numBurnInCopyRatio < numSamplesCopyRatio, "Number of copy-ratio samples must be greater than number of burn-in copy-ratio samples.");
ParamUtils.isPositiveOrZero(numBurnInAlleleFraction, "Number of burn-in allele-fraction samples must be non-negative.");
Utils.validateArg(numBurnInAlleleFraction < numSamplesAlleleFraction, "Number of allele-fraction samples must be greater than number of burn-in allele-fraction samples.");
metadata = CopyNumberArgumentValidationUtils.getValidatedMetadata(denoisedCopyRatios, allelicCounts);
CopyNumberArgumentValidationUtils.getValidatedSequenceDictionary(segments, denoisedCopyRatios, allelicCounts);
ParamUtils.isPositive(segments.size(), "Number of segments must be positive.");
currentSegments = segments;
this.denoisedCopyRatios = denoisedCopyRatios;
copyRatioMidpointOverlapDetector = denoisedCopyRatios.getMidpointOverlapDetector();
this.allelicCounts = allelicCounts;
allelicCountOverlapDetector = allelicCounts.getOverlapDetector();
this.alleleFractionPrior = Utils.nonNull(alleleFractionPrior);
this.numSamplesCopyRatio = numSamplesCopyRatio;
this.numBurnInCopyRatio = numBurnInCopyRatio;
this.numSamplesAlleleFraction = numSamplesAlleleFraction;
this.numBurnInAlleleFraction = numBurnInAlleleFraction;
logger.info("Fitting initial model...");
fitModel();
}
public ModeledSegmentCollection getModeledSegments() {
return new ModeledSegmentCollection(metadata, modeledSegments);
}
/**
* Performs Markov-Chain Monte Carlo model fitting using the
* number of total samples and number of burn-in samples specified at construction.
*/
private void fitModel() {
//perform MCMC to generate posterior samples
logger.info("Fitting copy-ratio model...");
copyRatioModeller = new CopyRatioModeller(denoisedCopyRatios, currentSegments);
copyRatioModeller.fitMCMC(numSamplesCopyRatio, numBurnInCopyRatio);
logger.info("Fitting allele-fraction model...");
alleleFractionModeller = new AlleleFractionModeller(allelicCounts, currentSegments, alleleFractionPrior);
alleleFractionModeller.fitMCMC(numSamplesAlleleFraction, numBurnInAlleleFraction);
//update list of ModeledSegment with new PosteriorSummaries
modeledSegments.clear();
final List segmentMeansPosteriorSummaries =
copyRatioModeller.getSegmentMeansPosteriorSummaries();
final List minorAlleleFractionsPosteriorSummaries =
alleleFractionModeller.getMinorAlleleFractionsPosteriorSummaries();
for (int segmentIndex = 0; segmentIndex < currentSegments.size(); segmentIndex++) {
final SimpleInterval segment = currentSegments.getRecords().get(segmentIndex);
final int numPointsCopyRatio = copyRatioMidpointOverlapDetector.getOverlaps(segment).size();
final int numPointsAlleleFraction = allelicCountOverlapDetector.getOverlaps(segment).size();
final ModeledSegment.SimplePosteriorSummary segmentMeansPosteriorSummary = segmentMeansPosteriorSummaries.get(segmentIndex);
final ModeledSegment.SimplePosteriorSummary minorAlleleFractionPosteriorSummary = minorAlleleFractionsPosteriorSummaries.get(segmentIndex);
modeledSegments.add(new ModeledSegment(
segment, numPointsCopyRatio, numPointsAlleleFraction, segmentMeansPosteriorSummary, minorAlleleFractionPosteriorSummary));
}
isModelFit = true;
}
/**
* @param numSmoothingIterationsPerFit if this is zero, no refitting will be performed between smoothing iterations
*/
public void smoothSegments(final int maxNumSmoothingIterations,
final int numSmoothingIterationsPerFit,
final double smoothingCredibleIntervalThresholdCopyRatio,
final double smoothingCredibleIntervalThresholdAlleleFraction) {
ParamUtils.isPositiveOrZero(maxNumSmoothingIterations,
"The maximum number of smoothing iterations must be non-negative.");
ParamUtils.isPositiveOrZero(smoothingCredibleIntervalThresholdCopyRatio,
"The number of smoothing iterations per fit must be non-negative.");
ParamUtils.isPositiveOrZero(smoothingCredibleIntervalThresholdAlleleFraction,
"The allele-fraction credible-interval threshold for segmentation smoothing must be non-negative.");
logger.info(String.format("Initial number of segments before smoothing: %d", modeledSegments.size()));
//perform iterations of similar-segment merging until all similar segments are merged
for (int numIterations = 1; numIterations <= maxNumSmoothingIterations; numIterations++) {
logger.info(String.format("Smoothing iteration: %d", numIterations));
final int prevNumSegments = modeledSegments.size();
if (numSmoothingIterationsPerFit > 0 && numIterations % numSmoothingIterationsPerFit == 0) {
//refit model after this merge iteration
performSmoothingIteration(smoothingCredibleIntervalThresholdCopyRatio, smoothingCredibleIntervalThresholdAlleleFraction, true);
} else {
//do not refit model after this merge iteration (posterior modes will be identical to posterior medians)
performSmoothingIteration(smoothingCredibleIntervalThresholdCopyRatio, smoothingCredibleIntervalThresholdAlleleFraction, false);
}
if (modeledSegments.size() == prevNumSegments) {
break;
}
}
if (!isModelFit) {
//make sure final model is completely fit (i.e., posterior modes are specified)
fitModel();
}
logger.info(String.format("Final number of segments after smoothing: %d", modeledSegments.size()));
}
/**
* Performs one iteration of similar-segment merging on the list of {@link ModeledSegment} held internally.
* Markov-Chain Monte Carlo model fitting is optionally performed after each iteration using the
* number of total samples and number of burn-in samples specified at construction.
* @param intervalThresholdSegmentMean threshold number of credible intervals for segment-mean similarity
* @param intervalThresholdMinorAlleleFraction threshold number of credible intervals for minor-allele-fraction similarity
* @param doModelFit if true, refit MCMC model after merging
*/
private void performSmoothingIteration(final double intervalThresholdSegmentMean,
final double intervalThresholdMinorAlleleFraction,
final boolean doModelFit) {
logger.info("Number of segments before smoothing iteration: " + modeledSegments.size());
final List mergedSegments = SimilarSegmentUtils.mergeSimilarSegments(
modeledSegments, intervalThresholdSegmentMean, intervalThresholdMinorAlleleFraction);
logger.info("Number of segments after smoothing iteration: " + mergedSegments.size());
currentSegments = new SimpleIntervalCollection(
new SimpleLocatableMetadata(metadata.getSequenceDictionary()),
mergedSegments.stream().map(ModeledSegment::getInterval).collect(Collectors.toList()));
if (doModelFit) {
fitModel();
} else {
modeledSegments.clear();
modeledSegments.addAll(mergedSegments);
isModelFit = false;
}
}
/**
* Writes posterior summaries for the global model parameters to a file.
*/
public void writeModelParameterFiles(final File copyRatioParameterFile,
final File alleleFractionParameterFile) {
Utils.nonNull(copyRatioParameterFile);
Utils.nonNull(alleleFractionParameterFile);
ensureModelIsFit();
logger.info(String.format("Writing posterior summaries for copy-ratio global parameters to %s...", copyRatioParameterFile.getAbsolutePath()));
copyRatioModeller.getGlobalParameterDeciles().write(copyRatioParameterFile);
logger.info(String.format("Writing posterior summaries for allele-fraction global parameters to %s...", alleleFractionParameterFile.getAbsolutePath()));
alleleFractionModeller.getGlobalParameterDeciles().write(alleleFractionParameterFile);
}
@VisibleForTesting
CopyRatioModeller getCopyRatioModeller() {
return copyRatioModeller;
}
@VisibleForTesting
AlleleFractionModeller getAlleleFractionModeller() {
return alleleFractionModeller;
}
private void ensureModelIsFit() {
if (!isModelFit) {
logger.warn("Attempted to write results to file when model was not completely fit. Performing model fit now.");
fitModel();
}
}
/**
* Contains private methods for similar-segment merging.
*/
private static final class SimilarSegmentUtils {
/**
* Returns a new, modifiable list of segments with similar segments (i.e., adjacent segments with both
* segment-mean and minor-allele-fractions posteriors similar; posteriors are similar if the difference between
* posterior central tendencies is less than intervalThreshold times the posterior credible interval of either summary)
* merged. The list of segments is traversed once from beginning to end, and each segment is checked for similarity
* with the segment to the right and merged until it is no longer similar.
* @param intervalThresholdSegmentMean threshold number of credible intervals for segment-mean similarity
* @param intervalThresholdMinorAlleleFraction threshold number of credible intervals for minor-allele-fraction similarity
*/
private static List mergeSimilarSegments(final List segments,
final double intervalThresholdSegmentMean,
final double intervalThresholdMinorAlleleFraction) {
final List mergedSegments = new ArrayList<>(segments);
int index = 0;
while (index < mergedSegments.size() - 1) {
final ModeledSegment segment1 = mergedSegments.get(index);
final ModeledSegment segment2 = mergedSegments.get(index + 1);
if (segment1.getContig().equals(segment2.getContig()) &&
areSimilar(segment1, segment2,
intervalThresholdSegmentMean, intervalThresholdMinorAlleleFraction)) {
mergedSegments.set(index, merge(segment1, segment2));
mergedSegments.remove(index + 1);
index--; //if merge performed, stay on current segment during next iteration
}
index++; //if no merge performed, go to next segment during next iteration
}
return mergedSegments;
}
//checks similarity of posterior summaries to within a credible-interval threshold;
//posterior summaries are similar if the difference between posterior central tendencies is less than
//intervalThreshold times the credible-interval width for both summaries
private static boolean areSimilar(final ModeledSegment.SimplePosteriorSummary summary1,
final ModeledSegment.SimplePosteriorSummary summary2,
final double intervalThreshold) {
if (Double.isNaN(summary1.getDecile50()) || Double.isNaN(summary2.getDecile50())) {
return true;
}
final double absoluteDifference = Math.abs(summary1.getDecile50() - summary2.getDecile50());
return absoluteDifference < intervalThreshold * (summary1.getDecile90() - summary1.getDecile10()) ||
absoluteDifference < intervalThreshold * (summary2.getDecile90() - summary2.getDecile10());
}
//checks similarity of modeled segments to within credible-interval thresholds for segment mean and minor allele fraction
private static boolean areSimilar(final ModeledSegment segment1,
final ModeledSegment segment2,
final double intervalThresholdSegmentMean,
final double intervalThresholdMinorAlleleFraction) {
return areSimilar(segment1.getLog2CopyRatioSimplePosteriorSummary(), segment2.getLog2CopyRatioSimplePosteriorSummary(), intervalThresholdSegmentMean) &&
areSimilar(segment1.getMinorAlleleFractionSimplePosteriorSummary(), segment2.getMinorAlleleFractionSimplePosteriorSummary(), intervalThresholdMinorAlleleFraction);
}
//merges posterior summaries naively by approximating posteriors as normal
private static ModeledSegment.SimplePosteriorSummary merge(final ModeledSegment.SimplePosteriorSummary summary1,
final ModeledSegment.SimplePosteriorSummary summary2) {
if (Double.isNaN(summary1.getDecile50()) && !Double.isNaN(summary2.getDecile50())) {
return summary2;
}
if ((!Double.isNaN(summary1.getDecile50()) && Double.isNaN(summary2.getDecile50())) ||
(Double.isNaN(summary1.getDecile50()) && Double.isNaN(summary2.getDecile50()))) {
return summary1;
}
//use credible half-interval as standard deviation
final double standardDeviation1 = (summary1.getDecile90() - summary1.getDecile10()) / 2.;
final double standardDeviation2 = (summary2.getDecile90() - summary2.getDecile10()) / 2.;
final double variance = 1. / (1. / Math.pow(standardDeviation1, 2.) + 1. / Math.pow(standardDeviation2, 2.));
final double mean =
(summary1.getDecile50() / Math.pow(standardDeviation1, 2.) + summary2.getDecile50() / Math.pow(standardDeviation2, 2.))
* variance;
final double standardDeviation = Math.sqrt(variance);
return new ModeledSegment.SimplePosteriorSummary(mean, mean - standardDeviation, mean + standardDeviation);
}
private static ModeledSegment merge(final ModeledSegment segment1,
final ModeledSegment segment2) {
return new ModeledSegment(mergeSegments(segment1.getInterval(), segment2.getInterval()),
segment1.getNumPointsCopyRatio() + segment2.getNumPointsCopyRatio(),
segment1.getNumPointsAlleleFraction() + segment2.getNumPointsAlleleFraction(),
merge(segment1.getLog2CopyRatioSimplePosteriorSummary(), segment2.getLog2CopyRatioSimplePosteriorSummary()),
merge(segment1.getMinorAlleleFractionSimplePosteriorSummary(), segment2.getMinorAlleleFractionSimplePosteriorSummary()));
}
private static SimpleInterval mergeSegments(final SimpleInterval segment1,
final SimpleInterval segment2) {
Utils.validateArg(segment1.getContig().equals(segment2.getContig()),
String.format("Cannot join segments %s and %s on different chromosomes.", segment1.toString(), segment2.toString()));
final int start = Math.min(segment1.getStart(), segment2.getStart());
final int end = Math.max(segment1.getEnd(), segment2.getEnd());
return new SimpleInterval(segment1.getContig(), start, end);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy