All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.maizegenetics.analysis.imputation.ViterbiAlgorithm Maven / Gradle / Ivy

package net.maizegenetics.analysis.imputation;

public class ViterbiAlgorithm {

	//adapted from Rabiner Proceedings of the IEEE 77(2):257-286
	//initialize
	//d(0, i) = p(Obs-0|S(i)*p(S(i))
	//where d0 = path length, Obs-0 = observation 0, S0 = true state 0
	//iterate:
	//for t = 1 to n
	//d(t,S(j)) = max(j){d(t-1,S(i)) * p[S(j)|S(i)] * p[Obs(t)|S(j)]}
	//h(t, j) = the value of i that maximizes distance
	//where h() = path history
	//termination
	//choose state that maximizes path length
	//back tracking
	//S(t) = h(t+1, S(t+1)), to decode best sequence
	
	TransitionProbability myTransitionMatrix;
	EmissionProbability probObservationGivenState;
	byte[] obs;
	byte[][] history;
	double[] distance;
	int numberOfStates; //number of true states;
	double[] probTrueStates; //ln of probabilities
	int numberOfObs;
	byte[] finalState;
	
	
	public ViterbiAlgorithm(byte[] observations, TransitionProbability transitionMatrix, EmissionProbability obsGivenTrue, double[] pTrue) {
		obs = observations;
		numberOfObs = obs.length;
		numberOfStates = transitionMatrix.getNumberOfStates();
		
		myTransitionMatrix = transitionMatrix;
		probObservationGivenState = obsGivenTrue;
		probTrueStates = new double[pTrue.length];
		for (int i = 0; i < pTrue.length; i++) {
			probTrueStates[i] = Math.log(pTrue[i]);
		}
		
		history = new byte[numberOfStates][numberOfObs];
		distance = new double[numberOfStates];
	}
	
	public void calculate() {
		initialize();
		for (int i = 1; i < numberOfObs; i++) {
			updateDistanceAndHistory(i);
		}
	}
	
	public void initialize() {
		for (int i = 0; i < numberOfStates; i++) {
			try{
				distance[i] = probObservationGivenState.getLnProbObsGivenState(i, obs[0], 0) + probTrueStates[i];
			} catch(Exception e) {
				e.printStackTrace();
			}
		}
	}
	
	public void updateDistanceAndHistory(int node) {
		double[][] candidateDistance = new double[numberOfStates][numberOfStates];
		myTransitionMatrix.setNode(node);
		for (int i = 0; i < numberOfStates; i++) {
			for (int j = 0; j < numberOfStates; j++) {
				candidateDistance[i][j] = distance[i] + myTransitionMatrix.getLnTransitionProbability(i, j) + probObservationGivenState.getLnProbObsGivenState(j, obs[node], node);
			}
		}

		//find the maxima
		int[] max = new int[numberOfStates];
		for (int i = 0; i < numberOfStates; i++) {
			for (int j = 0; j < numberOfStates; j++) {
				if (candidateDistance[i][j] > candidateDistance[max[j]][j]) max[j] = i;
			}
		}

		//update distance and history
		for (int j = 0; j < numberOfStates; j++) {
			distance[j] = candidateDistance[max[j]][j];
			history[j][node] = (byte) max[j];
		}
		
		//if the min distance is less than -1e100, subtract the max distance;
		double maxd = distance[0];
		double mind = 0;
		for (int i = 0; i < numberOfStates; i++) {
			if (distance[i] > maxd) maxd = distance[i];
			if (distance[i] != Double.NEGATIVE_INFINITY && distance[i] < mind) mind = distance[i];
		}
		if (mind < -1e100) {
			for (int i = 0; i < numberOfStates; i++) {
				distance[i] -= maxd;
			}
		}
	}
	
	//decode the most probable state sequence
	public byte[] getMostProbableStateSequence() {
		byte[] seq = new byte[numberOfObs];
		byte finalState = 0;
		for (int i = 1; i < numberOfStates; i++) {
			if (distance[i] > distance[finalState]) finalState = (byte) i;
		}
		
		//S(t) = h(t+1, S(t+1)), to decode best sequence
		seq[numberOfObs - 1] = finalState;
		for (int i = numberOfObs - 2; i >= 0; i--) {
			 seq[i] = history[seq[i + 1]][i + 1];
		}
		return seq;
	}
		
	public void setStateProbability(double[] probTrueState) {
		int n = probTrueState.length;
		probTrueStates = new double[n];
		for (int i = 0; i < n; i++) {
			probTrueStates[i] = Math.log(probTrueState[i]);
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy