com.aliasi.chunk.TrainTokenShapeChunker Maven / Gradle / Ivy
Show all versions of aliasi-lingpipe Show documentation
/*
* LingPipe v. 4.1.0
* Copyright (C) 2003-2011 Alias-i
*
* This program is licensed under the Alias-i Royalty Free License
* Version 1 WITHOUT ANY WARRANTY, without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the Alias-i
* Royalty Free License Version 1 for more details.
*
* You should have received a copy of the Alias-i Royalty Free License
* Version 1 along with this program; if not, visit
* http://alias-i.com/lingpipe/licenses/lingpipe-license-1.txt or contact
* Alias-i, Inc. at 181 North 11th Street, Suite 401, Brooklyn, NY 11211,
* +1 (718) 290-9170.
*/
package com.aliasi.chunk;
import com.aliasi.tokenizer.TokenCategorizer;
import com.aliasi.tokenizer.Tokenizer;
import com.aliasi.tokenizer.TokenizerFactory;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.Strings;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
/**
* A TrainTokenShapeChunker
is used to train a token and
* shape-based chunker.
*
* Estimation is based on a joint model of tags
* T1,...,TN
and tokens W1,...,WN
, which is
* approximated with a limited history and smoothed using linear
* interpolation.
*
*
* By the chain rule:
*
* P(W1,...,WN,T1,...TN)
* = P(W1,T1) * P(W2,T2|W1,T1) * P(W3,T3|W1,W2,T1,T2)
* * ... * P(WN,TN|W1,...,WN-1,T1,...,TN-1)
*
* The longer contexts are approximated with the two previous
* tokens and one previous tag.
*
* P(WN,TN|W1,...,WN-1,T1,...,TN-1)
* ~ P(WN,TN|WN-2,WN-1,TN-1)
*
* The shorter contexts are padded with tags and tokens for the
* beginning of a stream, and an addition end-of-stream symbol is
* trained after the last symbol in the input.
* The joint model is further decomposed into a conditional tag model
* and a conditional token model by the chain rule:
*
* P(WN,TN|WN-2,WN-1,TN-1)
* = P(TN|WN-2,WN-1,TN-1)
* * P(WN|WN-2,WN-1,TN-1,TN)
*
* The token model is further approximated as:
*
* P(WN|WN-2,WN-1,TN-1,TN)
* ~ P(WN|WN-1,interior(TN-1),TN)
*
* where interior(TN-1)
is the interior
* version of a tag; for instance:
*
* interior("ST_PERSON").equals("PERSON")
* interior("PERSON").equals("PERSON")
*
* This performs what is known as "model tying", and it
* amounts to sharing the models for the two contexts.
* The tag model is also approximated by tying start
* and interior tag histories:
*
* P(TN|WN-2,WN-1,TN-1)
* ~ P(TN|WN-2,WN-1,interior(TN-1))
*
* The tag and token models are themselves simple
* linear interpolation models, with smoothing parameters defined
* by the Witten-Bell method. The order
* of contexts for the token model is:
*
* P(WN|TN,interior(TN-1),WN-1)
* ~ lambda(TN,interior(TN-1),WN-1) * P_ml(WN|TN,interior(TN-1),WN-1)
* + (1-lambda(")) * P(WN|TN,interior(TN-1))
*
* P(WN|TN,interior(TN-1))
* ~ lambda(TN,interior(TN-1)) * P_ml(WN|TN,interior(TN-1))
* + (1-lambda(")) * P(WN|TN)
*
* P(WN|TN) ~ lambda(TN) * P_ml(WN|TN)
* + 1-lambda(") * UNIFORM_ESTIMATE
*
*
* The last step is degenerate in that SUM_W P(W|T) =
* INFINITY
, because there are infinitely many possible tokens,
* and each is assigned the uniform estimate. To fix this, a model
* would be needed of character sequences that ensured SUM_W
* P(W|T) = 1.0
. (The steps to do the final uniform estimate
* are handled by the compiled estimator.)
*
* The tag estimator is smoothed by:
*
* P(TN|interior(TN-1),WN-1,WN-2)
* ~ lambda(interior(TN-1),WN-1,WN-2) * P_ml(TN|interior(TN-1),WN-1,WN-2)
* + (1-lambda(")) * P(TN|interior(TN-1),WN-1)
*
* P(TN|interior(TN-1),WN-1)
* ~ lambda(interior(TN-1),WN-1) * P_ml(TN|interior(TN-1),WN-1)
* + (1-lambda(")) * P_ml(TN|interior(TN-1))
*
*
* Note that the smoothing stops at estimating a tag in terms
* of the previous tags. This guarantees that only bigram tag
* sequences seen in the training data get non-zero probability
* under the estimator.
*
*
*
*
* Sequences of training pairs are added via {@link
* #handle(Chunking)} method.
*
* @author Bob Carpenter
* @version 4.0.0
* @since LingPipe1.0
*/
public class TrainTokenShapeChunker
implements ObjectHandler,
Compilable {
private final boolean mValidateTokenizer = false;
private final int mKnownMinTokenCount;
private final int mMinTokenCount;
private final int mMinTagCount;
private final TokenCategorizer mTokenCategorizer;
private final TokenizerFactory mTokenizerFactory;
private final TrainableEstimator mTrainableEstimator;
private final List mTokenList = new ArrayList();
private final List mTagList = new ArrayList();
/**
* Construct a trainer for a token/shape chunker based on
* the specified token categorizer and tokenizer factory.
* The other parameters receive default vaules. The
* interpolation ratio is set to 4.0
, the
* number of tokens to 3,000,000
, the
* known minimum token count to 8, and the min tag and
* token count for pruning to 1.
*
* @param categorizer Token categorizer for unknown tokens.
* @param factory Tokenizer factory for creating tokenizers.
*/
public TrainTokenShapeChunker(TokenCategorizer categorizer,
TokenizerFactory factory) {
this(categorizer,factory,
8, 1, 1);
}
/**
* Construct a trainer for a token/shape chunker based on
* the specified token categorizer, tokenizer factory and
* numerical parameters. The parameters are described in
* detail in the class documentation above.
*
* @param categorizer Token categorizer for unknown tokens.
* @param factory Tokenizer factory for tokenizing data.
* @param knownMinTokenCount Number of instances required for
* a token to count as known for unknown training.
* @param minTokenCount Minimum token count for token contexts to
* survive after pruning.
* @param minTagCount Minimum count for tag contexts to survive
* after pruning.
*/
public TrainTokenShapeChunker(TokenCategorizer categorizer,
TokenizerFactory factory,
int knownMinTokenCount,
int minTokenCount,
int minTagCount) {
mTokenCategorizer = categorizer;
mTokenizerFactory = factory;
mKnownMinTokenCount = knownMinTokenCount;
mMinTokenCount = minTokenCount;
mMinTagCount = minTagCount;
mTrainableEstimator = new TrainableEstimator(categorizer);
}
/**
* Trains the underlying estimator on the specified BIO-encoded
* chunk tagging.
*
* @param tokens Sequence of tokens to train.
* @param whitespaces Sequence of whitespaces (ignored).
* @param tags Sequence of tags to train.
* @throws IllegalArgumentException If the tags and tokens are
* different lengths.
* @throws NullPointerException If any of the tags or tokens are null.
*/
void handle(String[] tokens, String[] whitespaces, String[] tags) {
if (tokens.length != tags.length) {
String msg = "Tokens and tags must be same length."
+ " Found tokens.length=" + tokens.length
+ " tags.length=" + tags.length;
throw new IllegalArgumentException(msg);
}
for (int i = 0; i < tokens.length; ++i) {
if (tokens[i] == null || tags[i] == null) {
String msg = "Tags and tokens must not be null."
+ " Found tokens[" + i + "]=" + tokens[i]
+ " tags[" + i + "]=" + tags[i];
throw new NullPointerException(msg);
}
mTokenList.add(tokens[i]);
mTagList.add(tags[i]);
}
}
// cut and paste from old adapter; another copy in CharLmHmmChunker
/**
* Add the specified chunking as a training event.
*
* @param chunking Chunking for training.
*/
public void handle(Chunking chunking) {
CharSequence cSeq = chunking.charSequence();
char[] cs = Strings.toCharArray(cSeq);
Set chunkSet = chunking.chunkSet();
Chunk[] chunks = chunkSet.toArray(EMPTY_CHUNK_ARRAY);
Arrays.sort(chunks,Chunk.TEXT_ORDER_COMPARATOR);
List tokenList = new ArrayList();
List whiteList = new ArrayList();
List tagList = new ArrayList();
int pos = 0;
for (Chunk nextChunk : chunks) {
String type = nextChunk.type();
int start = nextChunk.start();
int end = nextChunk.end();
outTag(cs,pos,start,tokenList,whiteList,tagList,mTokenizerFactory);
chunkTag(cs,start,end,type,tokenList,whiteList,tagList,mTokenizerFactory);
pos = end;
}
outTag(cs,pos,cSeq.length(),tokenList,whiteList,tagList,mTokenizerFactory);
String[] toks = tokenList.toArray(Strings.EMPTY_STRING_ARRAY);
String[] whites = whiteList.toArray(Strings.EMPTY_STRING_ARRAY);
String[] tags = tagList.toArray(Strings.EMPTY_STRING_ARRAY);
if (mValidateTokenizer
&& !consistentTokens(toks,whites,mTokenizerFactory)) {
String msg = "Tokens not consistent with tokenizer factory."
+ " Tokens=" + Arrays.asList(toks)
+ " Tokenization=" + tokenization(toks,whites)
+ " Factory class=" + mTokenizerFactory.getClass();
throw new IllegalArgumentException(msg);
}
handle(toks,whites,tags);
}
/**
* Compiles a chunker based on the training data received by
* this trainer to the specified object output.
*
* @param objOut Object output to which the chunker is written.
* @throws IOException If there is an underlying I/O error.
*/
public void compileTo(ObjectOutput objOut) throws IOException {
objOut.writeObject(new Externalizer(this));
}
static class Externalizer extends AbstractExternalizable {
private static final long serialVersionUID = 142720610674437597L;
final TrainTokenShapeChunker mChunker;
public Externalizer() {
this(null);
}
public Externalizer(TrainTokenShapeChunker chunker) {
mChunker = chunker;
}
@Override
public Object read(ObjectInput in)
throws ClassNotFoundException, IOException {
TokenizerFactory factory = (TokenizerFactory) in.readObject();
TokenCategorizer categorizer = (TokenCategorizer) in.readObject();
// System.out.println("factory.class=" + factory.getClass());
// System.out.println("categorizer.class=" + categorizer.getClass());
CompiledEstimator estimator = (CompiledEstimator) in.readObject();
// System.out.println("estimator.class=" + estimator.getClass());
TokenShapeDecoder decoder
= new TokenShapeDecoder(estimator,categorizer,1000.0);
return new TokenShapeChunker(factory,decoder);
}
@Override
public void writeExternal(ObjectOutput objOut) throws IOException {
int len = mChunker.mTagList.size();
String[] tokens
= mChunker.mTokenList.toArray(Strings.EMPTY_STRING_ARRAY);
String[] tags
= mChunker.mTagList.toArray(Strings.EMPTY_STRING_ARRAY);
// train once with straight vals
mChunker.mTrainableEstimator.handle(tokens,tags);
// train again with unknown tokens replaced with categories
mChunker.replaceUnknownsWithCategories(tokens);
mChunker.mTrainableEstimator.handle(tokens,tags);
mChunker.mTrainableEstimator.prune(mChunker.mMinTagCount,
mChunker.mMinTokenCount);
// smoothe after prune for persistence
mChunker.mTrainableEstimator.smoothTags(1);
// write: tokfact, tokcat, estimator
AbstractExternalizable.compileOrSerialize(mChunker.mTokenizerFactory,objOut);
AbstractExternalizable.compileOrSerialize(mChunker.mTokenCategorizer,objOut);
mChunker.mTrainableEstimator.compileTo(objOut);
}
}
// copied from old adapter; another copy in CharLmHmmChunker
void replaceUnknownsWithCategories(String[] tokens) {
ObjectToCounterMap counter = new ObjectToCounterMap();
for (int i = 0; i < tokens.length; ++i)
counter.increment(tokens[i]);
for (int i = 0; i < tokens.length; ++i)
if (counter.getCount(tokens[i]) < mKnownMinTokenCount)
tokens[i] = mTokenCategorizer.categorize(tokens[i]);
}
static final Chunk[] EMPTY_CHUNK_ARRAY = new Chunk[0];
static void outTag(char[] cs, int start, int end,
List tokenList, List whiteList, List tagList,
TokenizerFactory factory) {
Tokenizer tokenizer = factory.tokenizer(cs,start,end-start);
whiteList.add(tokenizer.nextWhitespace());
String nextToken;
while ((nextToken = tokenizer.nextToken()) != null) {
tokenList.add(nextToken);
tagList.add(ChunkTagHandlerAdapter2.OUT_TAG);
whiteList.add(tokenizer.nextWhitespace());
}
}
static void chunkTag(char[] cs, int start, int end, String type,
List tokenList, List whiteList, List tagList,
TokenizerFactory factory) {
Tokenizer tokenizer = factory.tokenizer(cs,start,end-start);
String firstToken = tokenizer.nextToken();
tokenList.add(firstToken);
tagList.add(ChunkTagHandlerAdapter2.BEGIN_TAG_PREFIX + type);
while (true) {
String nextWhitespace = tokenizer.nextWhitespace();
String nextToken = tokenizer.nextToken();
if (nextToken == null) break;
tokenList.add(nextToken);
whiteList.add(nextWhitespace);
tagList.add(ChunkTagHandlerAdapter2.IN_TAG_PREFIX + type);
}
}
static boolean consistentTokens(String[] toks,
String[] whitespaces,
TokenizerFactory tokenizerFactory) {
if (toks.length+1 != whitespaces.length) return false;
char[] cs = getChars(toks,whitespaces);
Tokenizer tokenizer = tokenizerFactory.tokenizer(cs,0,cs.length);
String nextWhitespace = tokenizer.nextWhitespace();
if (!whitespaces[0].equals(nextWhitespace)) {
return false;
}
for (int i = 0; i < toks.length; ++i) {
String token = tokenizer.nextToken();
if (token == null) {
return false;
}
if (!toks[i].equals(token)) {
return false;
}
nextWhitespace = tokenizer.nextWhitespace();
if (!whitespaces[i+1].equals(nextWhitespace)) {
return false;
}
}
return true;
}
List tokenization(String[] toks, String[] whitespaces) {
List tokList = new ArrayList();
List whiteList = new ArrayList();
char[] cs = getChars(toks,whitespaces);
Tokenizer tokenizer = mTokenizerFactory.tokenizer(cs,0,cs.length);
tokenizer.tokenize(tokList,whiteList);
return tokList;
}
static char[] getChars(String[] toks, String[] whitespaces) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < toks.length; ++i) {
sb.append(whitespaces[i]);
sb.append(toks[i]);
}
sb.append(whitespaces[whitespaces.length-1]);
return Strings.toCharArray(sb);
}
}