net.maizegenetics.analysis.imputation.ViterbiAlgorithmVariableStateNumber Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of tassel Show documentation
Show all versions of tassel Show documentation
TASSEL is a software package to evaluate traits associations, evolutionary patterns, and linkage
disequilibrium.
package net.maizegenetics.analysis.imputation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
public class ViterbiAlgorithmVariableStateNumber {
//adapted from Rabiner Proceedings of the IEEE 77(2):257-286
//initialize
//d(0, i) = p(Obs-0|S(i)*p(S(i))
//where d0 = path length, Obs-0 = observation 0, S0 = true state 0
//iterate:
//for t = 1 to n
//d(t,S(j)) = max(j){d(t-1,S(i)) * p[S(j)|S(i)] * p[Obs(t)|S(j)]}
//h(t, j) = the value of i that maximizes distance
//where h() = path history
//termination
//choose state that maximizes path length
//back tracking
//S(t) = h(t+1, S(t+1)), to decode best sequence
TransitionProbability myTransitionMatrix;
EmissionProbability probObservationGivenState;
byte[] obs;
ArrayList history;
double[] distance;
double[] probTrueStates; //ln of probabilities
int numberOfObs;
byte[] finalState;
int numberOfCurrentNodeStates;
int numberOfPreviousNodeStates;
public ViterbiAlgorithmVariableStateNumber(byte[] observations, TransitionProbability transitionMatrix, EmissionProbability obsGivenTrue, double[] pTrue) {
obs = observations;
numberOfObs = obs.length;
// numberOfStates = transitionMatrix.getNumberOfStates();
myTransitionMatrix = transitionMatrix;
probObservationGivenState = obsGivenTrue;
probTrueStates = new double[pTrue.length];
for (int i = 0; i < pTrue.length; i++) {
probTrueStates[i] = Math.log(pTrue[i]);
}
history = new ArrayList<>(numberOfObs);
// distance = new double[numberOfStates];
}
public void calculate() {
initialize();
for (int i = 1; i < numberOfObs; i++) {
updateDistanceAndHistory(i);
}
}
public void initialize() {
int n = probTrueStates.length;
numberOfCurrentNodeStates = numberOfPreviousNodeStates = n;
distance = new double[n];
for (int i = 0; i < n; i++) {
try{
distance[i] = probObservationGivenState.getLnProbObsGivenState(i, obs[0], 0) + probTrueStates[i];
} catch(Exception e) {
e.printStackTrace();
}
}
}
public void updateDistanceAndHistory(int node) {
myTransitionMatrix.setNode(node);
numberOfPreviousNodeStates = numberOfCurrentNodeStates;
numberOfCurrentNodeStates = myTransitionMatrix.getNumberOfStates();
double[][] candidateDistance = new double[numberOfPreviousNodeStates][numberOfCurrentNodeStates];
int distanceLength = distance.length;
try { //this try block for debugging
for (int i = 0; i < numberOfPreviousNodeStates; i++) { //this is the number of nodes for the previous anchor
for (int j = 0; j < numberOfCurrentNodeStates; j++) {
candidateDistance[i][j] = distance[i] + myTransitionMatrix.getLnTransitionProbability(i, j) + probObservationGivenState.getLnProbObsGivenState(j, obs[node], node);
}
}
} catch(Exception e) {
System.out.println(String.format("at node %d, distance.length = %d, numberOfPreviousNodeStates = %d, and numberOfCurrentNodeStates = %d",
node, distance.length, numberOfPreviousNodeStates, numberOfCurrentNodeStates));
throw new RuntimeException(e);
}
//find the maxima
int[] max = new int[numberOfCurrentNodeStates];
for (int i = 0; i < numberOfPreviousNodeStates; i++) {
for (int j = 0; j < numberOfCurrentNodeStates; j++) {
if (candidateDistance[i][j] > candidateDistance[max[j]][j]) max[j] = i;
}
}
//update distance and history
distance = new double[numberOfCurrentNodeStates];
byte[] nodeHistory = new byte[numberOfCurrentNodeStates];
history.add(nodeHistory);
for (int j = 0; j < numberOfCurrentNodeStates; j++) {
distance[j] = candidateDistance[max[j]][j];
nodeHistory[j] = (byte) max[j];
}
//debug
// System.out.print("distance: ");
// for (double dbl : distance) System.out.print(String.format(" %1.2e", dbl));
// System.out.println();
// System.out.println("candidate distance:");
// for (int i = 0; i < numberOfPreviousNodeStates; i++) {
// for (int j = 0; j < numberOfCurrentNodeStates; j++) {
// System.out.printf("%1.1e ", candidateDistance[i][j]);
// }
// System.out.println();
// }
// System.out.print("node history: ");
// for (byte b : nodeHistory) System.out.print(String.format(" %d", b));
// System.out.println();
// if (node > 5) System.exit(0);
//end debug
//if the min distance is less than -1e100, subtract the min distance;
double maxd = distance[0];
double mind = 0;
for (int i = 0; i < numberOfCurrentNodeStates; i++) {
if (distance[i] > maxd) maxd = distance[i];
if (distance[i] != Double.NEGATIVE_INFINITY && distance[i] < mind) mind = distance[i];
}
if (mind < -1e100) {
for (int i = 0; i < numberOfCurrentNodeStates; i++) {
distance[i] -= maxd;
}
}
}
//decode the most probable state sequence
public byte[] getMostProbableStateSequence() {
byte[] seq = new byte[numberOfObs];
byte finalState = 0;
for (int i = 1; i < distance.length; i++) {
if (distance[i] > distance[finalState]) finalState = (byte) i;
}
//S(t) = h(t+1, S(t+1)), to decode best sequence
seq[numberOfObs - 1] = finalState;
for (int i = numberOfObs - 2; i >= 0; i--) {
seq[i] = history.get(i)[seq[i + 1]];
}
return seq;
}
public void setStateProbability(double[] probTrueState) {
int n = probTrueState.length;
probTrueStates = new double[n];
for (int i = 0; i < n; i++) {
probTrueStates[i] = Math.log(probTrueState[i]);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy