dist.edu.umd.hooka.alignment.hmm.HMM Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cloud9 Show documentation
Show all versions of cloud9 Show documentation
University of Maryland's Hadoop Library
package edu.umd.hooka.alignment.hmm;
import java.io.IOException;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;
import edu.umd.hooka.Alignment;
import edu.umd.hooka.AlignmentPosteriorGrid;
import edu.umd.hooka.Array2D;
import edu.umd.hooka.PhrasePair;
import edu.umd.hooka.alignment.PartialCountContainer;
import edu.umd.hooka.alignment.PerplexityReporter;
import edu.umd.hooka.alignment.CrossEntropyCounters;
import edu.umd.hooka.alignment.ZeroProbabilityException;
import edu.umd.hooka.alignment.model1.Model1;
import edu.umd.hooka.ttables.TTable;
/**
* Represents an HMM that applies to a single sentence pair, which is
* derived from the parameters stored in a TTable and an ATable object.
*
* @author redpony
*
*/
public class HMM extends Model1 {
public static final IntWritable ACOUNT_VOC_ID = new IntWritable(999999);
static final int MAX_LENGTH = 500;
static final float THRESH =0.5f;
/**
* (s,j) = p(f_j|e(s))
*/
Array2D emission = new Array2D(MAX_LENGTH * MAX_LENGTH);
/**
* (s,j) = i s.t. e(s) = e_i or -1 if n.a.
*/
IntArray2D e_coords = new IntArray2D(MAX_LENGTH * MAX_LENGTH);
/**
* (s,j) = the english word corresponding to state s
*/
IntArray2D e_words = new IntArray2D(MAX_LENGTH * MAX_LENGTH);
/**
* (i',i) = p(i-i')
*/
Array2D transition = new Array2D(MAX_LENGTH * MAX_LENGTH);
IntArray2D transition_coords = new IntArray2D(MAX_LENGTH * MAX_LENGTH);
Array2D alphas = new Array2D(MAX_LENGTH * MAX_LENGTH);
Array2D betas = new Array2D(MAX_LENGTH * MAX_LENGTH);
Array2D viterbi = new Array2D(MAX_LENGTH * MAX_LENGTH);
IntArray2D backtrace = new IntArray2D(MAX_LENGTH * MAX_LENGTH);
ATable amodel;
ATable acounts;
int l = -1;
int m = -1;
AlignmentPosteriorGrid m1_post = null;
public void setModel1Posteriors(AlignmentPosteriorGrid m1pg) {
m1_post = m1pg;
}
protected HMM(TTable ttable, ATable atable, boolean useNull) {
super(ttable, useNull);
amodel = atable;
acounts = (ATable)amodel.clone(); acounts.clear();
}
public HMM(TTable ttable, ATable atable) {
super(ttable, false);
amodel = atable;
acounts = (ATable)amodel.clone(); acounts.clear();
}
public void writePartialCounts(OutputCollector output) throws IOException
{
super.writePartialCounts(output);
PartialCountContainer pcc = new PartialCountContainer();
pcc.setContent(acounts);
output.collect(ACOUNT_VOC_ID, pcc);
acounts.clear();
}
public void buildHMMTables(PhrasePair pp) {
int[] es = pp.getE().getWords();
int[] fs = pp.getF().getWords();
l = es.length;
m = fs.length;
emission.resize(m + 1, l + 1);
e_coords.resize(m + 1, l + 1);
e_words.resize(m + 1, l + 1);
e_words.fill(-1);
e_coords.fill(-1);
for (int i = 1; i <= l; i++) {
int ei = es[i-1];
for (int j = 1; j <= m; j++) {
int fj = fs[j-1];
e_coords.set(j, i, i);
emission.set(j, i, tmodel.get(ei, fj));
e_words.set(j, i, i - 1);
}
}
//System.out.println("b:\n"+emission);
transition.resize(l+1, l+1);
transition_coords.resize(l+1, l+1);
transition_coords.fill(-1);
for (int i_prev = 0; i_prev <= l; i_prev++) {
for (int i = 1; i <= l; i++) {
transition_coords.set(i_prev, i, amodel.getCoord(i - i_prev, (char)l));
transition.set(i_prev, i, amodel.get(i - i_prev, (char)l));
}
}
//System.out.println("a:\n"+transition);
}
public final int getNumStates() {
return transition.getSize2();
}
public final float getTransitionProb(int s_prev, int s) {
return transition.get(s_prev, s);
}
public final float getEmissionProb(int j, int s) {
return emission.get(j, s);
}
public final void addPartialJumpCountsToATable(ATable ac) {
ac.plusEquals(acounts);
}
@Override
public void processTrainingInstance(PhrasePair pp, Reporter r) {
if (pp.getE().size() >= amodel.getMaxDist()-1) return;
if (pp.getF().size() >= amodel.getMaxDist()-1) return;
if (pp.getE().size() == 0) return;
if (pp.getF().size() == 0) return;
this.buildHMMTables(pp);
float totalLogProb = this.baumWelch(pp, null);
if (r != null) {
r.incrCounter(CrossEntropyCounters.LOGPROB, (long)(-totalLogProb));
r.incrCounter(CrossEntropyCounters.WORDCOUNT, pp.getF().size());
}
}
/**
* @return negative log probability of sentence
*/
public final float baumWelch(PhrasePair pp, AlignmentPosteriorGrid pg) {
initializeCountTableForSentencePair(pp);
int[] obs = pp.getF().getWords();
int J = obs.length + 1;
int numStates = getNumStates();
int l = pp.getE().getWords().length;
float[] anorms = new float[J];
alphas.resize(J + 1, getNumStates());
betas.resize(J + 1, getNumStates());
alphas.set(0, 0, 1.0f); anorms[0]=1.0f;
Alignment m1a = null;
if (m1_post != null)
m1a = m1_post.alignPosteriorThreshold(THRESH);
for (int j = 1; j < J; j++) {
//System.out.println("J="+j);
for (int s = 0; s < numStates; s++) {
float alpha = 0.0f;
float m1boost = 1.0f;
float m1penalty = 0.0f;
boolean use_m1 = false;
if (m1a != null && m1a.isFAligned(j-1)) {
float m1post = 0.0f;
use_m1 = true;
for (int i=0; i 0 && m1a.aligned(j-1, s-1))
trans = m1boost;
else
trans *= m1penalty;
}
alpha += alphas.get(j - 1, s_prev) * trans;
}
alpha *= getEmissionProb(j, s);
//System.out.println(" ep:" + hmm.getEmissionProb(s, j));
alphas.set(j, s, alpha);
}
//anorms[j] = 1.0f;
try {
anorms[j] = alphas.normalizeColumn(j);
} catch (ZeroProbabilityException ex) {
this.notifyUnalignablePair(pp, ex.getMessage());
return 0.0f;
}
}
for (int s=1; s=1; j--) {
//System.out.println("J="+j);
for (int s = 0; s < numStates; s++) {
float beta = 0.0f;
float m1boost = 1.0f;
float m1penalty = 0.0f;
boolean use_m1 = false;
if (m1a != null && m1a.isFAligned(j-1)) {
float m1post = 0.0f;
use_m1 = true;
for (int i=0; i 0 && m1a.aligned(j-1, s-1))
trans = m1boost;
else
trans *= m1penalty;
}
beta += betas.get(j+1, s_next) *
trans *
getEmissionProb(j+1, s_next);
}
beta /= anorms[j];
//System.out.println(" s="+s+ " b:"+beta);
betas.set(j, s, beta);
}
}
// PARTIAL COUNTS FOR EMMISSIONS (WORD TRANSLATION)
float totalProb[] = new float[J];
for (int j=1; j 0 && m1a.aligned(j-1, s-1))
trans = m1boost;
else
trans *= m1penalty;
}
float cur = (float)(viterbi.get(j - 1, s_prev) +
Math.log(trans) +
emitLogProb);
//System.out.println(" s'="+s_prev + " cur="+cur);
if (cur > best) {
best = cur;
best_s = s_prev;
//System.out.println("new best: " + s + " " + best_s);
}
}
//System.out.println(" s_best="+best_s + " cur="+best);
viterbi.set(j, s, best);
if (best != Float.NEGATIVE_INFINITY)
valid = true;
backtrace.set(j, s, best_s);
}
// if we don't know how to generate some column
// create a uniform distribution over the states
// and assume the previous state was the best
if (!valid) {
float best = Float.NEGATIVE_INFINITY;
int bests = -1;
for (int s = 1; s < numStates; s++) {
if (viterbi.get(j-1, s) > best) {
best = viterbi.get(j-1, s);
bests = s;
}
}
for (int s = 1; s < numStates; s++) {
viterbi.set(j, s, 0.0f);
backtrace.set(j, s, bests);
}
}
}
//System.out.println(viterbi);
float best = Float.NEGATIVE_INFINITY;
int best_s = -1;
for (int s = 1; s < numStates; s++) {
if (viterbi.get(J-1, s) > best) {
best = viterbi.get(J-1,s);
best_s = s;
}
}
//System.out.println("vit: " + best + "j-1="+(J-1));
reporter.addFactor(best, J - 1);
//System.out.println(viterbi);
int e = best_s;
for (int f=J-1; f>0; f--) {
if (e <= 0) {
throw new ZeroProbabilityException(" Error f=" +f+" e="+e+
" sentence + \n" + viterbi + "\n" + emission + "\n" + transition + "\n" + backtrace);
} else {
if (viterbi.get(f, e) < 0.0) {
// hack to avoid errors
try {
int af = f-1;
int ae = e_words.get(f, e);
if (ae >= 0)
res.align(af, ae);
//else
// System.err.println("ALIGN NULL TO " + af);
} catch (RuntimeException ex) {
throw new RuntimeException("Caught " + ex + "\nvit(f,e)="+viterbi.get(f,e)+" size(f,e)=" + sentence.getF().size() +","+ sentence.getE().size() + " Error f=" +f+" e="+e+
" sentence + \n" + viterbi + "\n" + emission + "\n" + transition + "\n" + backtrace + "\n" + e_words);
}
}
e = backtrace.get(f, e);
}
}
return res;
}
}