All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
edu.stanford.nlp.sequences.BeamBestSequenceFinder Maven / Gradle / Ivy
package edu.stanford.nlp.sequences;
import edu.stanford.nlp.util.Beam;
import edu.stanford.nlp.util.Scored;
import edu.stanford.nlp.util.ScoredComparator;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Arrays;
/**
* An class capable of computing the best sequence given a SequenceModel. Uses beam search.
* @author Dan Klein
* @author Teg Grenager ([email protected] )
*/
public class BeamBestSequenceFinder implements BestSequenceFinder {
/**
* A class for testing.
*/
private static class TestSequenceModel implements SequenceModel {
private int[] correctTags = {0, 0, 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, 3, 2, 1, 0, 0};
private int[] allTags = {1, 2, 3, 4, 5, 6, 7, 8, 9};
private int[] midTags = {0, 1, 2, 3};
private int[] nullTags = {0};
public int length() {
return correctTags.length - leftWindow() - rightWindow();
}
public int leftWindow() {
return 2;
}
public int rightWindow() {
return 2;
}
public int[] getPossibleValues(int pos) {
if (pos < leftWindow() || pos >= leftWindow() + length()) {
return nullTags;
}
if (correctTags[pos] < 4) {
return midTags;
}
return allTags;
}
public double scoreOf(int[] sequence) {
throw new UnsupportedOperationException();
}
public double scoreOf(int[] tags, int pos) {
//System.out.println("Was asked: "+arrayToString(tags,10)+" at "+pos);
boolean match = true;
boolean ones = true;
for (int loc = pos - leftWindow(); loc <= pos + rightWindow(); loc++) {
if (tags[loc] != correctTags[loc]) {
match = false;
}
if (tags[loc] != 1 && loc >= leftWindow() && loc < length() + leftWindow()) {
ones = false;
}
}
if (match) {
return pos;
}
if (ones) {
return 0;//(length()/2-1);
}
return 0;
}
public double[] scoresOf(int[] tags, int pos) {
int[] tagsAtPos = getPossibleValues(pos);
double[] scores = new double[tagsAtPos.length];
for (int t = 0; t < tagsAtPos.length; t++) {
tags[pos] = tagsAtPos[t];
scores[t] = scoreOf(tags, pos);
}
return scores;
}
}
public static void main(String[] args) {
BestSequenceFinder ti = new BeamBestSequenceFinder(4, true);
SequenceModel ts = new TestSequenceModel();
int[] bestTags = ti.bestSequence(ts);
System.out.println("The best sequence is .... " + Arrays.toString(bestTags));
}
private static int[] tmp = null;
private static class TagSeq implements Scored {
private static class TagList {
int tag = -1;
TagList last = null;
}
private double score = 0.0;
public double score() {
return score;
}
private int size = 0;
public int size() {
return size;
}
private TagList info = null;
public int[] tmpTags(int count, int s) {
if (tmp == null || tmp.length < s) {
//tmp = new int[1024*128];
tmp = new int[s];
}
TagList tl = info;
int i = size() - 1;
while (tl != null && count >= 0) {
tmp[i] = tl.tag;
i--;
count--;
tl = tl.last;
}
return tmp;
}
public int[] tags() {
int[] t = new int[size()];
int i = size() - 1;
for (TagList tl = info; tl != null; tl = tl.last) {
t[i] = tl.tag;
i--;
}
return t;
}
public void extendWith(int tag) {
TagList last = info;
info = new TagList();
info.tag = tag;
info.last = last;
size++;
}
public void extendWith(int tag, SequenceModel ts, int s) {
extendWith(tag);
int[] tags = tmpTags(ts.leftWindow() + 1 + ts.rightWindow(), s);
score += ts.scoreOf(tags, size() - ts.rightWindow() - 1);
//for (int i=0; i= leftWindow + rightWindow) {
nextSeq.extendWith(tags[pos][nextTagNum], ts, size);
} else {
nextSeq.extendWith(tags[pos][nextTagNum]);
}
//System.out.println("Created: "+nextSeq.score()+" %% "+arrayToString(nextSeq.tags(), nextSeq.size()));
newBeam.add(nextSeq);
// System.out.println("Beam size: "+newBeam.size()+" of "+beamSize);
//System.out.println("Best is: "+((Scored)newBeam.iterator().next()).score());
}
}
System.out.println(" done");
if (recenter) {
double max = Double.NEGATIVE_INFINITY;
for (Iterator beamI = newBeam.iterator(); beamI.hasNext();) {
TagSeq tagSeq = (TagSeq) beamI.next();
if (tagSeq.score > max) {
max = tagSeq.score;
}
}
for (Iterator beamI = newBeam.iterator(); beamI.hasNext();) {
TagSeq tagSeq = (TagSeq) beamI.next();
tagSeq.score -= max;
}
}
}
try {
TagSeq bestSeq = (TagSeq) newBeam.iterator().next();
int[] seq = bestSeq.tags();
return seq;
} catch (NoSuchElementException e) {
System.err.println("Beam empty -- no best sequence.");
return null;
}
/*
int[] tempTags = new int[padLength];
// Set up product space sizes
int[] productSizes = new int[padLength];
int curProduct = 1;
for (int i=0; i leftWindow+rightWindow)
curProduct /= tagNum[pos-leftWindow-rightWindow-1]; // shift off
curProduct *= tagNum[pos]; // shift on
productSizes[pos-rightWindow] = curProduct;
}
// Score all of each window's options
double[][] windowScore = new double[padLength][];
for (int pos=leftWindow; pos= pos-leftWindow; curPos--) {
tempTags[curPos] = tags[curPos][p % tagNum[curPos]];
p /= tagNum[curPos];
if (curPos > pos)
shift *= tagNum[curPos];
}
if (tempTags[pos] == tags[pos][0]) {
// get all tags at once
double[] scores = ts.scoresOf(tempTags, pos);
// fill in the relevant windowScores
for (int t = 0; t < tagNum[pos]; t++) {
windowScore[pos][product+t*shift] = scores[t];
}
}
}
}
// Set up score and backtrace arrays
double[][] score = new double[padLength][];
int[][] trace = new int[padLength][];
for (int pos=0; pos score[pos][product]) {
score[pos][product] = predScore;
trace[pos][product] = predProduct;
}
}
}
}
}
// Project the actual tag sequence
double bestFinalScore = Double.NEGATIVE_INFINITY;
int bestCurrentProduct = -1;
for (int product=0; product bestFinalScore) {
bestCurrentProduct = product;
bestFinalScore = score[leftWindow+length-1][product];
}
}
int lastProduct = bestCurrentProduct;
for (int last=padLength-1; last>=length-1; last--) {
tempTags[last] = tags[last][lastProduct % tagNum[last]];
lastProduct /= tagNum[last];
}
for (int pos=leftWindow+length-2; pos>=leftWindow; pos--) {
int bestNextProduct = bestCurrentProduct;
bestCurrentProduct = trace[pos+1][bestNextProduct];
tempTags[pos-leftWindow] = tags[pos-leftWindow][bestCurrentProduct / (productSizes[pos]/tagNum[pos-leftWindow])];
}
return tempTags;
*/
}
/*
public int[] bestSequenceOld(TagScorer ts) {
// Set up tag options
int length = ts.length();
int leftWindow = ts.leftWindow();
int rightWindow = ts.rightWindow();
int padLength = length+leftWindow+rightWindow;
int[][] tags = new int[padLength][];
int[] tagNum = new int[padLength];
for (int pos = 0; pos < padLength; pos++) {
tags[pos] = ts.tagsAt(pos);
tagNum[pos] = tags[pos].length;
}
int[] tempTags = new int[padLength];
// Set up product space sizes
int[] productSizes = new int[padLength];
int curProduct = 1;
for (int i=0; i leftWindow+rightWindow)
curProduct /= tagNum[pos-leftWindow-rightWindow-1]; // shift off
curProduct *= tagNum[pos]; // shift on
productSizes[pos-rightWindow] = curProduct;
}
// Score all of each window's options
double[][] windowScore = new double[padLength][];
for (int pos=leftWindow; pos= pos-leftWindow; curPos--) {
tempTags[curPos] = tags[curPos][p % tagNum[curPos]];
p /= tagNum[curPos];
}
windowScore[pos][product] = ts.scoreOf(tempTags, pos);
}
}
// Set up score and backtrace arrays
double[][] score = new double[padLength][];
int[][] trace = new int[padLength][];
for (int pos=0; pos score[pos][product]) {
score[pos][product] = predScore;
trace[pos][product] = predProduct;
}
}
}
}
}
// Project the actual tag sequence
double bestFinalScore = Double.NEGATIVE_INFINITY;
int bestCurrentProduct = -1;
for (int product=0; product bestFinalScore) {
bestCurrentProduct = product;
bestFinalScore = score[leftWindow+length-1][product];
}
}
int lastProduct = bestCurrentProduct;
for (int last=padLength-1; last>=length-1; last--) {
tempTags[last] = tags[last][lastProduct % tagNum[last]];
lastProduct /= tagNum[last];
}
for (int pos=leftWindow+length-2; pos>=leftWindow; pos--) {
int bestNextProduct = bestCurrentProduct;
bestCurrentProduct = trace[pos+1][bestNextProduct];
tempTags[pos-leftWindow] = tags[pos-leftWindow][bestCurrentProduct / (productSizes[pos]/tagNum[pos-leftWindow])];
}
return tempTags;
}
*/
public BeamBestSequenceFinder(int beamSize) {
this(beamSize, false, false);
}
public BeamBestSequenceFinder(int beamSize, boolean exhaustiveStart) {
this(beamSize, exhaustiveStart, false);
}
public BeamBestSequenceFinder(int beamSize, boolean exhaustiveStart, boolean recenter) {
this.exhaustiveStart = exhaustiveStart;
this.beamSize = beamSize;
this.recenter = recenter;
}
}