org.deeplearning4j.util.Viterbi Maven / Gradle / Ivy
package org.deeplearning4j.util;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.berkeley.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.nn.api.Persistable;
import java.io.InputStream;
import java.io.OutputStream;
/**
* Based on the impl from:
* https://gist.github.com/rmcgibbo/3915977
*
*/
public class Viterbi implements Persistable {
private double metaStability = 0.9;
private double pCorrect = 0.99;
private INDArray possibleLabels;
private int states;
private double logPCorrect = FastMath.log(pCorrect);
private double logPIncorrect = FastMath.log(1 - pCorrect / states - 1);
private double logMetaInstability = Math.log(metaStability);
private double logOfDiangnalTProb;
private double logStates;
/**
* The possible outcomes for the chain.
* This should be the labels in the form of the possible outcomes (1,2,3)
* not the binarized label matrix
* @param possibleLabels the possible labels of the markov chain
*/
public Viterbi(INDArray possibleLabels) {
this.possibleLabels = possibleLabels;
this.states = possibleLabels.length();
logOfDiangnalTProb = FastMath.log(1 - metaStability / states - 1);
this.logStates = FastMath.log(states);
}
/**
* Decodes the given labels, assuming its a binary label matrix
* @param labels the labels as a binary label matrix
* @return the decoded labels and the most likely outcome of the sequence
*/
public Pair decode(INDArray labels) {
return decode(labels,true);
}
/**
* Decodes a series of labels
* @param labels the labels to decode
* @param binaryLabelMatrix whether the label is a binary label matrix
* @return the most likely sequence and the sequence labels
*/
public Pair decode(INDArray labels,boolean binaryLabelMatrix) {
INDArray outcomeSequence = labels.isColumnVector() || labels.isRowVector() || binaryLabelMatrix ? toOutcomesFromBinaryLabelMatrix(labels) : labels;
int frames = outcomeSequence.length();
INDArray V = Nd4j.ones(frames, states);
INDArray pointers = Nd4j.zeros(frames,states);
INDArray assigned = V.getRow(0);
assigned.assign(logPCorrect - logStates);
V.putRow(0,assigned);
V.put(0, (int) outcomeSequence.getDouble(0), logPCorrect - logStates);
for(int t = 1; t < frames; t++) {
for(int k = 0; k < states; k++) {
INDArray rowLogProduct = rowOfLogTransitionMatrix(k).add(V.getRow(t - 1));
int maxVal = Nd4j.getBlasWrapper().iamax(rowLogProduct);
double argMax = rowLogProduct.max(Integer.MAX_VALUE).getDouble(0);
V.put(t,k,argMax);
int element = (int) outcomeSequence.getDouble(t);
if(k == element)
V.put(t,k,logPCorrect + maxVal);
else
V.put(t,k, logPIncorrect + maxVal);
}
}
INDArray rectified = Nd4j.zeros(frames);
rectified.put(rectified.length() - 1,V.getRow(frames - 1).max(Integer.MAX_VALUE));
for(int t = rectified.length() - 2; t > 0; t--) {
rectified.putScalar(t,pointers.getDouble(t + 1,(int) rectified.getDouble(t + 1)));
}
return new Pair<>(V.getRow(frames - 1).max(Integer.MAX_VALUE).getDouble(0),rectified);
}
private INDArray rowOfLogTransitionMatrix(int k) {
INDArray row = Nd4j.ones(1,states).muli(logOfDiangnalTProb);
row.putScalar(k,logMetaInstability);
return row;
}
private INDArray toOutcomesFromBinaryLabelMatrix(INDArray outcomes) {
INDArray ret = Nd4j.create(outcomes.rows(),1);
for(int i = 0; i < outcomes.rows(); i++)
ret.put(i,0, Nd4j.getBlasWrapper().iamax(outcomes.getRow(i)));
return ret;
}
@Override
public void write(OutputStream os) {
SerializationUtils.writeObject(this,os);
}
@Override
public void load(InputStream is) {
Viterbi ret = SerializationUtils.readObject(is);
this.states = ret.states;
this.logStates = ret.logStates;
this.metaStability = ret.metaStability;
this.logMetaInstability = ret.logMetaInstability;
this.logOfDiangnalTProb = ret.logOfDiangnalTProb;
this.logPCorrect = ret.logPCorrect;
this.pCorrect = ret.pCorrect;
}
public double getMetaStability() {
return metaStability;
}
public void setMetaStability(double metaStability) {
this.metaStability = metaStability;
}
public double getpCorrect() {
return pCorrect;
}
public void setpCorrect(double pCorrect) {
this.pCorrect = pCorrect;
}
public INDArray getPossibleLabels() {
return possibleLabels;
}
public void setPossibleLabels(INDArray possibleLabels) {
this.possibleLabels = possibleLabels;
}
public int getStates() {
return states;
}
public void setStates(int states) {
this.states = states;
}
public double getLogPCorrect() {
return logPCorrect;
}
public void setLogPCorrect(double logPCorrect) {
this.logPCorrect = logPCorrect;
}
public double getLogPIncorrect() {
return logPIncorrect;
}
public void setLogPIncorrect(double logPIncorrect) {
this.logPIncorrect = logPIncorrect;
}
public double getLogMetaInstability() {
return logMetaInstability;
}
public void setLogMetaInstability(double logMetaInstability) {
this.logMetaInstability = logMetaInstability;
}
public double getLogOfDiangnalTProb() {
return logOfDiangnalTProb;
}
public void setLogOfDiangnalTProb(double logOfDiangnalTProb) {
this.logOfDiangnalTProb = logOfDiangnalTProb;
}
public double getLogStates() {
return logStates;
}
public void setLogStates(double logStates) {
this.logStates = logStates;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy