All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.extract.LatticeViewer Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.extract;


import java.io.*;
import java.text.DecimalFormat;
import java.util.List;

import cc.mallet.fst.CRF;
import cc.mallet.fst.MaxLattice;
import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.types.*;

/**
 * Created: Oct 31, 2004
 *
 * @author ERROR OUTPUT\n\n");
  }


  private static void writeFooter (PrintWriter out)
  {
    out.println ("");
  }



  // Display HTML for one error
  private static void error2html (PrintWriter out, ExtorInfo info, int start, int end)
  {
    String anchor = info.idx+":"+start+":"+end;
    out.println ("

"); out.println ("

Instance "+info.desc+" Position "+start+"..."+end); if (info.link != null) { out.println ("[Lattice]"); } out.println ("

"); out.println (""); outputIndices (out, start, end); outputInputRow (out, info.input, start, end); outputTableRow (out, "target", info.target, info.predicted, start, end); outputTableRow (out, "predicted", info.predicted, info.target, start, end); if (info.lattice != null) { outputLatticeRows (out, info.lattice, start, end); outputTransitionCosts (out, info, start, end); outputFeatures (out, info.fvs, info.predicted, info.target, start, end); } out.println ("
"); } public static int numMaxViterbi = 5; private static void outputLatticeRows (PrintWriter out, MaxLattice lattice, int start, int end) { DecimalFormat f = new DecimalFormat ("0.##"); Transducer ducer = lattice.getTransducer (); int max = Math.min (numMaxViterbi, ducer.numStates()); List> stateSequences = lattice.bestStateSequences(max); for (int k = 0; k < max; k++) { out.println (" "); out.println (" δ rank "+k+""); for (int ip = start; ip < end; ip++) { Transducer.State state = stateSequences.get(k).get(ip+1); if (state.getName().equals (lattice.bestOutputSequence().get(ip))) { out.print (""); } else { out.print (""); } out.print (state.getName()+"
"+f.format (-lattice.getDelta (ip+1, state.getIndex ()))+""); } out.println (""); } } private static int numFeaturesToDisplay = 5; public static int getNumFeaturesToDisplay () { return numFeaturesToDisplay; } public static void setNumFeaturesToDisplay (int numFeaturesToDisplay) { LatticeViewer.numFeaturesToDisplay = numFeaturesToDisplay; } private static void outputTransitionCosts (PrintWriter out, ExtorInfo info, int start, int end) { Transducer ducer = info.lattice.getTransducer (); out.println (""); out.println ("Cost(pred. trans)"); for (int ip = start; ip < end; ip++) { if (ip == 0) { out.println (""); continue; } Transducer.State from = ((CRF) ducer).getState (info.bestStates.get (ip - 1).toString ()); Transducer.TransitionIterator iter = from.transitionIterator (info.fvs, ip, info.predicted, ip); if (iter.hasNext ()) { iter.next (); double cost = iter.getWeight(); String str = iter.describeTransition ((int) (Math.abs(cost) / FEATURE_CUTOFF_PCT)); out.print ("" + str + ""); } else { out.print ("No matching transition"); } } out.println (""); out.println (""); out.println ("Cost(target trans)"); for (int ip = start; ip < end; ip++) { if (ip == 0) { out.println (""); continue; } if (!seqMatches (info.predicted, info.target, ip) || !seqMatches (info.predicted, info.target, ip - 1)) { Transducer.State from = ((CRF) ducer).getState (info.target.get (ip - 1).toString ()); if (from == null) { out.println ("Could not find state for "+info.target.get(ip-1)+""); } else { Transducer.TransitionIterator iter = from.transitionIterator (info.fvs, ip, info.target, ip); if (iter.hasNext ()) { iter.next (); double cost = iter.getWeight(); String str = iter.describeTransition ((int) (Math.abs(cost) / FEATURE_CUTOFF_PCT)); out.print ("" + str + ""); } else { out.print ("No matching transition"); } } } else { out.print (""); } } out.println (""); out.println (""); out.println ("Cost (pred->target trans)"); for (int ip = start; ip < end; ip++) { if (ip == 0) { out.println (""); continue; } if (!seqMatches (info.predicted, info.target, ip) || !seqMatches (info.predicted, info.target, ip - 1)) { Transducer.State from = ((CRF) ducer).getState (info.bestStates.get (ip - 1).toString ()); Transducer.TransitionIterator iter = from.transitionIterator (info.fvs, ip, info.target, ip); if (iter.hasNext ()) { iter.next (); double cost = iter.getWeight(); String str = iter.describeTransition ((int) (Math.abs(cost) / FEATURE_CUTOFF_PCT)); out.print ("" + str + ""); } else { out.print ("No matching transition"); } } else { out.print (""); } } out.println (""); } private static void outputLatticeRows (PrintWriter out, SumLatticeDefault lattice, int start, int end) { DecimalFormat f = new DecimalFormat ("0.##"); Transducer ducer = lattice.getTransducer (); for (int k = 0; k < ducer.numStates(); k++) { Transducer.State state = ducer.getState (k); out.println (" "); out.println (" α("+state.getName()+")"); for (int ip = start; ip < end; ip++) { out.print (""+f.format (lattice.getAlpha (ip+1, state))+""); } out.println (""); } for (int k = 0; k < ducer.numStates(); k++) { Transducer.State state = ducer.getState (k); out.println (" "); out.println (" β("+state.getName()+")"); for (int ip = start; ip < end; ip++) { out.print (""+f.format (lattice.getBeta (ip+1, state))+""); } out.println (""); } for (int k = 0; k < ducer.numStates(); k++) { Transducer.State state = ducer.getState (k); out.println (" "); out.println (" γ("+state.getName()+")"); for (int ip = start; ip < end; ip++) { out.print (""+f.format (lattice.getGammaWeight(ip+1, state))+""); } out.println (""); } } private static void outputInputRow (PrintWriter out, TokenSequence input, int start, int end) { out.println (" "); out.println (" "); for (int ip = start; ip < end; ip++) { out.print (""+input.get(ip).getText()+""); } out.println (" "); } private static void outputIndices (PrintWriter out, int start, int end) { out.println (" "); out.println (" "); for (int ip = start; ip < end; ip++) { out.print (""+ip+""); } out.println (" "); } private static void outputTableRow (PrintWriter out, String cssClass, Sequence seq1, Sequence seq2, int start, int end) { out.println (" "); out.println (" "+cssClass+""); for (int i = start; i < end; i++) { if (seqMatches (seq1, seq2, i)) { out.print (""); } else { out.print (""); } out.print (seq1.get(i)); out.print (""); } out.println (" "); } private static void outputFeatures (PrintWriter out, FeatureVectorSequence fvs, Sequence in, Sequence output, int start, int end) { out.println (" \nFeatures"); for (int i = start; i < end; i++) { if (!seqMatches (in, output, i)) { out.print (""); FeatureVector fv = fvs.getFeatureVector (i); for (int k = 0; k < fv.numLocations (); k++) { out.print (fv.getAlphabet ().lookupObject (fv.indexAtLocation (k))); if (fv.valueAtLocation (k) != 1.0) { out.print (" "+fv.valueAtLocation (k)); } out.println ("
"); } out.println (""); } else { out.println (""); } } out.println (" "); } private static boolean seqMatches (Sequence seq1, Sequence seq2, int i) { return seq1.get(i).toString().equals (seq2.get(i).toString()); } private static boolean allSeqMatches (Sequence seq1, Sequence seq2, int start, int end) { for (int i = start; i < end; i++) { if (!seqMatches (seq1, seq2, i)) return false; } return true; } public static void extraction2html (Extraction extraction, CRFExtractor extor, PrintStream out) { PrintWriter writer = new PrintWriter (new OutputStreamWriter (out), true); extraction2html (extraction, extor, out, false); } public static void extraction2html (Extraction extraction, CRFExtractor extor, PrintWriter out) { extraction2html (extraction, extor, out, false); } public static void extraction2html (Extraction extraction, CRFExtractor extor, PrintStream out, boolean showLattice) { PrintWriter writer = new PrintWriter (new OutputStreamWriter (out), true); extraction2html (extraction, extor, writer, showLattice); } public static void extraction2html (Extraction extraction, CRFExtractor extor, PrintWriter out, boolean showLattice) { writeHeader (out); for (int i = 0; i < extraction.getNumDocuments (); i++) { DocumentExtraction docextr = extraction.getDocumentExtraction (i); String desc = docextr.getName(); String doc = ((CharSequence) docextr.getDocument ()).toString(); ExtorInfo info = infoForDoc (doc, desc, "N"+i, docextr, extor, showLattice); if (!showLattice) info.link = "lattice.html"; lattice2html (out, info); } writeFooter (out); } private static class ExtorInfo { TokenSequence input; Sequence predicted; LabelSequence target; FeatureVectorSequence fvs; MaxLattice lattice; Sequence bestStates; String link; // If non-null, name of HTML file to use for cross-links String desc; String idx; public ExtorInfo (TokenSequence input, Sequence predicted, LabelSequence target, String desc, String idx) { this.input = input; this.predicted = predicted; this.target = target; this.desc = desc; this.idx = idx; } } private static ExtorInfo infoForDoc (String doc, String desc, String idx, DocumentExtraction docextr, CRFExtractor extor, boolean showLattice) { // Instance c2 = new Instance (doc, null, null, null, extor.getTokenizationPipe ()); // TokenSequence input = (TokenSequence) c2.getData (); TokenSequence input = (TokenSequence) docextr.getInput (); LabelSequence target = docextr.getTarget (); Sequence predicted = docextr.getPredictedLabels (); ExtorInfo info = new ExtorInfo (input, predicted, target, desc, idx); if (showLattice == true) { CRF crf = extor.getCrf(); // xxx perhaps the next two lines could be a transducer method??? Instance carrier = extor.getFeaturePipe().pipe(new Instance (input, null, null, null)); info.fvs = (FeatureVectorSequence) carrier.getData (); info.lattice = new MaxLatticeDefault (crf, (Sequence) carrier.getData(), null); info.bestStates = info.lattice.bestOutputSequence(); } return info; } // Lattice files get too large if too many instances are written to one file private static final int EXTRACTIONS_PER_FILE = 25; public static void viewDualResults (File dir, Extraction e1, CRFExtractor extor1, Extraction e2, CRFExtractor extor2) throws IOException { if (e1.getNumDocuments () != e2.getNumDocuments ()) throw new IllegalArgumentException ("Extractions don't match: different number of docs."); PrintWriter errorStr = new PrintWriter (new FileWriter (new File (dir, "errors.html"))); writeDualExtractions (errorStr, e1, extor1, e2, extor2, 0, e1.getNumDocuments (), false); errorStr.close (); int max = e1.getNumDocuments (); for (int start = 0; start < max; start += EXTRACTIONS_PER_FILE) { int end = Math.min (start + EXTRACTIONS_PER_FILE, max); PrintWriter latticeStr = new PrintWriter (new FileWriter (new File (dir, "lattice-"+start+".html"))); writeDualExtractions (latticeStr, e1, extor1, e2, extor2, start, end, true); latticeStr.close (); } } private static String computeLatticeFname (int docIdx) { int htmlDocNo = docIdx / EXTRACTIONS_PER_FILE; // this will get integer truncated int start = EXTRACTIONS_PER_FILE * htmlDocNo; return "lattice-"+start+".html"; } private static void writeDualExtractions (PrintWriter out, Extraction e1, CRFExtractor extor1, Extraction e2, CRFExtractor extor2, int start, int end, boolean showLattice) { writeHeader (out); for (int i = start; i < end; i++) { DocumentExtraction doc1 = e1.getDocumentExtraction (i); DocumentExtraction doc2 = e2.getDocumentExtraction (i); String desc = doc1.getName(); String doc1Str = ((CharSequence) doc1.getDocument ()).toString(); String doc2Str = ((CharSequence) doc2.getDocument ()).toString(); if (!doc1Str.equals (doc2Str)) { System.err.println ("Skipping document "+i+": Extractions don't match"); continue; } Sequence targ1 = doc1.getPredictedLabels (); Sequence targ2 = doc2.getPredictedLabels (); if (!predictionsMatch (targ1, targ2)) { ExtorInfo info1 = infoForDoc (doc1Str, "CRF1::"+desc, "C1I"+i, doc1, extor1, showLattice); ExtorInfo info2 = infoForDoc (doc1Str, "CRF2::"+desc, "C2I"+i, doc2, extor2, showLattice); if (!showLattice) { // add links from errors.html --> lattice.html info1.link = info2.link = computeLatticeFname (i); } dualLattice2html (out, desc, info1, info2); } } writeFooter (out); } // if lattice == null, no alpha, beta values printed public static void dualLattice2html (PrintWriter out, String desc, ExtorInfo info1, ExtorInfo info2) { assert (info1.predicted.size() == info1.target.size()); assert (info1.input.size() == info1.predicted.size()); assert (info2.input.size() == info2.predicted.size()); assert (info2.predicted.size() == info2.target.size()); int N = info1.target.size(); for (int start = 0; start < N; start += LENGTH - 1) { int end = Math.min (info1.predicted.size(), start + LENGTH); if (!allSeqMatches (info1.predicted, info2.predicted, start, end)) { error2html (out, info1, start, end); error2html (out, info2, start, end); } } } private static boolean predictionsMatch (Sequence targ1, Sequence targ2) { if (targ1.size() != targ2.size()) return false; for (int i = 0; i < targ1.size(); i++) if (!targ1.get(i).toString().equals (targ2.get(i).toString())) return false; return true; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy