All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cmu.arktweetnlp.RunTagger Maven / Gradle / Ivy

The newest version!
package cmu.arktweetnlp;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.LineNumberReader;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.HashSet;
import java.util.List;

import cmu.arktweetnlp.impl.ModelSentence;
import cmu.arktweetnlp.impl.Sentence;
import cmu.arktweetnlp.impl.features.FeatureExtractor;
import cmu.arktweetnlp.impl.features.WordClusterPaths;
import cmu.arktweetnlp.io.CoNLLReader;
import cmu.arktweetnlp.io.JsonTweetReader;
import cmu.arktweetnlp.util.BasicFileIO;
import edu.stanford.nlp.util.StringUtils;

/**
 * Commandline interface to run the Twitter POS tagger with a variety of possible input and output formats.
 * Also does basic evaluation if given labeled input text.
 * 
 * For basic usage of the tagger from Java, see instead Tagger.java.
 */
public class RunTagger {
	Tagger tagger;
	
	// Commandline I/O-ish options
	String inputFormat = "auto";
	String outputFormat = "auto";
	int inputField = 1;
	
	String inputFilename;
	/** Can be either filename or resource name **/
	String modelFilename = "/cmu/arktweetnlp/model.20120919";

	public boolean noOutput = false;
	public boolean justTokenize = false;
	
	public static enum Decoder { GREEDY, VITERBI };
	public Decoder decoder = Decoder.GREEDY; 
	public boolean showConfidence = true;

	PrintStream outputStream;
	Iterable inputIterable = null;
	
	// Evaluation stuff
	private static HashSet _wordsInCluster;
	// Only for evaluation mode (conll inputs)
	int numTokensCorrect = 0;
	int numTokens = 0;
	int oovTokensCorrect = 0;
	int oovTokens = 0;
	int clusterTokensCorrect = 0;
	int clusterTokens = 0;
	
	public static void die(String message) {
		// (BTO) I like "assert false" but assertions are disabled by default in java
		System.err.println(message);
		System.exit(-1);
	}
	public RunTagger() throws UnsupportedEncodingException {
		// force UTF-8 here, so don't need -Dfile.encoding
		this.outputStream = new PrintStream(System.out, true, "UTF-8");
	}
	public void detectAndSetInputFormat(String tweetData) throws IOException {
		JsonTweetReader jsonTweetReader = new JsonTweetReader();
		if (jsonTweetReader.isJson(tweetData)) {
			System.err.println("Detected JSON input format");
			inputFormat = "json";
		} else {
			System.err.println("Detected text input format");
			inputFormat = "text";
		}
	}
	
	public void runTagger() throws IOException, ClassNotFoundException {
		
		tagger = new Tagger();
		if (!justTokenize) {
			tagger.loadModel(modelFilename);			
		}
		
		if (inputFormat.equals("conll")) {
			runTaggerInEvalMode();
			return;
		} 

		JsonTweetReader jsonTweetReader = new JsonTweetReader();
		
		LineNumberReader reader = new LineNumberReader(BasicFileIO.openFileToReadUTF8(inputFilename));
		String line;
		long currenttime = System.currentTimeMillis();
		int numtoks = 0;
		while ( (line = reader.readLine()) != null) {
			String[] parts = line.split("\t");
			String tweetData = parts[inputField-1];
			
			if (reader.getLineNumber()==1) {
				if (inputFormat.equals("auto")) {
					detectAndSetInputFormat(tweetData);
				}
			}
			
			String text;
			if (inputFormat.equals("json")) {
				text = jsonTweetReader.getText(tweetData);
				if (text==null) {
					System.err.println("Warning, null text (JSON parse error?), using blank string instead");
					text = "";
				}
			} else {
				text = tweetData;
			}
			
			Sentence sentence = new Sentence();
			
			sentence.tokens = Twokenize.tokenizeRawTweetText(text);
			ModelSentence modelSentence = null;

			if (sentence.T() > 0 && !justTokenize) {
				modelSentence = new ModelSentence(sentence.T());
				tagger.featureExtractor.computeFeatures(sentence, modelSentence);
				goDecode(modelSentence);
			}
				
			if (outputFormat.equals("conll")) {
				outputJustTagging(sentence, modelSentence);
			} else {
				outputPrependedTagging(sentence, modelSentence, justTokenize, line);				
			}
			numtoks += sentence.T();
		}
		long finishtime = System.currentTimeMillis();
		System.err.printf("Tokenized%s %d tweets (%d tokens) in %.1f seconds: %.1f tweets/sec, %.1f tokens/sec\n",
				justTokenize ? "" : " and tagged", 
				reader.getLineNumber(), numtoks, (finishtime-currenttime)/1000.0,
				reader.getLineNumber() / ((finishtime-currenttime)/1000.0),
				numtoks / ((finishtime-currenttime)/1000.0)
		);
		reader.close();
	}

	/** Runs the correct algorithm (make config option perhaps) **/
	public void goDecode(ModelSentence mSent) {
		if (decoder == Decoder.GREEDY) {
			tagger.model.greedyDecode(mSent, showConfidence);
		} else if (decoder == Decoder.VITERBI) {
//			if (showConfidence) throw new RuntimeException("--confidence only works with greedy decoder right now, sorry, yes this is a lame limitation");
			tagger.model.viterbiDecode(mSent);
		}		
	}
	
	public void runTaggerInEvalMode() throws IOException, ClassNotFoundException {
		
		long t0 = System.currentTimeMillis();
		int n=0;

		List examples = CoNLLReader.readFile(inputFilename); 
		inputIterable = examples;

		int[][] confusion = new int[tagger.model.numLabels][tagger.model.numLabels];
		
		for (Sentence sentence : examples) {
			n++;
			
			ModelSentence mSent = new ModelSentence(sentence.T());
			tagger.featureExtractor.computeFeatures(sentence, mSent);
			goDecode(mSent);
			
			if ( ! noOutput) {
				outputJustTagging(sentence, mSent);	
			}
			evaluateSentenceTagging(sentence, mSent);
			//evaluateOOV(sentence, mSent);
			//getconfusion(sentence, mSent, confusion);
		}

		System.err.printf("%d / %d correct = %.4f acc, %.4f err\n", 
				numTokensCorrect, numTokens,
				numTokensCorrect*1.0 / numTokens,
				1 - (numTokensCorrect*1.0 / numTokens)
		);
		double elapsed = ((double) (System.currentTimeMillis() - t0)) / 1000.0;
		System.err.printf("%d tweets in %.1f seconds, %.1f tweets/sec\n",
				n, elapsed, n*1.0/elapsed);
		
/*		System.err.printf("%d / %d cluster words correct = %.4f acc, %.4f err\n", 
				oovTokensCorrect, oovTokens,
				oovTokensCorrect*1.0 / oovTokens,
				1 - (oovTokensCorrect*1.0 / oovTokens)
		);	*/
/*		int i=0;
		System.out.println("\t"+tagger.model.labelVocab.toString().replaceAll(" ", ", "));
		for (int[] row:confusion){
			System.out.println(tagger.model.labelVocab.name(i)+"\t"+Arrays.toString(row));
			i++;
		}		*/
	}
	
	private void evaluateOOV(Sentence lSent, ModelSentence mSent) throws FileNotFoundException, IOException, ClassNotFoundException {
		for (int t=0; t < mSent.T; t++) {
			int trueLabel = tagger.model.labelVocab.num(lSent.labels.get(t));
			int predLabel = mSent.labels[t];
			if(wordsInCluster().contains(lSent.tokens.get(t))){
				oovTokensCorrect += (trueLabel == predLabel) ? 1 : 0;
				oovTokens += 1;
			}
		}
    }
	private void getconfusion(Sentence lSent, ModelSentence mSent, int[][] confusion) {
		for (int t=0; t < mSent.T; t++) {
			int trueLabel = tagger.model.labelVocab.num(lSent.labels.get(t));
			int predLabel = mSent.labels[t];
			if(trueLabel!=-1)
				confusion[trueLabel][predLabel]++;
		}
		
		
    }
	public void evaluateSentenceTagging(Sentence lSent, ModelSentence mSent) {
		for (int t=0; t < mSent.T; t++) {
			int trueLabel = tagger.model.labelVocab.num(lSent.labels.get(t));
			int predLabel = mSent.labels[t];
			numTokensCorrect += (trueLabel == predLabel) ? 1 : 0;
			numTokens += 1;
		}
	}
	
	private String formatConfidence(double confidence) {
		// too many decimal places wastes space
		return String.format("%.4f", confidence);
	}

	/**
	 * assume mSent's labels hold the tagging.
	 */
	public void outputJustTagging(Sentence lSent, ModelSentence mSent) {
		// mSent might be null!

		if (outputFormat.equals("conll")) {
			for (int t=0; t < lSent.T(); t++) {
				outputStream.printf("%s\t%s", 
						lSent.tokens.get(t),  
						tagger.model.labelVocab.name(mSent.labels[t]));
				if (mSent.confidences != null) {
					outputStream.printf("\t%s", formatConfidence(mSent.confidences[t]));
				}
				outputStream.printf("\n");
			}
			outputStream.println("");
		} 
		else {
			die("bad output format for just tagging: " + outputFormat);
		}
	}
	/**
	 * assume mSent's labels hold the tagging.
	 * 
	 * @param lSent
	 * @param mSent
	 * @param inputLine -- assume does NOT have trailing newline.  (default from java's readLine)
	 */
	public void outputPrependedTagging(Sentence lSent, ModelSentence mSent, 
			boolean suppressTags, String inputLine) {
		// mSent might be null!
		
		int T = lSent.T();
		String[] tokens = new String[T];
		String[] tags = new String[T];
		String[] confs = new String[T];
		for (int t=0; t < T; t++) {
			tokens[t] = lSent.tokens.get(t);
			if (!suppressTags) {
				tags[t] = tagger.model.labelVocab.name(mSent.labels[t]);	
			}
			if (showConfidence) {
				confs[t] = formatConfidence(mSent.confidences[t]);
			}
		}
		
		StringBuilder sb = new StringBuilder();
		sb.append(StringUtils.join(tokens));
		sb.append("\t");
		if (!suppressTags) {
			sb.append(StringUtils.join(tags));
			sb.append("\t");
		}
		if (showConfidence) {
			sb.append(StringUtils.join(confs));
			sb.append("\t");
		}
		sb.append(inputLine);
		
		outputStream.println(sb.toString());
	}


	///////////////////////////////////////////////////////////////////


	public static void main(String[] args) throws IOException, ClassNotFoundException {        
		if (args.length > 0 && (args[0].equals("-h") || args[0].equals("--help"))) {
			usage();
		}

		RunTagger tagger = new RunTagger();

		int i = 0;
		while (i < args.length) {
			if (!args[i].startsWith("-")) {
				break;
			} else if (args[i].equals("--model")) {
				tagger.modelFilename = args[i+1];
				i += 2;
			} else if (args[i].equals("--just-tokenize")) {
				tagger.justTokenize = true;
				i += 1;
			} else if (args[i].equals("--decoder")) {
				if (args[i+1].equals("viterbi")) tagger.decoder = Decoder.VITERBI;
				else if (args[i+1].equals("greedy"))  tagger.decoder = Decoder.GREEDY;
				else die("unknown decoder " + args[i+1]);
				i += 2;
			} else if (args[i].equals("--quiet")) {
				tagger.noOutput = true;
				i += 1;
			} else if (args[i].equals("--input-format")) {
				String s = args[i+1];
				if (!(s.equals("json") || s.equals("text") || s.equals("conll")))
					usage("input format must be: json, text, or conll");
				tagger.inputFormat = args[i+1];
				i += 2;
			} else if (args[i].equals("--output-format")) {
				tagger.outputFormat = args[i+1];
				i += 2;
			} else if (args[i].equals("--input-field")) {
				tagger.inputField = Integer.parseInt(args[i+1]);
				i += 2;
			} else if (args[i].equals("--word-clusters")) {
				WordClusterPaths.clusterResourceName = args[i+1];
				i += 1;
			} else if (args[i].equals("--no-confidence")) {
				tagger.showConfidence = false;
				i += 1;
			}	
			else {
				System.out.println("bad option " + args[i]);
				usage();                
			}
		}
		
		if (args.length - i > 1) usage();
		if (args.length == i || args[i].equals("-")) {
			System.err.println("Listening on stdin for input.  (-h for help)");
			tagger.inputFilename = "/dev/stdin";
		} else {
			tagger.inputFilename = args[i];
		}
		
		tagger.finalizeOptions();
		
		tagger.runTagger();		
	}
	
	public void finalizeOptions() throws IOException {
		if (outputFormat.equals("auto")) {
			if (inputFormat.equals("conll")) {
				outputFormat = "conll";
			} else {
				outputFormat = "pretsv";
			}
		}
		if (showConfidence && decoder==Decoder.VITERBI) {
			System.err.println("Confidence output is unimplemented in Viterbi, turning it off.");
			showConfidence = false;
		}
		if (justTokenize) {
			showConfidence = false;
		}
	}
	
	public static void usage() {
		usage(null);
	}

	public static void usage(String extra) {
		System.out.println(
"RunTagger [options] [ExamplesFilename]" +
"\n  runs the CMU ARK Twitter tagger on tweets from ExamplesFilename, " +
"\n  writing taggings to standard output. Listens on stdin if no input filename." +
"\n\nOptions:" +
"\n  --model         Specify model filename. (Else use built-in.)" +
"\n  --just-tokenize           Only run the tokenizer; no POS tags." +
"\n  --quiet                   Quiet: no output" +
"\n  --input-format    Default: auto" +
"\n                            Options: json, text, conll" +
"\n  --output-format   Default: automatically decide from input format." +
"\n                            Options: pretsv, conll" +
"\n  --input-field NUM         Default: 1" +
"\n                            Which tab-separated field contains the input" +
"\n                            (1-indexed, like unix 'cut')" +
"\n                            Only for {json, text} input formats." +
"\n  --word-clusters     Alternate word clusters file (see FeatureExtractor)" +
"\n  --no-confidence           Don't output confidence probabilities" +
"\n  --decoder        Change the decoding algorithm (default: greedy)" +
"\n" +
"\nTweet-per-line input formats:" +
"\n   json: Every input line has a JSON object containing the tweet," +
"\n         as per the Streaming API. (The 'text' field is used.)" +
"\n   text: Every input line has the text for one tweet." +
"\nWe actually assume input lines are TSV and the tweet data is one field."+
"\n(Therefore tab characters are not allowed in tweets." +
"\nTwitter's own JSON formats guarantee this;" +
"\nif you extract the text yourself, you must remove tabs and newlines.)" +
"\nTweet-per-line output format is" +
"\n   pretsv: Prepend the tokenization and tagging as new TSV fields, " +
"\n           so the output includes a complete copy of the input." +
"\nBy default, three TSV fields are prepended:" +
"\n   Tokenization \\t POSTags \\t Confidences \\t (original data...)" +
"\nThe tokenization and tags are parallel space-separated lists." +
"\nThe 'conll' format is token-per-line, blank spaces separating tweets."+
"\n");
		
		if (extra != null) {
			System.out.println("ERROR: " + extra);
		}
		System.exit(1);
	}
	public static HashSet wordsInCluster() {
		if (_wordsInCluster==null) {
			_wordsInCluster = new HashSet(WordClusterPaths.wordToPath.keySet());
		}
		return _wordsInCluster;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy