All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.maizegenetics.pangenome.api.ReferenceRangeEmissionProbability Maven / Gradle / Ivy

There is a newer version: 1.10
Show newest version
package net.maizegenetics.pangenome.api;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.TreeMap;
import java.util.stream.Collectors;

import org.apache.commons.math3.distribution.BinomialDistribution;

import com.google.common.collect.Multiset;

import net.maizegenetics.analysis.imputation.EmissionProbability;

public class ReferenceRangeEmissionProbability extends EmissionProbability {
	private final NavigableMap> myNodeMap;
	private final List myRanges;
	private final Map hapidCountMap;
	private final Map hapidExclusionCountMap;
	private final double Pcorrect;
	private final METHOD myMethod;
	private final double myAdjustmentFactor;
	
	private int myCurrentAnchor = -1;
	double[] anchorProbabilities;
//	private int minCountPerNode = 1;
	
	public enum METHOD {inclusionOnly, allCounts, allCountsWeighted}
	
	private ReferenceRangeEmissionProbability(NavigableMap> nodes, Map hapidCounts, double probCorrect) {
		myNodeMap = nodes;
		myRanges = new ArrayList(myNodeMap.keySet());
		hapidCountMap = hapidCounts;
		hapidExclusionCountMap = null;
		Pcorrect = probCorrect;
		myMethod = METHOD.inclusionOnly;
		myAdjustmentFactor = 1;
	}
	
	private ReferenceRangeEmissionProbability(NavigableMap> nodes, List> hapidCountList,
			METHOD method, double adjustment, double probCorrect) {
		myNodeMap = nodes;
		myRanges = new ArrayList(myNodeMap.keySet());
		hapidCountMap = hapidCountList.get(0);
		hapidExclusionCountMap = hapidCountList.get(1);
		Pcorrect = probCorrect;
		myMethod = method;
		myAdjustmentFactor = adjustment;
	}
	
	public static ReferenceRangeEmissionProbability getInstanceFromHapidCounts(TreeMap> nodes, Multiset hapidCounts, double probCorrect) {
		Map hapmap = hapidCounts.entrySet().stream().collect(Collectors.toMap(ent -> ent.getElement(), ent -> ent.getCount()));
		return new ReferenceRangeEmissionProbability(nodes, hapmap, probCorrect);
	}
	
	public static ReferenceRangeEmissionProbability getInstanceFromNodeCounts(TreeMap> nodes, Multiset nodeCounts, double probCorrect) {
		Map hapmap = nodeCounts.entrySet().stream().collect(Collectors.toMap(ent -> ent.getElement().id(), ent -> ent.getCount()));
		return new ReferenceRangeEmissionProbability(nodes, hapmap, probCorrect);
	}
	
	public static ReferenceRangeEmissionProbability getInstanceFromHapidCountMap(TreeMap> nodes, Map hapidCountMap, double probCorrect) {
		return new ReferenceRangeEmissionProbability(nodes, hapidCountMap, probCorrect);
	}
	
	//emission probability methods using include and exclude counts:
	//for consensus haplotypes (haplotypes with non-zero total counts)
	//P(obs|state) = pbinom(trials = nExcludes + nIncludes, successes = nExcludes, p = pErr) (Sum)
	// or P(obs|state) = pbinom(trials = max(nExcludes + nIncludes), successes = nExcludes, p = pErr) (Max)
	//
	//for "other" haplotypes (haplotypes with zero total counts)
	//P(obs|state) = average(P(obs|state)) weighted by taxa count (weighted)
	//P(obs|state) = x * average(P(obs|state)) weighted by taxa count, where value of x is chosen to maximize number of correct choices of nodes on path (adjusted)
	
	@Override
	public double getProbObsGivenState(int state, int anchor) {
		switch(myMethod) {
		case inclusionOnly:
			return inclusionOnly(state, anchor);
		case allCounts:
			return allCounts(state, anchor);
		case allCountsWeighted:
			return allCountsWeighted(state, anchor);
		default:
			return Double.NaN;
		}
	}
	
	@Override
	public double getProbObsGivenState(int state, int obs, int node) {
		return getProbObsGivenState(state, node);
	}

	private double inclusionOnly(int state, int anchor) {
		//anchor is the ith ref-range in a list (myRanges)
		//probability is binomial probability of nsuccess out of ntrials with probability Pcorrect[anchor]
		//ntrials = sum Of node counts, nsuccess is the count for this node (state)
		
		int[] nodeCounts = myNodeMap.get(myRanges.get(anchor)).stream()
				.mapToInt(node -> hapidCountMap.getOrDefault(node.id(), 0).intValue())
				.toArray();
		int totalCount = Arrays.stream(nodeCounts).sum();
		int maxCount = Arrays.stream(nodeCounts).max().orElse(0);
		BinomialDistribution binom = new BinomialDistribution(maxCount, Pcorrect);
		return binom.probability(nodeCounts[state]);
	}
	
	public double allCounts_oops(int state, int anchor) {
		//P(obs|state) = pbinom(trials = nExcludes + nIncludes, successes = nExcludes, p = pErr) (Sum)
		//need to deal with nodes with too few observations (include + exclude)
		if (myCurrentAnchor != anchor) {
			myCurrentAnchor = anchor;
			anchorProbabilities(anchor);
		}
		
		int myNodeId = myNodeMap.get(myRanges.get(anchor)).get(state).id();
		int includeCount = hapidCountMap.getOrDefault(myNodeId, 0);
		int excludeCount = hapidExclusionCountMap.getOrDefault(myNodeId, 0);
		int totalCount = includeCount + excludeCount;
		BinomialDistribution binom = new BinomialDistribution(totalCount, 1 - Pcorrect);
		return binom.probability(excludeCount);
	}
	
	public double allCounts(int state, int anchor) {
		//P(obs|state) = pbinom(trials = nExcludes + nIncludes, successes = nExcludes, p = pErr) (Sum)
		//need to deal with nodes with too few observations (include + exclude)
		if (myCurrentAnchor != anchor) {
			myCurrentAnchor = anchor;
			anchorProbabilities(anchor);
		}
		return anchorProbabilities[state];
	}
	
	public double allCountsWeighted(int state, int anchor) {
		if (myCurrentAnchor != anchor) {
			myCurrentAnchor = anchor;
			weightedAnchorProbabilities(anchor);
		}
		return anchorProbabilities[state];
	}
	
	private void anchorProbabilities(int anchor) {
		ReferenceRange range = myRanges.get(anchor);
		List nodes = myNodeMap.get(range);
		int[] nodeCounts = nodes.stream()
				.mapToInt(node -> hapidCountMap.getOrDefault(node.id(), 0).intValue())
				.toArray();
		int[] exclusionCounts = nodes.stream()
				.mapToInt(node -> hapidExclusionCountMap.getOrDefault(node.id(), 0).intValue())
				.toArray();
		int n = nodeCounts.length;
		int[] totalCounts = new int[n];
		for (int i = 0; i < n; i++) totalCounts[i] = nodeCounts[i] + exclusionCounts[i];
		double[] stateProbs = new double[n];
		double totalProb =  0;
		double probCount = 0;
		
		//for each node, if total > 0, calculate the probability from pbinom(total, excluded, 1-Pcorrect)
		//if total = 0, use avg probability
		for (int i = 0; i < n; i++) {
			if (totalCounts[i] > 0) {
				//calculate prob and add total
				BinomialDistribution binom = new BinomialDistribution(totalCounts[i], 1 - Pcorrect);
				stateProbs[i] = binom.probability(exclusionCounts[i]);
				totalProb += stateProbs[i];
				probCount++;
			}
		}
		double avgProb;
		if (probCount > 0) avgProb = totalProb / probCount;
		else avgProb = (double) 1.0 / n;
		for (int i = 0; i < n; i++) {
			if (totalCounts[i] == 0) {
				//set to avg prob
				stateProbs[i] = avgProb;
			}
		}
		anchorProbabilities = stateProbs;
	}
	
	private void weightedAnchorProbabilities(int anchor) {
		ReferenceRange range = myRanges.get(anchor);
		List nodes = myNodeMap.get(range);
		int[] taxaCounts = nodes.stream().mapToInt(node -> node.numTaxa()).toArray();
		int[] nodeCounts = nodes.stream()
				.mapToInt(node -> hapidCountMap.getOrDefault(node.id(), 0).intValue())
				.toArray();
		int[] exclusionCounts = nodes.stream()
				.mapToInt(node -> hapidExclusionCountMap.getOrDefault(node.id(), 0).intValue())
				.toArray();
		int n = nodeCounts.length;
		int[] totalCounts = new int[n];
		for (int i = 0; i < n; i++) totalCounts[i] = nodeCounts[i] + exclusionCounts[i];
		double[] stateProbs = new double[n];
		double totalProb =  0;
		double probCount = 0;
		
		//for each node, if total > 0, calculate the probability from pbinom(total, excluded, 1-Pcorrect)
		//if total = 0, use avg probability
		for (int i = 0; i < n; i++) {
			if (totalCounts[i] > 0) {
				//calculate prob and add total
				BinomialDistribution binom = new BinomialDistribution(totalCounts[i], 1 - Pcorrect);
				stateProbs[i] = binom.probability(exclusionCounts[i]);
				totalProb += stateProbs[i] * taxaCounts[i];
				probCount += taxaCounts[i];
			}
		}
		double avgProb;
		if (probCount > 0) avgProb = totalProb / probCount;
		else avgProb = (double) 1.0 / n;
		for (int i = 0; i < n; i++) {
			if (totalCounts[i] == 0) {
				//set to avg prob
				stateProbs[i] = avgProb;
			}
		}
		anchorProbabilities = stateProbs;
	}
	
	public double weightedSum(int state, int anchor) {
		return 0;
	}
	
	public double weightedMax(int state, int anchor) {
		return 0;
	}

	@Override
	public String toString() {
		StringBuilder sb = new StringBuilder();
		sb.append("ReferenceRangeEmissionProbability:\n");
		sb.append("Pcorrect = ").append(Pcorrect).append("\n");
		sb.append("myMethod = ").append(myMethod.name()).append("\n");
		sb.append("myAdjustmentFactor = ").append(myAdjustmentFactor).append("\n");
		return sb.toString();
	}

	public static class Builder {
		private NavigableMap> nodeMap = null;
		private Map hapidInclusionCountMap = null;
		private Map hapidExclusionCountMap = null;
		private double probCorrect = 0.99;
		private METHOD myMethod = METHOD.inclusionOnly;
		private double myAdjustmentFactor = 1;

		public Builder nodeMap(NavigableMap> nodeMap) {
			this.nodeMap = nodeMap;
			return this;
		}
		
		public Builder inclusionCountMap(Map inclusionCounts) {
			hapidInclusionCountMap = inclusionCounts;
			return this;
		}
		
		public Builder exclusionCountMap(Map exclusionCounts) {
			hapidExclusionCountMap = exclusionCounts;
			return this;
		}

		public Builder probabilityCorrect(double pcorrect) {
			probCorrect = pcorrect;
			return this;
		}
		
		public Builder method(METHOD emissionMethod) {
			myMethod = emissionMethod;
			return this;
		}
		
		public Builder adjustment(double factor) {
			myAdjustmentFactor = factor;
			return this;
		}
		
		public ReferenceRangeEmissionProbability build() {
			if (nodeMap == null) 
				throw new IllegalArgumentException("nodeMap required for building ReferenceRangeEmissionProbability");
			if (hapidInclusionCountMap == null) 
				throw new IllegalArgumentException("inclusion count map required for building ReferenceRangeEmissionProbability");
			if (myMethod == METHOD.inclusionOnly) {
				return new ReferenceRangeEmissionProbability(nodeMap, hapidInclusionCountMap, probCorrect);
			} else {
				if (hapidExclusionCountMap == null) 
					throw new IllegalArgumentException("exclusion count map required for building ReferenceRangeEmissionProbability");
				List> countList = new ArrayList<>();
				countList.add(hapidInclusionCountMap);
				countList.add(hapidExclusionCountMap);
				return new ReferenceRangeEmissionProbability(nodeMap, countList, myMethod, myAdjustmentFactor, probCorrect);
						
			}
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy