All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dist.edu.umd.hooka.alignment.hmm.HMM Maven / Gradle / Ivy

There is a newer version: 2.0.1
Show newest version
package edu.umd.hooka.alignment.hmm;

import java.io.IOException;

import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;

import edu.umd.hooka.Alignment;
import edu.umd.hooka.AlignmentPosteriorGrid;
import edu.umd.hooka.Array2D;
import edu.umd.hooka.PhrasePair;
import edu.umd.hooka.alignment.PartialCountContainer;
import edu.umd.hooka.alignment.PerplexityReporter;
import edu.umd.hooka.alignment.CrossEntropyCounters;
import edu.umd.hooka.alignment.ZeroProbabilityException;
import edu.umd.hooka.alignment.model1.Model1;
import edu.umd.hooka.ttables.TTable;

/**
 * Represents an HMM that applies to a single sentence pair, which is
 * derived from the parameters stored in a TTable and an ATable object.
 * 
 * @author redpony
 *
 */
public class HMM extends Model1 {
	public static final IntWritable ACOUNT_VOC_ID = new IntWritable(999999);
	static final int MAX_LENGTH = 500;
	static final float THRESH =0.5f;
	
	/**
	 * (s,j) = p(f_j|e(s))
	 */
	Array2D emission = new Array2D(MAX_LENGTH * MAX_LENGTH);

	/**
	 * (s,j) = i s.t. e(s) = e_i or -1 if n.a.
	 */
	IntArray2D e_coords = new IntArray2D(MAX_LENGTH * MAX_LENGTH);
	
	/**
	 * (s,j) = the english word corresponding to state s
	 */
	IntArray2D e_words = new IntArray2D(MAX_LENGTH * MAX_LENGTH);
	
	/**
	 * (i',i) = p(i-i')
	 */
	Array2D transition = new Array2D(MAX_LENGTH * MAX_LENGTH);
	IntArray2D transition_coords = new IntArray2D(MAX_LENGTH * MAX_LENGTH);

	Array2D alphas = new Array2D(MAX_LENGTH * MAX_LENGTH);
	Array2D betas  = new Array2D(MAX_LENGTH * MAX_LENGTH);

	Array2D viterbi = new Array2D(MAX_LENGTH * MAX_LENGTH);
	IntArray2D backtrace = new IntArray2D(MAX_LENGTH * MAX_LENGTH);
	
	ATable amodel;
	ATable acounts;
	
	int l = -1;
	int m = -1;
	AlignmentPosteriorGrid m1_post = null;
	
	public void setModel1Posteriors(AlignmentPosteriorGrid m1pg) {
		m1_post = m1pg;
	}

	protected HMM(TTable ttable, ATable atable, boolean useNull) {
		super(ttable, useNull);
		amodel = atable;
		acounts = (ATable)amodel.clone(); acounts.clear();		
	}
	
	public HMM(TTable ttable, ATable atable) {
		super(ttable, false);
		amodel = atable;
		acounts = (ATable)amodel.clone(); acounts.clear();
	}
	
	public void writePartialCounts(OutputCollector output) throws IOException
	{
		super.writePartialCounts(output);
		PartialCountContainer pcc = new PartialCountContainer();
		pcc.setContent(acounts);
		output.collect(ACOUNT_VOC_ID, pcc);
		acounts.clear();
	}
	
	public void buildHMMTables(PhrasePair pp) {
		int[] es = pp.getE().getWords();
		int[] fs = pp.getF().getWords();
		l = es.length;
		m = fs.length;
		emission.resize(m + 1, l + 1);
		e_coords.resize(m + 1, l + 1);
		e_words.resize(m + 1, l + 1);
		e_words.fill(-1);
		e_coords.fill(-1);
		for (int i = 1; i <= l; i++) {
			int ei = es[i-1];
			for (int j = 1; j <= m; j++) {
				int fj = fs[j-1];
				e_coords.set(j, i, i);
				emission.set(j, i, tmodel.get(ei, fj));
				e_words.set(j, i, i - 1);
			}
		}
		//System.out.println("b:\n"+emission);

		transition.resize(l+1, l+1);
		transition_coords.resize(l+1, l+1);
		transition_coords.fill(-1);
		for (int i_prev = 0; i_prev <= l; i_prev++) {
			for (int i = 1; i <= l; i++) {
				transition_coords.set(i_prev, i, amodel.getCoord(i - i_prev, (char)l));
				transition.set(i_prev, i, amodel.get(i - i_prev, (char)l));
			}
		}
		
		//System.out.println("a:\n"+transition);
	}
	
	public final int getNumStates() {
		return transition.getSize2();
	}
	
	public final float getTransitionProb(int s_prev, int s) {
		return transition.get(s_prev, s);
	}
	
	public final float getEmissionProb(int j, int s) {
		return emission.get(j, s);
	}
	
	public final void addPartialJumpCountsToATable(ATable ac) {
		ac.plusEquals(acounts);
	}

	@Override
	public void processTrainingInstance(PhrasePair pp, Reporter r) {
		if (pp.getE().size() >= amodel.getMaxDist()-1) return;
		if (pp.getF().size() >= amodel.getMaxDist()-1) return;
		if (pp.getE().size() == 0) return;
		if (pp.getF().size() == 0) return;

		this.buildHMMTables(pp);
		float totalLogProb = this.baumWelch(pp, null);
		if (r != null) {
			r.incrCounter(CrossEntropyCounters.LOGPROB, (long)(-totalLogProb));
			r.incrCounter(CrossEntropyCounters.WORDCOUNT, pp.getF().size());
		}
	}

	/**
	 * @return negative log probability of sentence
	 */
	public final float baumWelch(PhrasePair pp, AlignmentPosteriorGrid pg) {

		initializeCountTableForSentencePair(pp);

		int[] obs = pp.getF().getWords();
		int J = obs.length + 1;
		int numStates = getNumStates();
		int l = pp.getE().getWords().length;
		float[] anorms = new float[J];
		alphas.resize(J + 1, getNumStates());
		betas.resize(J + 1, getNumStates());
		alphas.set(0, 0, 1.0f); anorms[0]=1.0f;
		Alignment m1a = null;
		if (m1_post != null)
			m1a = m1_post.alignPosteriorThreshold(THRESH);
		for (int j = 1; j < J; j++) {
			//System.out.println("J="+j);
			for (int s = 0; s < numStates; s++) {
				float alpha = 0.0f;
				float m1boost = 1.0f;
				float m1penalty = 0.0f;
				boolean use_m1 = false;
				if (m1a != null && m1a.isFAligned(j-1)) {
					float m1post = 0.0f;
					use_m1 = true;
					for (int i=0; i 0 && m1a.aligned(j-1, s-1))
							trans = m1boost;
						else
							trans *= m1penalty;
					}
					alpha += alphas.get(j - 1, s_prev) * trans;
				}
				alpha *= getEmissionProb(j, s);
				//System.out.println(" ep:" + hmm.getEmissionProb(s, j));
				alphas.set(j, s, alpha);
			}
			//anorms[j] = 1.0f;
			try {
				anorms[j] = alphas.normalizeColumn(j);
			} catch (ZeroProbabilityException ex) {
				this.notifyUnalignablePair(pp, ex.getMessage());
				return 0.0f;
			}
		}
		for (int s=1; s=1; j--) {
			//System.out.println("J="+j);
			for (int s = 0; s < numStates; s++) {
				float beta = 0.0f;
				float m1boost = 1.0f;
				float m1penalty = 0.0f;
				boolean use_m1 = false;
				if (m1a != null && m1a.isFAligned(j-1)) {
					float m1post = 0.0f;
					use_m1 = true;
					for (int i=0; i 0 && m1a.aligned(j-1, s-1))
							trans = m1boost;
						else
							trans *= m1penalty;
					}
					beta += betas.get(j+1, s_next) *
					  trans *
					  getEmissionProb(j+1, s_next);
				}
				
				beta /= anorms[j];
				//System.out.println("  s="+s+ "  b:"+beta);
				betas.set(j, s, beta);
			}
		}
		
		// PARTIAL COUNTS FOR EMMISSIONS (WORD TRANSLATION)
		float totalProb[] = new float[J];
		for (int j=1; j 0 && m1a.aligned(j-1, s-1))
							trans = m1boost;
						else
							trans *= m1penalty;
					}
					float cur = (float)(viterbi.get(j - 1, s_prev) +
					    Math.log(trans) +
					    emitLogProb);
					//System.out.println(" s'="+s_prev + "  cur="+cur);
					if (cur > best) {
						best = cur;
						best_s = s_prev;
						//System.out.println("new best: " + s + " " + best_s);
					}
				}
				//System.out.println(" s_best="+best_s + "  cur="+best);
				viterbi.set(j, s, best);
				if (best != Float.NEGATIVE_INFINITY)
					valid = true;
				backtrace.set(j, s, best_s);
			}
			// if we don't know how to generate some column
			// create a uniform distribution over the states
			// and assume the previous state was the best
			if (!valid) {
				float best = Float.NEGATIVE_INFINITY;
				int bests = -1;
				for (int s = 1; s < numStates; s++) {
					if (viterbi.get(j-1, s) > best) {
						best = viterbi.get(j-1, s);
						bests = s;
					}
				}
				for (int s = 1; s < numStates; s++) {
					viterbi.set(j, s, 0.0f);
					backtrace.set(j, s, bests);
				}
			}
		}
		//System.out.println(viterbi);
		float best = Float.NEGATIVE_INFINITY;
		int best_s = -1;
		for (int s = 1; s < numStates; s++) {
			if (viterbi.get(J-1, s) > best) {
				best = viterbi.get(J-1,s);
				best_s = s;
			}
		}
		//System.out.println("vit: " + best + "j-1="+(J-1));
		reporter.addFactor(best, J - 1);
		//System.out.println(viterbi);
		int e = best_s;
		for (int f=J-1; f>0; f--) {
			if (e <= 0) {
				throw new ZeroProbabilityException("  Error f=" +f+" e="+e+
						"  sentence + \n" + viterbi + "\n" + emission + "\n" + transition + "\n" + backtrace);
			} else {
				if (viterbi.get(f, e) < 0.0) {
					// hack to avoid errors
					try {
						int af = f-1;
						int ae = e_words.get(f, e);
						if (ae >= 0)
							res.align(af, ae);
						//else
						//	System.err.println("ALIGN NULL TO " + af);
					} catch (RuntimeException ex) {
						throw new RuntimeException("Caught " + ex + "\nvit(f,e)="+viterbi.get(f,e)+"  size(f,e)=" + sentence.getF().size() +","+ sentence.getE().size() + " Error f=" +f+" e="+e+
								"  sentence + \n" + viterbi + "\n" + emission + "\n" + transition + "\n" + backtrace + "\n" + e_words);
					}
				}
				e = backtrace.get(f, e);
			}
		}
		return res;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy