All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.maizegenetics.pangenome.hapCalling.ConvertReadsToPathUsingHMM Maven / Gradle / Ivy

There is a newer version: 1.10
Show newest version
package net.maizegenetics.pangenome.hapCalling;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multiset;

import net.maizegenetics.analysis.imputation.BackwardForwardVariableStateNumber;
import net.maizegenetics.analysis.imputation.EmissionProbability;
import net.maizegenetics.analysis.imputation.ViterbiAlgorithmVariableStateNumber;
import net.maizegenetics.dna.map.Chromosome;
import net.maizegenetics.pangenome.api.*;
import net.maizegenetics.taxa.TaxaList;

import net.maizegenetics.taxa.Taxon;
import org.apache.log4j.Logger;

import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.util.*;
import java.util.stream.Collectors;

public class ConvertReadsToPathUsingHMM {
    private static Logger myLogger = Logger.getLogger(ConvertReadsToPathUsingHMM.class);
    private Multiset myHapidCounts = null;
    private Map myHapidCountMap = null;
    private HaplotypeGraph myGraph;
    private List pathGammas = null;
    private Multimap myReadMap = null;

    //parameters
    private int minReadsPerRefRange = 0;
    private int maxReadsPerRefRangeKB = 10000;
    private String myTaxaListString = null;
    private TaxaList myTaxaList = null;
    private double minTransitionProb = 0.001;
    private double probReadMappedCorrectly = 0.99;
    private double transitionProbSameTaxon = 0.99;
    private String targetTaxon = null;
    private boolean removeRangesWithEqualCounts = true;

    public ConvertReadsToPathUsingHMM() {

    }

    /**
     * This method filters a HaplotypeGraph and sets the myGraph field of the class to the result.
     * The resulting filtered HaplotypeGraph will be used for any subsequent method calls.
     * @param graph the HaplotypeGraph to be filtered
     */
    public ConvertReadsToPathUsingHMM filterHaplotypeGraph(HaplotypeGraph graph) {
        int numberOfReads = myReadMap.entries().stream().mapToInt(ent -> ent.getValue().getCount()).sum();
        myLogger.info("Filtering graph based on read mappings for " + numberOfReads + " reads.");
        if (numberOfReads == 0) throw new IllegalArgumentException("myReadMap has not reads.");

        HaplotypeGraph hapGraph = filterOnTaxa(graph);

        FilterGraphPlugin myFilter = new FilterGraphPlugin(null,false);

        List rangesToRemove = new ArrayList<>();
        int numberOfRangesFromDB = hapGraph.numberOfRanges();
        int countOfLowReadCount = 0;
        int countOfhighKbCount = 0;
        int countOfAllEqual = 0;
        for (ReferenceRange range : hapGraph.referenceRanges()) {
            Collection readMappings = myReadMap.get(range);

            int totalReadCount = readMappings == null? 0 : readMappings.stream().mapToInt(set -> set.getCount()).sum();

            double countPerKB = (double) totalReadCount * 1000 / (range.end() - range.start() + 1);
            int nNodes = hapGraph.nodes(range).size();

            //remove any range with too few reads or too many reads
            if (totalReadCount < minReadsPerRefRange) countOfLowReadCount++;
            if (countPerKB > maxReadsPerRefRangeKB) countOfhighKbCount++;

            if (totalReadCount < minReadsPerRefRange || countPerKB > maxReadsPerRefRangeKB)
                rangesToRemove.add(range);

                //checks for all counts equal
            else if (removeRangesWithEqualCounts && totalReadCount > 0){
                boolean allEqual = true;
                Multiset hapidCounts = HashMultiset.create();
                for (HapIdSetCount mapping : myReadMap.get(range)) {
                    for (Integer hapid : mapping.getHapIdSet()) hapidCounts.add(hapid, mapping.getCount());
                }

                //if any nodes have 0 count then allEqual = false
                //otherwise, if all nodes have the same count allEqual = true
                if (hapidCounts.elementSet().size() < nNodes) {
                    allEqual = false;
                } else {
                    int firstNodeCount = hapidCounts.entrySet().iterator().next().getCount();
                    for (Multiset.Entry ent : hapidCounts.entrySet()) {
                        if (ent.getCount() != firstNodeCount) {
                            allEqual = false;
                            break;
                        }
                    }
                }
                if (allEqual) {
                    rangesToRemove.add(range);
                    countOfAllEqual++;
                }
            }
        }

        myLogger.info(String.format("total ranges = %d, number of ranges removed = %d", numberOfRangesFromDB, rangesToRemove.size()));
        myLogger.info(String.format("number of ranges with low read counts = %d, high count per kb = %d, counts all equal = %d",
                countOfLowReadCount, countOfhighKbCount, countOfAllEqual));
        myFilter.refRanges(rangesToRemove);

        myLogger.debug(String.format("before filtering hapgraph: %d nodes.%n", hapGraph.numberOfNodes()));

        myGraph = myFilter.filter(hapGraph);

        myLogger.debug(String.format("after filtering hapgraph: %d nodes.%n", myGraph.numberOfNodes()));

        //if there are no nodes left throw an exception. Determine if the haplotype method has no reads.
        if (myGraph.numberOfNodes() < 1) {
            //what hapids did the read mappings come from?
            List methodsList = ReadMappingUtils.getHaplotypeMethodsForReadMappings(myReadMap, 1000);
            String msg = "Method names for read mapping ids: " + methodsList.stream().collect(Collectors.joining(","));
            myLogger.info(msg);
            throw new IllegalArgumentException("The filtered graph has no nodes.");
        }

        return this;
    }

    /**
     * This method filters a HaplotypeGraph and sets the myGraph field of the class to the result.
     * The resulting filtered HaplotypeGraph will be used for any subsequent method calls.
     * @param graph the HaplotypeGraph to be filtered
     */
    public ConvertReadsToPathUsingHMM filterHaplotypeGraph(HaplotypeGraph graph, List rangesToKeep) {
        HaplotypeGraph hapGraph = filterOnTaxa(graph);

        FilterGraphPlugin myFilter = new FilterGraphPlugin(null,false);

        List rangesToRemove = new ArrayList<>();

        int numberOfRangesFromDB = hapGraph.numberOfRanges();
        int countOfLowReadCount = 0;
        int countOfhighKbCount = 0;
        int countOfAllEqual = 0;
        for (ReferenceRange range : hapGraph.referenceRanges()) {
            Collection readMappings = myReadMap.get(range);

            int totalReadCount = readMappings == null? 0 : readMappings.stream().mapToInt(set -> set.getCount()).sum();

            double countPerKB = (double) totalReadCount / (range.end() - range.start() + 1) * 1000;
            int nNodes = hapGraph.nodes(range).size();

            //remove any range with too few reads, too many reads, or too many nodes
            if (totalReadCount < minReadsPerRefRange) countOfLowReadCount++;
            if (countPerKB > maxReadsPerRefRangeKB) countOfhighKbCount++;

            if (totalReadCount < minReadsPerRefRange || countPerKB > maxReadsPerRefRangeKB)
                rangesToRemove.add(range);
            else if (rangesToKeep != null && !rangesToKeep.contains(range)) {
                rangesToRemove.add(range);
            }
            else if (removeRangesWithEqualCounts && totalReadCount > 0){
                //checks for all counts equal
                boolean allEqual = true;
                Multiset hapidCounts = HashMultiset.create();
                for (HapIdSetCount mapping : myReadMap.get(range)) {
                    for (Integer hapid : mapping.getHapIdSet()) hapidCounts.add(hapid, mapping.getCount());
                }

                //if any nodes have 0 count then allEqual = false
                //otherwise, if all nodes have the same count allEqual = true
                if (hapidCounts.elementSet().size() < nNodes) {
                    allEqual = false;
                } else {
                    int firstNodeCount = hapidCounts.entrySet().iterator().next().getCount();
                    for (Multiset.Entry ent : hapidCounts.entrySet()) {
                        if (ent.getCount() != firstNodeCount) {
                            allEqual = false;
                            break;
                        }
                    }
                }
                if (allEqual) {
                    rangesToRemove.add(range);
                    countOfAllEqual++;
                }
            }
        }

        myLogger.info(String.format("total ranges = %d, number of ranges removed = %d", numberOfRangesFromDB, rangesToRemove.size()));
        myLogger.info(String.format("number of ranges with low read counts = %d, high count per kb = %d, counts all equal = %d",
                countOfLowReadCount, countOfhighKbCount, countOfAllEqual));
        myFilter.refRanges(rangesToRemove);

        myLogger.debug(String.format("before filtering hapgraph: %d nodes.%n", hapGraph.numberOfNodes()));

        myGraph = myFilter.filter(hapGraph);

        myLogger.debug(String.format("after filtering hapgraph: %d nodes.%n", myGraph.numberOfNodes()));

        //if there are no nodes left throw an exception. Determine if the haplotype method has no reads.
        if (myGraph.numberOfNodes() < 1) {
            //what hapids did the read mappings come from?
            List methodsList = ReadMappingUtils.getHaplotypeMethodsForReadMappings(myReadMap, 1000);
            String msg = "Method names for read mapping ids: " + methodsList.stream().collect(Collectors.joining(","));
            myLogger.info(msg);
            throw new IllegalArgumentException("The filtered graph has no nodes.");
        }

        //add missing sequence nodes
        myGraph = CreateGraphUtils.addMissingSequenceNodes(myGraph);

        return this;

    }

    public HaplotypeGraph filterOnTaxa(HaplotypeGraph graph) {
        int startNumberOfNodes = graph.numberOfNodes();
        int startNumberOfTaxa = graph.totalNumberTaxa();
        int startNumberOfRanges = graph.numberOfRanges();

        if (myTaxaList != null && myTaxaList.size() > 0) {
            FilterGraphPlugin myFilter = new FilterGraphPlugin(null,false);
            HaplotypeGraph filteredGraph = myFilter.taxaList(myTaxaList).filter(graph);
            myLogger.debug(String.format("Numbers before filtering on taxa: nodes = %d, ranges = %d, taxa = %d",
                    startNumberOfNodes, startNumberOfRanges, startNumberOfTaxa));
            myLogger.debug(String.format("Numbers after filtering on taxa: nodes = %d, ranges = %d, taxa = %d",
                    filteredGraph.numberOfNodes(), filteredGraph.totalNumberTaxa(), filteredGraph.numberOfRanges()));
            return filteredGraph;
        } else if (myTaxaListString != null) {
            FilterGraphPlugin myFilter = new FilterGraphPlugin(null,false);
            HaplotypeGraph filteredGraph = myFilter.taxaList(myTaxaListString).filter(graph);
            myLogger.debug(String.format("Numbers before filtering on taxa: nodes = %d, ranges = %d, taxa = %d",
                    startNumberOfNodes, startNumberOfRanges, startNumberOfTaxa));
            myLogger.debug(String.format("Numbers after filtering on taxa: nodes = %d, ranges = %d, taxa = %d",
                    filteredGraph.numberOfNodes(), filteredGraph.totalNumberTaxa(), filteredGraph.numberOfRanges()));
            return filteredGraph;
        }

        return graph;
    }

    /**
     * Writes the names of the taxa in the HaplotypeGraph, myGraph, to System.out
     */
    public void listTaxa() {
        System.out.println("taxa in graph:");
        myGraph.taxaInGraph().stream().forEach(System.out::println);
    }

    /**
     * @return	The list of HaplotypeNodes on the most likely Path for a Multiset of hapids. That is,
     * 	given a HaplotypeGraph and a Multiset of hapids, the nodes on the path through the graph that is most likely
     * 	to have generated the multiset of hapids.
     */
    public List haplotypeCountsToPath() {
        //instantiate emission and transition probabilities

        List pathNodes = new ArrayList<>();
        for (Chromosome chromosome : myGraph.chromosomes()) {
            myLogger.info("Getting path for chromosome " + chromosome.getName());
            NavigableMap> rangeToNodesMap = myGraph.tree(chromosome);

            EmissionProbability emissionProb = new HaplotypeEmissionProbability(rangeToNodesMap, myReadMap, probReadMappedCorrectly);

            ArrayList> anchorNodeList = new ArrayList<>(rangeToNodesMap.values());
            ReferenceRangeTransitionProbability transitionProb = new ReferenceRangeTransitionProbability(anchorNodeList, myGraph, minTransitionProb);


            int numberOfAnchors = anchorNodeList.size();
            byte[] obs = new byte[numberOfAnchors];

            int numberOfNodesInFirstRange = rangeToNodesMap.values().iterator().next().size();
            double[] probStartNode = startProbabilities(numberOfNodesInFirstRange);

            ViterbiAlgorithmVariableStateNumber va = new ViterbiAlgorithmVariableStateNumber(obs, transitionProb, emissionProb, probStartNode);
            va.initialize();
            va.calculate();

            byte[] mostProbableNodes = va.getMostProbableStateSequence();

            //convert the results to a HaplotypePath
            for (int anc = 0; anc < numberOfAnchors; anc++) {
                pathNodes.add(anchorNodeList.get(anc).get(mostProbableNodes[anc]));
            }
        }

        return pathNodes;
    }

    /**
     * @return a list of node probabilities for each range in the input graph.
     */
    public List haplotypeCountsToPathProbability() {

        pathGammas = new ArrayList<>();
        List> pathNodeLists = new ArrayList<>();

        for (Chromosome chromosome : myGraph.chromosomes()) {
            myLogger.info("Getting path for chromosome " + chromosome.getName());
            NavigableMap> rangeToNodesMap = myGraph.tree(chromosome);
            pathNodeLists.addAll(rangeToNodesMap.values());

            myLogger.info("Extracted graph tree for chromosome " + chromosome.getName());

            long startTime =  System.currentTimeMillis();

            EmissionProbability emissionProb = new HaplotypeEmissionProbability(rangeToNodesMap, myReadMap, probReadMappedCorrectly);
            myLogger.info(String.format("emission probability set up in %d ms.", System.currentTimeMillis() - startTime));
            myLogger.info(emissionProb.toString());

            startTime =  System.currentTimeMillis();
            ArrayList> anchorNodeList = new ArrayList<>(rangeToNodesMap.values());
            ReferenceRangeTransitionProbability transitionProb = new ReferenceRangeTransitionProbability(anchorNodeList, myGraph, minTransitionProb);

            myLogger.info(String.format("transition probability set up in %d ms.", System.currentTimeMillis() - startTime));
            startTime =  System.currentTimeMillis();

            int numberOfAnchors = anchorNodeList.size();
            int[] obs = new int[numberOfAnchors];
            int numberOfNodesInFirstRange = rangeToNodesMap.values().iterator().next().size();
            double[] probStartNode = startProbabilities(numberOfNodesInFirstRange);

            startTime =  System.currentTimeMillis();
            BackwardForwardVariableStateNumber bfmodel = new BackwardForwardVariableStateNumber();
            bfmodel.emission(emissionProb)
                    .transition(transitionProb)
                    .initialStateProbability(probStartNode)
                    .observations(obs)
                    .calculateAlpha()
                    .calculateBeta();
            pathGammas.addAll(bfmodel.gamma());

        }

        return pathGammas;
    }

    public List nodeListFromProbabilities(double minP, String infoFilename) {
        List nodeList = new ArrayList<>();

        Iterator probIter = pathGammas.iterator();
        Iterator rangeIter = myGraph.referenceRangeList().iterator();

        while(probIter.hasNext()) {
            double[] probs = probIter.next();
            ReferenceRange range = rangeIter.next();

            int maxIndex = 0;
            for (int i = 1; i < probs.length; i++) {
                if (probs[i] > probs[maxIndex]) maxIndex = i;
            }

            if (probs[maxIndex] >= minP) nodeList.add(myGraph.nodes(range).get(maxIndex));
        }

        if (infoFilename != null) {
            final String TAB = "\t";
            try {
                PrintWriter pw = new PrintWriter(infoFilename);
                pw.println("chr\tstart\thasTarget\tprob\ttaxa");
                probIter = pathGammas.iterator();
                rangeIter = myGraph.referenceRangeList().iterator();
                while(probIter.hasNext()) {
                    double[] probs = probIter.next();
                    ReferenceRange range = rangeIter.next();
                    List myNodes = myGraph.nodes(range);
                    int nodeCount = 0;
                    for (HaplotypeNode node : myNodes) {
                        pw.print(range.chromosome().getName() + TAB); //chr
                        pw.print(Integer.toString(range.start()) + TAB); //start
                        pw.print(Boolean.toString(node.taxaList().indexOf(targetTaxon) >= 0) + TAB); //hasTarget
                        pw.print(probs[nodeCount++] + TAB); //prob
                        String taxaNames = node.taxaList().stream().map(Taxon::getName).collect(Collectors.joining(","));
                        pw.print(taxaNames + TAB); //nTaxa
                        pw.println();
                    }
                }
                pw.close();
            } catch (FileNotFoundException e) {
                myLogger.error(e.getMessage());
                myLogger.error(String.format("Unable to open %s for output in ConvertReadsToPathUsingHMM.nodeListFromProbabilities", infoFilename));
            }
        }

        return nodeList;
    }

    public List nodeListFromProbabilities(double minP) {
        return nodeListFromProbabilities(minP, null);
    }

    public double[] startProbabilities(int numberOfNodes) {
        double initp = 1.0 / ((double) numberOfNodes);
        double[] startp = new double[numberOfNodes];
        Arrays.fill(startp, initp);
        return startp;
    }

    /**
     * @param hapGraph	a HaplotypeGraph
     * @param hapidCounts	a Multiset of hapids used to retrieve hapid counts
     * @return	for each range in the HaplotypeGraph, the probability that a read was mapped to the correct node
     */
    public double[] probabilityOfBeingCorrect(HaplotypeGraph hapGraph, Multiset hapidCounts) {
        return hapGraph.referenceRanges().stream().mapToDouble(rr -> {
            List hnList = hapGraph.nodes(rr);
            int[] counts = hnList.stream().map(node -> hapidCounts.count(node.id())).mapToInt(Integer::intValue).toArray();
            return nodeCorrectProbability(counts);
        }).toArray();
    }

    /**
     * @param hapidCounts	a Multiset of hapids used to retrieve hapid counts
     * @param rangeToNodesMap	a map with ReferenceRange as key and the List of nodes at that range as the associated value
     * @return	for each range in the HaplotypeGraph, the probability that a read was mapped to the correct node
     */
    public double[] probabilityOfBeingCorrect(Multiset hapidCounts, TreeMap> rangeToNodesMap) {
        double[] prob = new double[rangeToNodesMap.size()];
        int rangeCounter = 0;
        for (Map.Entry> entry : rangeToNodesMap.entrySet()) {
            int[] counts = entry.getValue().stream().map(node -> hapidCounts.count(node.id())).mapToInt(Integer::intValue).toArray();
            prob[rangeCounter++] = nodeCorrectProbability(counts);
        }
        return prob;
    }

    /**
     * @param hapidCountMap	a Multiset of hapids used to retrieve hapid counts
     * @param rangeToNodesMap	a map with ReferenceRange as key and the List of nodes at that range as the associated value
     * @return	for each range in the HaplotypeGraph, the probability that a read was mapped to the correct node
     */
    public double[] probabilityOfBeingCorrect(Map hapidCountMap, TreeMap> rangeToNodesMap) {
        double[] prob = new double[rangeToNodesMap.size()];
        int rangeCounter = 0;
        for (Map.Entry> entry : rangeToNodesMap.entrySet()) {
            int[] counts = entry.getValue().stream().map(node -> hapidCountMap.getOrDefault(node, 0)).mapToInt(Integer::intValue).toArray();
            prob[rangeCounter++] = nodeCorrectProbability(counts);
        }
        return prob;
    }

    private double nodeCorrectProbability(int[] counts) {
        int total = 0;
        int max = 0;
        for (int cnt:counts) {
            total += cnt;
            max = Math.max(max, cnt);
        }
        return (double) max/total;
    }

    //getters
    public HaplotypeGraph filteredGraph() {return myGraph; }

    //setters
    public ConvertReadsToPathUsingHMM hapidCountMap(Map countmap) {
        myHapidCountMap = countmap;
        return this;
    }

    public ConvertReadsToPathUsingHMM minReadsPerRange(int minReads) {
        minReadsPerRefRange = minReads;
        return this;
    }

    public ConvertReadsToPathUsingHMM maxReadsPerRangeKB(int maxReads) {
        maxReadsPerRefRangeKB = maxReads;
        return this;
    }

    public ConvertReadsToPathUsingHMM taxaFilterList(String taxaNames) {
        myTaxaListString = taxaNames;
        return this;
    }

    public ConvertReadsToPathUsingHMM taxaFilterList(TaxaList listOfTaxa) {
        myTaxaList = listOfTaxa;
        return this;
    }

    public ConvertReadsToPathUsingHMM probabilityReadMappingCorrect(double probCorrect) {
        probReadMappedCorrectly = probCorrect;
        return this;
    }

    public ConvertReadsToPathUsingHMM minTransitionProbability(double minprob) {
        minTransitionProb = minprob;
        return this;
    }

    public ConvertReadsToPathUsingHMM transitionProbabilitySameTaxon(double p) {
        transitionProbSameTaxon = p;
        return this;
    }

    public ConvertReadsToPathUsingHMM targetTaxon(String taxonName) {
        targetTaxon = taxonName;
        return this;
    }

    public ConvertReadsToPathUsingHMM readMap(Multimap readMap) {
        myReadMap = readMap;
        return this;
    }

    public ConvertReadsToPathUsingHMM removeRangesWithEqualCounts(boolean remove) {
        removeRangesWithEqualCounts = remove;
        return this;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy