
org.broadinstitute.hellbender.tools.copynumber.segmentation.MultisampleMultidimensionalKernelSegmenter Maven / Gradle / Ivy
The newest version!
package org.broadinstitute.hellbender.tools.copynumber.segmentation;
import htsjdk.samtools.util.Locatable;
import htsjdk.samtools.util.OverlapDetector;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberArgumentValidationUtils;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.AbstractLocatableCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.AllelicCountCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.CopyRatioCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.SimpleIntervalCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.LocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.AllelicCount;
import org.broadinstitute.hellbender.tools.copynumber.utils.segmentation.KernelSegmenter;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
/**
* Segments copy-ratio data and/or alternate-allele-fraction data from one or more samples using kernel segmentation.
* Copy-ratio intervals and/or allele-fraction sites must be identical in all samples. Segments do not span chromosomes.
* If both types of data are provided, only the first allele-fraction site in each copy-ratio interval is used, and
* the alternate-allele fraction in copy-ratio intervals that do not contain any sites is imputed to be balanced at 0.5.
*
* @author Samuel Lee <[email protected]>
*/
public final class MultisampleMultidimensionalKernelSegmenter {
private static final Logger logger = LogManager.getLogger(MultisampleMultidimensionalKernelSegmenter.class);
private enum Mode {
COPY_RATIO_ONLY, ALLELE_FRACTION_ONLY, COPY_RATIO_AND_ALLELE_FRACTION
}
private static final int MIN_NUM_POINTS_REQUIRED_PER_CHROMOSOME = 10;
//assume alternate-allele fraction is 0.5 for missing data
private static final SimpleInterval DUMMY_INTERVAL = new SimpleInterval("DUMMY", 1, 1);
private static final AllelicCount BALANCED_ALLELIC_COUNT = new AllelicCount(DUMMY_INTERVAL, 1, 1);
//Gaussian kernel for a specified variance; if variance is zero, use a linear kernel
private static final Function> KERNEL =
standardDeviation -> standardDeviation == 0.
? (x, y) -> x * y
: (x, y) -> new NormalDistribution(null, x, standardDeviation).density(y);
private static final class MultidimensionalPoint implements Locatable {
private final SimpleInterval interval;
private final double[] log2CopyRatios;
private final double[] alternateAlleleFractions;
MultidimensionalPoint(final SimpleInterval interval,
final double[] log2CopyRatios,
final double[] alternateAlleleFractions) {
this.interval = interval;
this.log2CopyRatios = log2CopyRatios;
this.alternateAlleleFractions = alternateAlleleFractions;
}
@Override
public String getContig() {
return interval.getContig();
}
@Override
public int getStart() {
return interval.getStart();
}
@Override
public int getEnd() {
return interval.getEnd();
}
}
private final Mode mode;
private final int numSamples;
private final int numPointsCopyRatio;
private final int numPointsAlleleFraction;
private final LocatableMetadata metadata;
private final Map> multidimensionalPointsPerChromosome;
/**
* @param denoisedCopyRatiosPerSample non-empty; all copy-ratio intervals identical across samples;
* number of samples and order identical to that in {@code allelicCountsPerSample};
* pass list of empty {@link CopyRatioCollection}s for {@code ALLELE_FRACTION_ONLY} mode
* @param allelicCountsPerSample non-empty; all allele-fraction sites identical across samples;
* number of samples and order identical to that in {@code denoisedCopyRatiosPerSample};
* pass list of empty {@link AllelicCountCollection}s for {@code COPY_RATIO_ONLY} mode
*/
public MultisampleMultidimensionalKernelSegmenter(final List denoisedCopyRatiosPerSample,
final List allelicCountsPerSample) {
validateInputs(denoisedCopyRatiosPerSample, allelicCountsPerSample);
numSamples = denoisedCopyRatiosPerSample.size();
final CopyRatioCollection denoisedCopyRatiosFirstSample = denoisedCopyRatiosPerSample.get(0);
final AllelicCountCollection allelicCountsFirstSample = allelicCountsPerSample.get(0);
metadata = denoisedCopyRatiosFirstSample.getMetadata();
numPointsCopyRatio = denoisedCopyRatiosFirstSample.size();
numPointsAlleleFraction = allelicCountsFirstSample.size();
if (numPointsAlleleFraction == 0) {
mode = Mode.COPY_RATIO_ONLY;
multidimensionalPointsPerChromosome = IntStream.range(0, numPointsCopyRatio).boxed()
.map(i -> new MultidimensionalPoint(
denoisedCopyRatiosFirstSample.getRecords().get(i).getInterval(),
denoisedCopyRatiosPerSample.stream()
.mapToDouble(cr -> cr.getRecords().get(i).getLog2CopyRatioValue())
.toArray(),
null))
.collect(Collectors.groupingBy(
MultidimensionalPoint::getContig,
LinkedHashMap::new,
Collectors.toList()));
} else if (numPointsCopyRatio == 0) {
mode = Mode.ALLELE_FRACTION_ONLY;
multidimensionalPointsPerChromosome = IntStream.range(0, numPointsAlleleFraction).boxed()
.map(i -> new MultidimensionalPoint(
allelicCountsFirstSample.getRecords().get(i).getInterval(),
null,
allelicCountsPerSample.stream()
.mapToDouble(ac -> ac.getRecords().get(i).getAlternateAlleleFraction())
.toArray()))
.collect(Collectors.groupingBy(
MultidimensionalPoint::getContig,
LinkedHashMap::new,
Collectors.toList()));
} else {
mode = Mode.COPY_RATIO_AND_ALLELE_FRACTION;
final OverlapDetector allelicCountOverlapDetector = allelicCountsFirstSample.getOverlapDetector();
final Comparator comparator = denoisedCopyRatiosFirstSample.getComparator();
final Map allelicSiteToIndexMap = IntStream.range(0, numPointsAlleleFraction).boxed()
.collect(Collectors.toMap(
i -> allelicCountsFirstSample.getRecords().get(i).getInterval(),
Function.identity(),
(u, v) -> {
throw new GATKException.ShouldNeverReachHereException("Cannot have duplicate sites.");
}, //sites should already be distinct
LinkedHashMap::new));
final Map intervalIndexToSiteIndexMap = IntStream.range(0, numPointsCopyRatio).boxed()
.collect(Collectors.toMap(
Function.identity(),
i -> allelicCountOverlapDetector.getOverlaps(denoisedCopyRatiosFirstSample.getRecords().get(i)).stream()
.map(AllelicCount::getInterval)
.min(comparator::compare)
.map(allelicSiteToIndexMap::get)
.orElse(-1),
(u, v) -> {
throw new GATKException.ShouldNeverReachHereException("Cannot have duplicate indices.");
},
LinkedHashMap::new));
final int numAllelicCountsToUse = (int) intervalIndexToSiteIndexMap.values().stream()
.filter(i -> i != -1)
.count();
logger.info(String.format("Using first allelic-count site in each copy-ratio interval (%d / %d) for multidimensional segmentation...",
numAllelicCountsToUse, numPointsAlleleFraction));
multidimensionalPointsPerChromosome = IntStream.range(0, numPointsCopyRatio).boxed()
.map(i -> new MultidimensionalPoint(
denoisedCopyRatiosFirstSample.getRecords().get(i).getInterval(),
denoisedCopyRatiosPerSample.stream()
.mapToDouble(denoisedCopyRatios -> denoisedCopyRatios.getRecords().get(i).getLog2CopyRatioValue())
.toArray(),
allelicCountsPerSample.stream()
.map(allelicCounts -> intervalIndexToSiteIndexMap.get(i) != -1
? allelicCounts.getRecords().get(intervalIndexToSiteIndexMap.get(i))
: BALANCED_ALLELIC_COUNT)
.mapToDouble(AllelicCount::getAlternateAlleleFraction)
.toArray()))
.collect(Collectors.groupingBy(
MultidimensionalPoint::getContig,
LinkedHashMap::new,
Collectors.toList()));
}
}
private static void validateInputs(final List denoisedCopyRatiosPerSample,
final List allelicCountsPerSample) {
Utils.nonEmpty(denoisedCopyRatiosPerSample);
Utils.nonEmpty(allelicCountsPerSample);
Utils.validateArg(denoisedCopyRatiosPerSample.size() == allelicCountsPerSample.size(),
"Number of copy-ratio and allelic-count collections must be equal.");
Utils.validateArg(IntStream.range(0, denoisedCopyRatiosPerSample.size())
.allMatch(i -> denoisedCopyRatiosPerSample.get(i).getMetadata().equals(allelicCountsPerSample.get(i).getMetadata())),
"Metadata do not match across copy-ratio and allelic-count collections for the samples. " +
"Check that the sample orders for the corresponding inputs are identical.");
CopyNumberArgumentValidationUtils.getValidatedSequenceDictionary(
Stream.of(denoisedCopyRatiosPerSample, allelicCountsPerSample)
.flatMap(Collection::stream)
.toArray(AbstractLocatableCollection[]::new));
Utils.validateArg((int) denoisedCopyRatiosPerSample.stream()
.map(CopyRatioCollection::getIntervals)
.distinct()
.count() == 1,
"Copy-ratio intervals must be identical across all samples.");
Utils.validateArg((int) allelicCountsPerSample.stream()
.map(AllelicCountCollection::getIntervals)
.distinct()
.count() == 1,
"Allelic-count sites must be identical across all samples.");
}
/**
* Segments the internally held {@link CopyRatioCollection} and {@link AllelicCountCollection}
* using a separate {@link KernelSegmenter} for each chromosome.
* @param kernelVarianceCopyRatio variance of the Gaussian kernel used for copy-ratio data;
* if zero, a linear kernel is used instead
* @param kernelVarianceAlleleFraction variance of the Gaussian kernel used for allele-fraction data;
* if zero, a linear kernel is used instead
* @param kernelScalingAlleleFraction relative scaling S of the kernel K_AF for allele-fraction data
* to the kernel K_CR for copy-ratio data;
* the total kernel is K_CR + S * K_AF
*/
public SimpleIntervalCollection findSegmentation(final int maxNumSegmentsPerChromosome,
final double kernelVarianceCopyRatio,
final double kernelVarianceAlleleFraction,
final double kernelScalingAlleleFraction,
final int kernelApproximationDimension,
final List windowSizes,
final double numChangepointsPenaltyLinearFactor,
final double numChangepointsPenaltyLogLinearFactor) {
ParamUtils.isPositive(maxNumSegmentsPerChromosome, "Maximum number of segments must be positive.");
ParamUtils.isPositiveOrZero(kernelVarianceCopyRatio, "Variance of copy-ratio Gaussian kernel must be non-negative (if zero, a linear kernel will be used).");
ParamUtils.isPositiveOrZero(kernelVarianceAlleleFraction, "Variance of allele-fraction Gaussian kernel must be non-negative (if zero, a linear kernel will be used).");
ParamUtils.isPositiveOrZero(kernelScalingAlleleFraction, "Scaling of allele-fraction Gaussian kernel must be non-negative.");
ParamUtils.isPositive(kernelApproximationDimension, "Dimension of kernel approximation must be positive.");
Utils.validateArg(windowSizes.stream().allMatch(ws -> ws > 0), "Window sizes must all be positive.");
Utils.validateArg(new HashSet<>(windowSizes).size() == windowSizes.size(), "Window sizes must all be unique.");
ParamUtils.isPositiveOrZero(numChangepointsPenaltyLinearFactor,
"Linear factor for the penalty on the number of changepoints per chromosome must be non-negative.");
ParamUtils.isPositiveOrZero(numChangepointsPenaltyLogLinearFactor,
"Log-linear factor for the penalty on the number of changepoints per chromosome must be non-negative.");
final BiFunction kernel = constructKernel(
kernelVarianceCopyRatio, kernelVarianceAlleleFraction, kernelScalingAlleleFraction);
final int maxNumChangepointsPerChromosome = maxNumSegmentsPerChromosome - 1;
logger.info(String.format("Finding changepoints in (%d, %d) data points and %d chromosomes across %d sample(s)...",
numPointsCopyRatio, numPointsAlleleFraction, multidimensionalPointsPerChromosome.size(), numSamples));
//loop over chromosomes, find changepoints, and create segments
final List segments = new ArrayList<>();
for (final String chromosome : multidimensionalPointsPerChromosome.keySet()) {
final List multidimensionalPointsInChromosome = multidimensionalPointsPerChromosome.get(chromosome);
final int numMultidimensionalPointsInChromosome = multidimensionalPointsInChromosome.size();
logger.info(String.format("Finding changepoints in %d data points in chromosome %s...",
numMultidimensionalPointsInChromosome, chromosome));
if (numMultidimensionalPointsInChromosome < MIN_NUM_POINTS_REQUIRED_PER_CHROMOSOME) {
logger.warn(String.format("Number of points in chromosome %s (%d) is less than that required (%d), skipping segmentation...",
chromosome, numMultidimensionalPointsInChromosome, MIN_NUM_POINTS_REQUIRED_PER_CHROMOSOME));
final int start = multidimensionalPointsInChromosome.get(0).getStart();
final int end = multidimensionalPointsInChromosome.get(numMultidimensionalPointsInChromosome - 1).getEnd();
segments.add(new SimpleInterval(chromosome, start, end));
continue;
}
final List changepoints = new ArrayList<>(new KernelSegmenter<>(multidimensionalPointsInChromosome)
.findChangepoints(maxNumChangepointsPerChromosome, kernel, kernelApproximationDimension,
windowSizes, numChangepointsPenaltyLinearFactor, numChangepointsPenaltyLogLinearFactor, KernelSegmenter.ChangepointSortOrder.INDEX));
if (!changepoints.contains(numMultidimensionalPointsInChromosome)) {
changepoints.add(numMultidimensionalPointsInChromosome - 1);
}
int previousChangepoint = -1;
for (final int changepoint : changepoints) {
final int start = multidimensionalPointsPerChromosome.get(chromosome).get(previousChangepoint + 1).getStart();
final int end = multidimensionalPointsPerChromosome.get(chromosome).get(changepoint).getEnd();
segments.add(new SimpleInterval(chromosome, start, end));
previousChangepoint = changepoint;
}
}
logger.info(String.format("Found %d segments in %d chromosomes across %d sample(s).", segments.size(), multidimensionalPointsPerChromosome.size(), numSamples));
return new SimpleIntervalCollection(metadata, segments);
}
private BiFunction constructKernel(final double kernelVarianceCopyRatio,
final double kernelVarianceAlleleFraction,
final double kernelScalingAlleleFraction) {
final double standardDeviationCopyRatio = Math.sqrt(kernelVarianceCopyRatio);
final double standardDeviationAlleleFraction = Math.sqrt(kernelVarianceAlleleFraction);
switch (mode) {
case COPY_RATIO_ONLY:
return (p1, p2) -> {
double sum = 0.;
for (int sampleIndex = 0; sampleIndex < numSamples; sampleIndex++) {
sum += KERNEL.apply(standardDeviationCopyRatio).apply(p1.log2CopyRatios[sampleIndex], p2.log2CopyRatios[sampleIndex]);
}
return sum;
};
case ALLELE_FRACTION_ONLY:
return (p1, p2) -> {
double sum = 0.;
for (int sampleIndex = 0; sampleIndex < numSamples; sampleIndex++) {
sum += KERNEL.apply(standardDeviationAlleleFraction).apply(p1.alternateAlleleFractions[sampleIndex], p2.alternateAlleleFractions[sampleIndex]);
}
return sum;
};
case COPY_RATIO_AND_ALLELE_FRACTION:
return (p1, p2) -> {
double sum = 0.;
for (int sampleIndex = 0; sampleIndex < numSamples; sampleIndex++) {
sum += KERNEL.apply(standardDeviationCopyRatio).apply(p1.log2CopyRatios[sampleIndex], p2.log2CopyRatios[sampleIndex]) +
kernelScalingAlleleFraction * KERNEL.apply(standardDeviationAlleleFraction).apply(p1.alternateAlleleFractions[sampleIndex], p2.alternateAlleleFractions[sampleIndex]);
}
return sum;
};
default:
throw new GATKException.ShouldNeverReachHereException("Encountered unknown Mode.");
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy