net.maizegenetics.pangenome.hapCalling.ConvertReadsToPathUsingHMM Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of phg Show documentation
Show all versions of phg Show documentation
PHG - Practical Haplotype Graph
package net.maizegenetics.pangenome.hapCalling;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multiset;
import net.maizegenetics.analysis.imputation.BackwardForwardVariableStateNumber;
import net.maizegenetics.analysis.imputation.EmissionProbability;
import net.maizegenetics.analysis.imputation.ViterbiAlgorithmVariableStateNumber;
import net.maizegenetics.dna.map.Chromosome;
import net.maizegenetics.pangenome.api.*;
import net.maizegenetics.taxa.TaxaList;
import net.maizegenetics.taxa.Taxon;
import org.apache.log4j.Logger;
import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.util.*;
import java.util.stream.Collectors;
public class ConvertReadsToPathUsingHMM {
private static Logger myLogger = Logger.getLogger(ConvertReadsToPathUsingHMM.class);
private Multiset myHapidCounts = null;
private Map myHapidCountMap = null;
private HaplotypeGraph myGraph;
private List pathGammas = null;
private Multimap myReadMap = null;
//parameters
private int minReadsPerRefRange = 0;
private int maxReadsPerRefRangeKB = 10000;
private String myTaxaListString = null;
private TaxaList myTaxaList = null;
private double minTransitionProb = 0.001;
private double probReadMappedCorrectly = 0.99;
private double transitionProbSameTaxon = 0.99;
private String targetTaxon = null;
private boolean removeRangesWithEqualCounts = true;
public ConvertReadsToPathUsingHMM() {
}
/**
* This method filters a HaplotypeGraph and sets the myGraph field of the class to the result.
* The resulting filtered HaplotypeGraph will be used for any subsequent method calls.
* @param graph the HaplotypeGraph to be filtered
*/
public ConvertReadsToPathUsingHMM filterHaplotypeGraph(HaplotypeGraph graph) {
int numberOfReads = myReadMap.entries().stream().mapToInt(ent -> ent.getValue().getCount()).sum();
myLogger.info("Filtering graph based on read mappings for " + numberOfReads + " reads.");
if (numberOfReads == 0) throw new IllegalArgumentException("myReadMap has not reads.");
HaplotypeGraph hapGraph = filterOnTaxa(graph);
FilterGraphPlugin myFilter = new FilterGraphPlugin(null,false);
List rangesToRemove = new ArrayList<>();
int numberOfRangesFromDB = hapGraph.numberOfRanges();
int countOfLowReadCount = 0;
int countOfhighKbCount = 0;
int countOfAllEqual = 0;
for (ReferenceRange range : hapGraph.referenceRanges()) {
Collection readMappings = myReadMap.get(range);
int totalReadCount = readMappings == null? 0 : readMappings.stream().mapToInt(set -> set.getCount()).sum();
double countPerKB = (double) totalReadCount * 1000 / (range.end() - range.start() + 1);
int nNodes = hapGraph.nodes(range).size();
//remove any range with too few reads or too many reads
if (totalReadCount < minReadsPerRefRange) countOfLowReadCount++;
if (countPerKB > maxReadsPerRefRangeKB) countOfhighKbCount++;
if (totalReadCount < minReadsPerRefRange || countPerKB > maxReadsPerRefRangeKB)
rangesToRemove.add(range);
//checks for all counts equal
else if (removeRangesWithEqualCounts && totalReadCount > 0){
boolean allEqual = true;
Multiset hapidCounts = HashMultiset.create();
for (HapIdSetCount mapping : myReadMap.get(range)) {
for (Integer hapid : mapping.getHapIdSet()) hapidCounts.add(hapid, mapping.getCount());
}
//if any nodes have 0 count then allEqual = false
//otherwise, if all nodes have the same count allEqual = true
if (hapidCounts.elementSet().size() < nNodes) {
allEqual = false;
} else {
int firstNodeCount = hapidCounts.entrySet().iterator().next().getCount();
for (Multiset.Entry ent : hapidCounts.entrySet()) {
if (ent.getCount() != firstNodeCount) {
allEqual = false;
break;
}
}
}
if (allEqual) {
rangesToRemove.add(range);
countOfAllEqual++;
}
}
}
myLogger.info(String.format("total ranges = %d, number of ranges removed = %d", numberOfRangesFromDB, rangesToRemove.size()));
myLogger.info(String.format("number of ranges with low read counts = %d, high count per kb = %d, counts all equal = %d",
countOfLowReadCount, countOfhighKbCount, countOfAllEqual));
myFilter.refRanges(rangesToRemove);
myLogger.debug(String.format("before filtering hapgraph: %d nodes.%n", hapGraph.numberOfNodes()));
myGraph = myFilter.filter(hapGraph);
myLogger.debug(String.format("after filtering hapgraph: %d nodes.%n", myGraph.numberOfNodes()));
//if there are no nodes left throw an exception. Determine if the haplotype method has no reads.
if (myGraph.numberOfNodes() < 1) {
//what hapids did the read mappings come from?
List methodsList = ReadMappingUtils.getHaplotypeMethodsForReadMappings(myReadMap, 1000);
String msg = "Method names for read mapping ids: " + methodsList.stream().collect(Collectors.joining(","));
myLogger.info(msg);
throw new IllegalArgumentException("The filtered graph has no nodes.");
}
return this;
}
/**
* This method filters a HaplotypeGraph and sets the myGraph field of the class to the result.
* The resulting filtered HaplotypeGraph will be used for any subsequent method calls.
* @param graph the HaplotypeGraph to be filtered
*/
public ConvertReadsToPathUsingHMM filterHaplotypeGraph(HaplotypeGraph graph, List rangesToKeep) {
HaplotypeGraph hapGraph = filterOnTaxa(graph);
FilterGraphPlugin myFilter = new FilterGraphPlugin(null,false);
List rangesToRemove = new ArrayList<>();
int numberOfRangesFromDB = hapGraph.numberOfRanges();
int countOfLowReadCount = 0;
int countOfhighKbCount = 0;
int countOfAllEqual = 0;
for (ReferenceRange range : hapGraph.referenceRanges()) {
Collection readMappings = myReadMap.get(range);
int totalReadCount = readMappings == null? 0 : readMappings.stream().mapToInt(set -> set.getCount()).sum();
double countPerKB = (double) totalReadCount / (range.end() - range.start() + 1) * 1000;
int nNodes = hapGraph.nodes(range).size();
//remove any range with too few reads, too many reads, or too many nodes
if (totalReadCount < minReadsPerRefRange) countOfLowReadCount++;
if (countPerKB > maxReadsPerRefRangeKB) countOfhighKbCount++;
if (totalReadCount < minReadsPerRefRange || countPerKB > maxReadsPerRefRangeKB)
rangesToRemove.add(range);
else if (rangesToKeep != null && !rangesToKeep.contains(range)) {
rangesToRemove.add(range);
}
else if (removeRangesWithEqualCounts && totalReadCount > 0){
//checks for all counts equal
boolean allEqual = true;
Multiset hapidCounts = HashMultiset.create();
for (HapIdSetCount mapping : myReadMap.get(range)) {
for (Integer hapid : mapping.getHapIdSet()) hapidCounts.add(hapid, mapping.getCount());
}
//if any nodes have 0 count then allEqual = false
//otherwise, if all nodes have the same count allEqual = true
if (hapidCounts.elementSet().size() < nNodes) {
allEqual = false;
} else {
int firstNodeCount = hapidCounts.entrySet().iterator().next().getCount();
for (Multiset.Entry ent : hapidCounts.entrySet()) {
if (ent.getCount() != firstNodeCount) {
allEqual = false;
break;
}
}
}
if (allEqual) {
rangesToRemove.add(range);
countOfAllEqual++;
}
}
}
myLogger.info(String.format("total ranges = %d, number of ranges removed = %d", numberOfRangesFromDB, rangesToRemove.size()));
myLogger.info(String.format("number of ranges with low read counts = %d, high count per kb = %d, counts all equal = %d",
countOfLowReadCount, countOfhighKbCount, countOfAllEqual));
myFilter.refRanges(rangesToRemove);
myLogger.debug(String.format("before filtering hapgraph: %d nodes.%n", hapGraph.numberOfNodes()));
myGraph = myFilter.filter(hapGraph);
myLogger.debug(String.format("after filtering hapgraph: %d nodes.%n", myGraph.numberOfNodes()));
//if there are no nodes left throw an exception. Determine if the haplotype method has no reads.
if (myGraph.numberOfNodes() < 1) {
//what hapids did the read mappings come from?
List methodsList = ReadMappingUtils.getHaplotypeMethodsForReadMappings(myReadMap, 1000);
String msg = "Method names for read mapping ids: " + methodsList.stream().collect(Collectors.joining(","));
myLogger.info(msg);
throw new IllegalArgumentException("The filtered graph has no nodes.");
}
//add missing sequence nodes
myGraph = CreateGraphUtils.addMissingSequenceNodes(myGraph);
return this;
}
public HaplotypeGraph filterOnTaxa(HaplotypeGraph graph) {
int startNumberOfNodes = graph.numberOfNodes();
int startNumberOfTaxa = graph.totalNumberTaxa();
int startNumberOfRanges = graph.numberOfRanges();
if (myTaxaList != null && myTaxaList.size() > 0) {
FilterGraphPlugin myFilter = new FilterGraphPlugin(null,false);
HaplotypeGraph filteredGraph = myFilter.taxaList(myTaxaList).filter(graph);
myLogger.debug(String.format("Numbers before filtering on taxa: nodes = %d, ranges = %d, taxa = %d",
startNumberOfNodes, startNumberOfRanges, startNumberOfTaxa));
myLogger.debug(String.format("Numbers after filtering on taxa: nodes = %d, ranges = %d, taxa = %d",
filteredGraph.numberOfNodes(), filteredGraph.totalNumberTaxa(), filteredGraph.numberOfRanges()));
return filteredGraph;
} else if (myTaxaListString != null) {
FilterGraphPlugin myFilter = new FilterGraphPlugin(null,false);
HaplotypeGraph filteredGraph = myFilter.taxaList(myTaxaListString).filter(graph);
myLogger.debug(String.format("Numbers before filtering on taxa: nodes = %d, ranges = %d, taxa = %d",
startNumberOfNodes, startNumberOfRanges, startNumberOfTaxa));
myLogger.debug(String.format("Numbers after filtering on taxa: nodes = %d, ranges = %d, taxa = %d",
filteredGraph.numberOfNodes(), filteredGraph.totalNumberTaxa(), filteredGraph.numberOfRanges()));
return filteredGraph;
}
return graph;
}
/**
* Writes the names of the taxa in the HaplotypeGraph, myGraph, to System.out
*/
public void listTaxa() {
System.out.println("taxa in graph:");
myGraph.taxaInGraph().stream().forEach(System.out::println);
}
/**
* @return The list of HaplotypeNodes on the most likely Path for a Multiset of hapids. That is,
* given a HaplotypeGraph and a Multiset of hapids, the nodes on the path through the graph that is most likely
* to have generated the multiset of hapids.
*/
public List haplotypeCountsToPath() {
//instantiate emission and transition probabilities
List pathNodes = new ArrayList<>();
for (Chromosome chromosome : myGraph.chromosomes()) {
myLogger.info("Getting path for chromosome " + chromosome.getName());
NavigableMap> rangeToNodesMap = myGraph.tree(chromosome);
EmissionProbability emissionProb = new HaplotypeEmissionProbability(rangeToNodesMap, myReadMap, probReadMappedCorrectly);
ArrayList> anchorNodeList = new ArrayList<>(rangeToNodesMap.values());
ReferenceRangeTransitionProbability transitionProb = new ReferenceRangeTransitionProbability(anchorNodeList, myGraph, minTransitionProb);
int numberOfAnchors = anchorNodeList.size();
byte[] obs = new byte[numberOfAnchors];
int numberOfNodesInFirstRange = rangeToNodesMap.values().iterator().next().size();
double[] probStartNode = startProbabilities(numberOfNodesInFirstRange);
ViterbiAlgorithmVariableStateNumber va = new ViterbiAlgorithmVariableStateNumber(obs, transitionProb, emissionProb, probStartNode);
va.initialize();
va.calculate();
byte[] mostProbableNodes = va.getMostProbableStateSequence();
//convert the results to a HaplotypePath
for (int anc = 0; anc < numberOfAnchors; anc++) {
pathNodes.add(anchorNodeList.get(anc).get(mostProbableNodes[anc]));
}
}
return pathNodes;
}
/**
* @return a list of node probabilities for each range in the input graph.
*/
public List haplotypeCountsToPathProbability() {
pathGammas = new ArrayList<>();
List> pathNodeLists = new ArrayList<>();
for (Chromosome chromosome : myGraph.chromosomes()) {
myLogger.info("Getting path for chromosome " + chromosome.getName());
NavigableMap> rangeToNodesMap = myGraph.tree(chromosome);
pathNodeLists.addAll(rangeToNodesMap.values());
myLogger.info("Extracted graph tree for chromosome " + chromosome.getName());
long startTime = System.currentTimeMillis();
EmissionProbability emissionProb = new HaplotypeEmissionProbability(rangeToNodesMap, myReadMap, probReadMappedCorrectly);
myLogger.info(String.format("emission probability set up in %d ms.", System.currentTimeMillis() - startTime));
myLogger.info(emissionProb.toString());
startTime = System.currentTimeMillis();
ArrayList> anchorNodeList = new ArrayList<>(rangeToNodesMap.values());
ReferenceRangeTransitionProbability transitionProb = new ReferenceRangeTransitionProbability(anchorNodeList, myGraph, minTransitionProb);
myLogger.info(String.format("transition probability set up in %d ms.", System.currentTimeMillis() - startTime));
startTime = System.currentTimeMillis();
int numberOfAnchors = anchorNodeList.size();
int[] obs = new int[numberOfAnchors];
int numberOfNodesInFirstRange = rangeToNodesMap.values().iterator().next().size();
double[] probStartNode = startProbabilities(numberOfNodesInFirstRange);
startTime = System.currentTimeMillis();
BackwardForwardVariableStateNumber bfmodel = new BackwardForwardVariableStateNumber();
bfmodel.emission(emissionProb)
.transition(transitionProb)
.initialStateProbability(probStartNode)
.observations(obs)
.calculateAlpha()
.calculateBeta();
pathGammas.addAll(bfmodel.gamma());
}
return pathGammas;
}
public List nodeListFromProbabilities(double minP, String infoFilename) {
List nodeList = new ArrayList<>();
Iterator probIter = pathGammas.iterator();
Iterator rangeIter = myGraph.referenceRangeList().iterator();
while(probIter.hasNext()) {
double[] probs = probIter.next();
ReferenceRange range = rangeIter.next();
int maxIndex = 0;
for (int i = 1; i < probs.length; i++) {
if (probs[i] > probs[maxIndex]) maxIndex = i;
}
if (probs[maxIndex] >= minP) nodeList.add(myGraph.nodes(range).get(maxIndex));
}
if (infoFilename != null) {
final String TAB = "\t";
try {
PrintWriter pw = new PrintWriter(infoFilename);
pw.println("chr\tstart\thasTarget\tprob\ttaxa");
probIter = pathGammas.iterator();
rangeIter = myGraph.referenceRangeList().iterator();
while(probIter.hasNext()) {
double[] probs = probIter.next();
ReferenceRange range = rangeIter.next();
List myNodes = myGraph.nodes(range);
int nodeCount = 0;
for (HaplotypeNode node : myNodes) {
pw.print(range.chromosome().getName() + TAB); //chr
pw.print(Integer.toString(range.start()) + TAB); //start
pw.print(Boolean.toString(node.taxaList().indexOf(targetTaxon) >= 0) + TAB); //hasTarget
pw.print(probs[nodeCount++] + TAB); //prob
String taxaNames = node.taxaList().stream().map(Taxon::getName).collect(Collectors.joining(","));
pw.print(taxaNames + TAB); //nTaxa
pw.println();
}
}
pw.close();
} catch (FileNotFoundException e) {
myLogger.error(e.getMessage());
myLogger.error(String.format("Unable to open %s for output in ConvertReadsToPathUsingHMM.nodeListFromProbabilities", infoFilename));
}
}
return nodeList;
}
public List nodeListFromProbabilities(double minP) {
return nodeListFromProbabilities(minP, null);
}
public double[] startProbabilities(int numberOfNodes) {
double initp = 1.0 / ((double) numberOfNodes);
double[] startp = new double[numberOfNodes];
Arrays.fill(startp, initp);
return startp;
}
/**
* @param hapGraph a HaplotypeGraph
* @param hapidCounts a Multiset of hapids used to retrieve hapid counts
* @return for each range in the HaplotypeGraph, the probability that a read was mapped to the correct node
*/
public double[] probabilityOfBeingCorrect(HaplotypeGraph hapGraph, Multiset hapidCounts) {
return hapGraph.referenceRanges().stream().mapToDouble(rr -> {
List hnList = hapGraph.nodes(rr);
int[] counts = hnList.stream().map(node -> hapidCounts.count(node.id())).mapToInt(Integer::intValue).toArray();
return nodeCorrectProbability(counts);
}).toArray();
}
/**
* @param hapidCounts a Multiset of hapids used to retrieve hapid counts
* @param rangeToNodesMap a map with ReferenceRange as key and the List of nodes at that range as the associated value
* @return for each range in the HaplotypeGraph, the probability that a read was mapped to the correct node
*/
public double[] probabilityOfBeingCorrect(Multiset hapidCounts, TreeMap> rangeToNodesMap) {
double[] prob = new double[rangeToNodesMap.size()];
int rangeCounter = 0;
for (Map.Entry> entry : rangeToNodesMap.entrySet()) {
int[] counts = entry.getValue().stream().map(node -> hapidCounts.count(node.id())).mapToInt(Integer::intValue).toArray();
prob[rangeCounter++] = nodeCorrectProbability(counts);
}
return prob;
}
/**
* @param hapidCountMap a Multiset of hapids used to retrieve hapid counts
* @param rangeToNodesMap a map with ReferenceRange as key and the List of nodes at that range as the associated value
* @return for each range in the HaplotypeGraph, the probability that a read was mapped to the correct node
*/
public double[] probabilityOfBeingCorrect(Map hapidCountMap, TreeMap> rangeToNodesMap) {
double[] prob = new double[rangeToNodesMap.size()];
int rangeCounter = 0;
for (Map.Entry> entry : rangeToNodesMap.entrySet()) {
int[] counts = entry.getValue().stream().map(node -> hapidCountMap.getOrDefault(node, 0)).mapToInt(Integer::intValue).toArray();
prob[rangeCounter++] = nodeCorrectProbability(counts);
}
return prob;
}
private double nodeCorrectProbability(int[] counts) {
int total = 0;
int max = 0;
for (int cnt:counts) {
total += cnt;
max = Math.max(max, cnt);
}
return (double) max/total;
}
//getters
public HaplotypeGraph filteredGraph() {return myGraph; }
//setters
public ConvertReadsToPathUsingHMM hapidCountMap(Map countmap) {
myHapidCountMap = countmap;
return this;
}
public ConvertReadsToPathUsingHMM minReadsPerRange(int minReads) {
minReadsPerRefRange = minReads;
return this;
}
public ConvertReadsToPathUsingHMM maxReadsPerRangeKB(int maxReads) {
maxReadsPerRefRangeKB = maxReads;
return this;
}
public ConvertReadsToPathUsingHMM taxaFilterList(String taxaNames) {
myTaxaListString = taxaNames;
return this;
}
public ConvertReadsToPathUsingHMM taxaFilterList(TaxaList listOfTaxa) {
myTaxaList = listOfTaxa;
return this;
}
public ConvertReadsToPathUsingHMM probabilityReadMappingCorrect(double probCorrect) {
probReadMappedCorrectly = probCorrect;
return this;
}
public ConvertReadsToPathUsingHMM minTransitionProbability(double minprob) {
minTransitionProb = minprob;
return this;
}
public ConvertReadsToPathUsingHMM transitionProbabilitySameTaxon(double p) {
transitionProbSameTaxon = p;
return this;
}
public ConvertReadsToPathUsingHMM targetTaxon(String taxonName) {
targetTaxon = taxonName;
return this;
}
public ConvertReadsToPathUsingHMM readMap(Multimap readMap) {
myReadMap = readMap;
return this;
}
public ConvertReadsToPathUsingHMM removeRangesWithEqualCounts(boolean remove) {
removeRangesWithEqualCounts = remove;
return this;
}
}