com.aliasi.lm.NGramProcessLM Maven / Gradle / Ivy
Show all versions of aliasi-lingpipe Show documentation
/*
* LingPipe v. 4.1.0
* Copyright (C) 2003-2011 Alias-i
*
* This program is licensed under the Alias-i Royalty Free License
* Version 1 WITHOUT ANY WARRANTY, without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the Alias-i
* Royalty Free License Version 1 for more details.
*
* You should have received a copy of the Alias-i Royalty Free License
* Version 1 along with this program; if not, visit
* http://alias-i.com/lingpipe/licenses/lingpipe-license-1.txt or contact
* Alias-i, Inc. at 181 North 11th Street, Suite 401, Brooklyn, NY 11211,
* +1 (718) 290-9170.
*/
package com.aliasi.lm;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.io.BitInput;
import com.aliasi.io.BitOutput;
import com.aliasi.stats.Model;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Strings;
import java.io.Externalizable;
import java.io.InputStream;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.LinkedList;
/**
* An NGramProcessLM
provides a dynamic conditional
* process language model process for which training, estimation, and
* pruning may be interleaved. A process language model normalizes
* probablities for a given length of input.
*
* The model may be compiled to an object output stream; the
* compiled model read back in will be an instance of {@link
* CompiledNGramProcessLM}.
*
*
This class implements a generative language model based on the
* chain rule, as specified by {@link LanguageModel.Conditional}.
* The maximum likelihood estimator (see {@link CharSeqCounter}),
* is smoothed by linear interpolation with the next lower-order context
* model:
*
*
* P'(ck|cj,...,ck-1)
*
* = lambda(cj,...,ck-1)
* * PML(ck|cj,...,ck-1)
*
* + (1-lambda(cj,...,ck-1))
* * P'(ck|cj+1,...,ck-1)
*
*
* The PML
terms in the above definition
* are maximum likelihood estimates based on frequency:
*
*
* PML(ck|cj,...,ck-1)
* = count(cj,...,ck-1, ck)
* / extCount(cj,...,ck-1)
*
* The count
is just the number of times a given string
* has been seein the data, whereas extCount
is the number
* of times an extension to the string has been seen in the data:
*
*
* extCount(c1,...,cn
)
* = Σd count(c1,...,cn,d)
*
*
* In the parametric Witten-Bell method, the interpolation ratio
* lambda
is defined based on extensions of the context
* of estimation to be:
*
*
* lambda(c1,...,cn)
*
= extCount(c1,...,cn)
*
/ (extCount(c1,...,cn)
* + L * numExtensions(c1,...,cn))
*
*
* where
* c1,...,cn
* is the conditioning context for estimation, extCount
* is as defined above, where numExtensions
is the
* number of extensions of a context:
*
* * * and where* numExtensions(c1,...,cn) * = cardinality( { d | count(c1,...,cn,d) > 0 } )
L
is a hyperparameter of the distribution
* (described below).
*
* As a base case, P(ck)
is
* interpolated with the uniform distribution
* PU
, with interpolation defined
* as usual with the argument to lambda
being the
* empty (i.e. zero length) sequence:
*
*
* * The uniform distribution* P(d) = lambda() * PML(d) * + (1-lambda()) * PU(d)
PU
only
* depends on the number of possible characters used in training and
* tests:
*
* * * where* PU(c) = 1/alphabetSize
alphabetSize
is the maximum number of distinct
* characters in this model.
*
* The free hyperparameter L
in the smoothing equation
* determines the balance between higher-order and lower-order models.
* A higher value for L
gives more of the weight to
* lower-order contexts. As the amount of data grows against a fixed
* alphabet of characters, the impact of L
is reduced.
* In Witten and Bell's original paper, the hyperparameter
* L
was set to 1.0, which is not a particularly good
* choice for most text sources. A value of the lambda factor that is
* roughly the length of the longest n-gram seems to be a good rule of
* thumb.
*
*
Methods are provided for computing a sample cross-entropy rate
* for a character sequence. The sample cross-entropy
* H(c1,...,cn;PM)
for
* sequence c1,...,cn
in
* probability model PM
is defined to be the
* average log (base 2) probability of the characters in the sequence
* according to the model. In symbols:
*
*
* H(c1,...,cn;PM)
= (-log2 PM(c1,...,cn))/n
*
*
* The cross-entropy rate of distribution P'
* with respect to a distribution P
is defined by:
*
*
* H(P',P)
* = Σx
* P(x) * log2 P'(x)
*
*
* The Shannon-McMillan-Breiman theorem shows that as the length of
* the sample drawn from the true distribution P
grows,
* the sample cross-entropy rate approaches the actual cross-entropy
* rate. In symbols:
*
*
*
* H(P,PM)
* = limn->infinity
* H(c1,...,cn;PM)/n
*
*
*
* The entropy of a distribution P
is defined by its
* cross-entropy against itself, H(P,P)
. A
* distribution's entropy is a lower bound on its cross-entropy; in
* symbols, H(P',P) > H(P,P)
for all distributions
* P'
.
*
* Pruning
* *Models may be pruned by pruning the underlying substring * counter for the language model. This counter is returned by * the method {@link #substringCounter()}. See the class documentat * for the return result {@link TrieCharSeqCounter} for more information. * *
Serialization
* *Models may be serialized in the usual way by creating an object * output object and writing the object: * *
* * Reading just reverses the process: * ** NGramProcessLM lm = ...; * ObjectOutput out = ...; * out.writeObject(lm);
* * Serialization is based on the methods {@link #writeTo(OutputStream)} * and {@link #readFrom(InputStream)}. These write compressed forms of * the model to streams in binary format. * ** ObjectInput in = ...; * NGramProcessLM lm = (NGramProcessLM) in.readObject();
Warning: The object input and output used for * serialization must extend {@link InputStream} and {@link * OutputStream}. The only implementations of {@link ObjectInput} and * {@link ObjectOutput} as of the 1.6 JDK do extend the streams, so * this will only be a problem with customized object input or output * objects. If you need this method to work with custom input and * output objects that do not extend the corresponding streams, drop * us a line and we can perhaps refactor the output methods to remove * this restriction. * *
References
* *For information on the Witten-Bell interpolation method, see: *
-
*
- * Witten, Ian H. and Timothy C. Bell. 1991. The zero-frequency * problem: estimating the probabilities of novel events in adaptive * text compression. IEEE Transactions on Information Theory * 37(4). *
The counter argument allows serialized counters to be * read back in and used to create an n-gram process LM. * * @param numChars Maximum number of characters in training and * test data. * @param lambdaFactor Interpolation parameter (see class doc). * @param counter Character sequence counter to use. * @throws IllegalArgumentException If the number of characters is * not between 1 and the maximum number of characters, of if the * lambda factor is not greater than or equal to 0. */ public NGramProcessLM(int numChars, double lambdaFactor, TrieCharSeqCounter counter) { mMaxNGram = counter.mMaxLength; setLambdaFactor(lambdaFactor); // checks range setNumChars(numChars); mTrieCharSeqCounter = counter; } /** * Writes this language model to the specified output stream. * *
A language model is written using a {@link BitOutput} * wrapped around the specified output stream. This bit output is * used to delta encode the maximum n-gram, number of characters, * lambda factor times 1,000,000, and then the underlying sequence * counter using {@link * TrieCharSeqCounter#writeCounter(CharSeqCounter,TrieWriter,int)}. * The bit output is flushed, but the output stream is not closed. * *
A language model can be read and written using the following
* code, given a file f
:
*
*
* * @param out Output stream to which to write language model. * @throws IOException If there is an underlying I/O error. */ public void writeTo(OutputStream out) throws IOException { BitOutput bitOut = new BitOutput(out); writeTo(bitOut); bitOut.flush(); } void writeTo(BitOutput bitOut) throws IOException { bitOut.writeDelta(mMaxNGram); bitOut.writeDelta(mNumChars); bitOut.writeDelta((int) (mLambdaFactor * 1000000)); BitTrieWriter trieWriter = new BitTrieWriter(bitOut); TrieCharSeqCounter.writeCounter(mTrieCharSeqCounter,trieWriter, mMaxNGram); } /** * Reads a language model from the specified input stream. * ** NGramProcessLM lm = ...; * File f = ...; * OutputStream out = new FileOutputStream(f); * BufferedOutputStream bufOut = new BufferedOutputStream(out); * lm.writeTo(bufOut); * bufOut.close(); * * ... * InputStream in = new FileInputStream(f); * BufferedInputStream bufIn = new BufferedInputStream(in); * NGramProcessLM lm2 = NGramProcessLM.readFrom(bufIn); * bufIn.close();
See {@link #writeTo(OutputStream)} for information on the * binary I/O format. * * @param in Input stream from which to read a language model. * @return The language model read from the stream. * @throws IOException If there is an underlying I/O error. */ public static NGramProcessLM readFrom(InputStream in) throws IOException { BitInput bitIn = new BitInput(in); return readFrom(bitIn); } static NGramProcessLM readFrom(BitInput bitIn) throws IOException { int maxNGram = (int) bitIn.readDelta(); int numChars = (int) bitIn.readDelta(); double lambdaFactor = bitIn.readDelta() / 1000000.0; BitTrieReader trieReader = new BitTrieReader(bitIn); TrieCharSeqCounter counter = TrieCharSeqCounter.readCounter(trieReader,maxNGram); return new NGramProcessLM(numChars,lambdaFactor,counter); } public double log2Prob(CharSequence cSeq) { return log2Estimate(cSeq); } public double prob(CharSequence cSeq) { return java.lang.Math.pow(2.0,log2Estimate(cSeq)); } public final double log2Estimate(CharSequence cSeq) { char[] cs = Strings.toCharArray(cSeq); return log2Estimate(cs,0,cs.length); } public final double log2Estimate(char[] cs, int start, int end) { Strings.checkArgsStartEnd(cs,start,end); double sum = 0.0; for (int i = start+1; i <= end; ++i) sum += log2ConditionalEstimate(cs,start,i); return sum; } public void train(CharSequence cSeq) { train(cSeq,1); } public void train(CharSequence cSeq, int incr) { char[] cs = Strings.toCharArray(cSeq); train(cs,0,cs.length,incr); } public void train(char[] cs, int start, int end) { train(cs,start,end,1); } public void train(char[] cs, int start, int end, int incr) { Strings.checkArgsStartEnd(cs,start,end); mTrieCharSeqCounter.incrementSubstrings(cs,start,end,incr); } /** * Implements the object handler interface over character * sequences for training. The implementation delegates to {@link * #train(CharSequence)}. * * @param cSeq Character sequence on which to train. */ public void handle(CharSequence cSeq) { train(cSeq); } /** * Trains the specified conditional outcome(s) of the specified * character slice given the background slice. *
This method just shorthand for incrementing the counts of
* all substrings of cs
from position
* start
to end-1
inclusive, then
* decrementing all of the counts of substrings from position
* start
to condEnd-1
. For instance, if
* cs
is
* "abcde".toCharArray()
, then calling
* trainConditional(cs,0,5,2)
will increment the
* counts of cde
given ab
, but will not
* increment the counts of ab
directly. This increases
* the following probabilities:
*
*
* P('e'|"abcd")
* P('e'|"bcd")
* P('e'|"cd")
* P('e'|"d")
* P('e'|"")
*
* P('d'|"abc")
* P('d'|"bc")
* P('d'|"c")
* P('d'|"")
*
* P('c'|"ab")
* P('c'|"b")
* P('c'|"")
*
*
* but does not increase the following probabilities:
*
*
* P('b'|"a")
* P('b'|"")
*
* P('a'|"")
*
* @param cs Array of characters.
* @param start Start position for slice.
* @param end One past end position for slice.
* @param condEnd One past the end of the conditional portion of
* the slice.
*/
public void trainConditional(char[] cs, int start, int end,
int condEnd) {
Strings.checkArgsStartEnd(cs,start,end);
Strings.checkArgsStartEnd(cs,start,condEnd);
if (condEnd > end) {
String msg = "Conditional end must be < end."
+ " Found condEnd=" + condEnd
+ " end=" + end;
throw new IllegalArgumentException(msg);
}
if (condEnd == end) return;
mTrieCharSeqCounter.incrementSubstrings(cs,start,end);
mTrieCharSeqCounter.decrementSubstrings(cs,start,condEnd);
}
public char[] observedCharacters() {
return mTrieCharSeqCounter.observedCharacters();
}
/**
* Writes a compiled version of this process language model to the
* specified object output.
*
* The object written will be an instance of {@link * CompiledNGramProcessLM}. It may be read in by casting the * result of {@link ObjectInput#readObject()}. * *
Compilation is time consuming, because it must traverse the
* entire trie structure, and for each node, estimate its log
* probability and if it is internal, its log interpolation value.
* Given that time taken is proportional to the size of the trie,
* pruning first may greatly speed up this operation and reduce
* the size of the compiled object that is written.
*
* @param objOut Object output to which a compiled version of this
* langauge model will be written.
* @throws IOException If there is an I/O exception writing the
* compiled object.
*/
public void compileTo(ObjectOutput objOut) throws IOException {
objOut.writeObject(new Externalizer(this));
}
public double log2ConditionalEstimate(CharSequence cSeq) {
return log2ConditionalEstimate(cSeq,mMaxNGram,mLambdaFactor);
}
public double log2ConditionalEstimate(char[] cs, int start, int end) {
return log2ConditionalEstimate(cs,start,end,mMaxNGram,mLambdaFactor);
}
/**
* Returns the substring counter for this language model.
* Modifying the counts in the returned counter, such as by
* pruning, will change the estimates in this language model.
*
* @return Substring counter for this language model.
*/
public TrieCharSeqCounter substringCounter() {
return mTrieCharSeqCounter;
}
/**
* Returns the maximum n-gram length for this model.
*
* @return The maximum n-gram length for this model.
*/
public int maxNGram() {
return mMaxNGram;
}
/**
* Returns the log (base 2) conditional estimate of the last
* character in the specified character sequence given the
* previous characters based only on counts of n-grams up to the
* specified maximum n-gram. If the maximum n-gram argument is
* greater than or equal to the one supplied at construction time,
* the results wil lbe the same as the ordinary conditional
* estimate.
*
* @param cSeq Character sequence to estimate.
* @param maxNGram Maximum length of n-gram count to use for
* estimate.
* @param lambdaFactor Value of interpolation hyperparameter for
* this estimate.
* @return Log (base 2) conditional estimate.
* @throws IllegalArgumentException If the character sequence is not at
* least one character long.
*/
public double log2ConditionalEstimate(CharSequence cSeq, int maxNGram,
double lambdaFactor) {
char[] cs = Strings.toCharArray(cSeq);
return log2ConditionalEstimate(cs,0,cs.length,maxNGram,lambdaFactor);
}
/**
* Returns the log (base 2) conditional estimate for a specified
* character slice with a specified maximum n-gram and specified
* hyperparameter.
* @param cs Underlying character array.
* @param start Index of first character in slice.
* @param end Index of one past last character in slice.
* @param maxNGram Maximum length of n-gram to use in estimates.
* @param lambdaFactor Value of interpolation hyperparameter.
* @return Log (base 2) conditional estimate of the last character
* in the slice given the previous characters.
* @throws IndexOutOfBoundsException If the start index and end
* index minus one are out of range of the character array or if the
* character slice is less than one character long.
*/
public double log2ConditionalEstimate(char[] cs, int start, int end,
int maxNGram, double lambdaFactor) {
if (end <= start) {
String msg = "Conditional estimates require at least one character.";
throw new IllegalArgumentException(msg);
}
Strings.checkArgsStartEnd(cs,start,end);
checkMaxNGram(maxNGram);
checkLambdaFactor(lambdaFactor);
int maxUsableNGram = Math.min(maxNGram,mMaxNGram);
if (start == end) return 0.0;
double currentEstimate = mUniformEstimate;
int contextEnd = end-1;
int longestContextStart = Math.max(start,end-maxUsableNGram);
for (int currentContextStart = contextEnd;
currentContextStart >= longestContextStart;
--currentContextStart) {
long contextCount
= mTrieCharSeqCounter.extensionCount(cs,currentContextStart,contextEnd);
if (contextCount == 0) break;
long outcomeCount = mTrieCharSeqCounter.count(cs,currentContextStart,end);
double lambda = lambda(cs,currentContextStart,contextEnd,lambdaFactor);
currentEstimate
= lambda * (((double)outcomeCount) / (double)contextCount)
+ (1.0 - lambda) * currentEstimate;
}
return com.aliasi.util.Math.log2(currentEstimate);
}
/**
* Returns the interpolation ratio for the specified character
* slice interpreted as a context. The hyperparameter used is
* that returned by {@link #getLambdaFactor()}. The definition of
* lambda()
is provided in the class documentation
* above.
*
* @param cs Underlying character array.
* @param start Index of first character in slice.
* @param end Index of one past last character in slice.
* @throws IndexOutOfBoundsException If the start index and end
* index minus one are out of range of the character array.
*/
double lambda(char[] cs, int start, int end) {
return lambda(cs,start,end,getLambdaFactor());
}
/**
* Returns the interpolation ratio for the specified character
* slice interpreted as a context with the specified
* hyperparameter. The definition of lambda()
is
* provided in the class documentation above. *
* @param cs Underlying character array.
* @param start Index of first character in slice.
* @param end Index of one past last character in slice.
* @param lambdaFactor Value for interpolation ratio hyperparameter.
* @throws IndexOutOfBoundsException If the start index and end
* index minus one are out of range of the character array.
*/
double lambda(char[] cs, int start, int end, double lambdaFactor) {
checkLambdaFactor(lambdaFactor);
Strings.checkArgsStartEnd(cs,start,end);
double count = mTrieCharSeqCounter.extensionCount(cs,start,end);
if (count <= 0.0) return 0.0;
double numOutcomes = mTrieCharSeqCounter.numCharactersFollowing(cs,start,end);
return lambda(count,numOutcomes,lambdaFactor);
}
/**
* Returns the current setting of the interpolation ratio
* hyperparameter. See the class documentation above for
* information on how the interpolation ratio is used in
* estimates.
*
* @return The current setting of the interpolation ratio
* hyperparameter.
*/
public double getLambdaFactor() {
return mLambdaFactor;
}
/**
* Sets the value of the interpolation ratio hyperparameter
* to the specified value. See the class documentation above
* for information on how the interpolation ratio is used in estimates.
*
* @param lambdaFactor New value for interpolation ratio
* hyperparameter.
* @throws IllegalArgumentException If the value is not greater
* than or equal to zero.
*/
public final void setLambdaFactor(double lambdaFactor) {
checkLambdaFactor(lambdaFactor);
mLambdaFactor = lambdaFactor;
}
/**
* Sets the number of characters for this language model. All
* subsequent estimates will be based on this number. See the
* class definition above for information on how the number of
* character is used to determine the base case uniform
* distribution.
*
* @param numChars New number of characters for this language model.
* @throws IllegalArgumentException If the number of characters is
* less than 0
or more than
* Character.MAX_VALUE
.
*/
public final void setNumChars(int numChars) {
checkNumChars(numChars);
mNumChars = numChars;
mUniformEstimate = 1.0 / (double)mNumChars;
mLog2UniformEstimate
= com.aliasi.util.Math.log2(mUniformEstimate);
}
/**
* Returns a string-based representation of this language model.
*
* @return A string-based representation of this language model.
*/
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
toStringBuilder(sb);
return sb.toString();
}
void toStringBuilder(StringBuilder sb) {
sb.append("Max NGram=" + mMaxNGram + " ");
sb.append("Num characters=" + mNumChars + "\n");
sb.append("Trie of counts=\n");
mTrieCharSeqCounter.toStringBuilder(sb);
}
// need this for the process model to get boundaries right
void decrementUnigram(char c) {
decrementUnigram(c,1);
}
void decrementUnigram(char c, int count) {
mTrieCharSeqCounter.decrementUnigram(c,count);
}
private double lambda(double count, double numOutcomes,
double lambdaFactor) {
return count
/ (count + lambdaFactor * numOutcomes);
}
private double lambda(Node node) {
double count = node.contextCount(Strings.EMPTY_CHAR_ARRAY,0,0);
double numOutcomes = node.numOutcomes(Strings.EMPTY_CHAR_ARRAY,0,0);
return lambda(count,numOutcomes,mLambdaFactor);
}
private int lastInternalNodeIndex() {
int last = 1;
LinkedList