
de.julielab.jsbd.SentenceSplitterApplication Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jcore-jsbd-ae Show documentation
Show all versions of jcore-jsbd-ae Show documentation
UIMA Wrapper for the JULIE Sentence Boundary Detector.
/**
* SentenceSplitterApplication.java
*
* Copyright (c) 2015, JULIE Lab.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the GNU Lesser General Public License (LGPL) v3.0
*
* Author: tomanek
*
* Current version: 2.0
* Since version: 1.0
*
* Creation date: Aug 01, 2006
**/
package de.julielab.jsbd;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.Random;
import java.util.TreeSet;
import java.util.zip.GZIPInputStream;
import cc.mallet.fst.CRF;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelSequence;
/**
* * The user interface (command line version) for the JULIE Sentence Boundary Detector. Includes
* training, prediction, file format check, and evaluation.
*
* When splits are done (e.g., for 90/10 or X-Val) the same randomization is enforced by seeding a
* random number generator with 1.
*
* @author tomanek
*/
public class SentenceSplitterApplication {
private final static boolean doPostprocessing = true;
public static void main(String[] args) {
if (args.length < 1) {
System.err.println("usage: JSBD {mode_specific_parameters}");
System.err.println("different modes:");
System.err.println("c: check texts");
System.err.println("t: train a sentence splitting model");
System.err.println("p: do the sentence splitting");
System.err.println("s: evaluation with 90-10 split");
System.err.println("x: evaluation with cross-validation");
System.err.println("e: evaluation on previously trained model");
System.exit(-1);
}
String mode = args[0];
if (mode.equals("c")) { // check mode
startCheckMode(args);
} else if (mode.equals("t")) { // training mode
startTrainingMode(args);
} else if (mode.equals("p")) { // prediction mode
startPredictionMode(args);
} else if (mode.equals("x")) { // cross validation mode
startXValidationMode(args);
} else if (mode.equals("s")) { // 90-10 validation split mode
start9010ValidationMode(args);
} else if (mode.equals("e")) { // compare validation mode
startCompareValidationMode(args);
} else { // unknown mode
System.err.println("Unknown run mode.");
System.exit(-1);
}
}
private static void startCompareValidationMode(String[] args) {
System.out.println("performing evaluation previously trained model.");
if (args.length != 4) {
System.err.println("usage: JSBD e ");
System.exit(-1);
}
ObjectInputStream in;
CRF crf = null;
try {
// load model
in = new ObjectInputStream(new GZIPInputStream(new FileInputStream(args[1])));
crf = (CRF) in.readObject();
in.close();
} catch (Exception e) {
e.printStackTrace();
}
File abstractDir = new File(args[2]);
if (!abstractDir.isDirectory()) {
System.err.println("Error: the specified directory does not exist.");
System.exit(-1);
}
File[] abstractArray = abstractDir.listFiles();
TreeSet errorList = new TreeSet();
EvalResult er = doEvaluation(crf, abstractArray, errorList);
writeFile(errorList, new File(args[3]));
System.out.println("\n\nAccuracy on pretrained model: " + er.ACC);
System.exit(0);
}
/**
* Entry point for 90-10 split.
*
* @param args
* the command line arguments.
*/
private static void start9010ValidationMode(String[] args) {
System.out.println("performing evaluation on 90/10 split");
if (args.length != 3) {
System.err.println("usage: JSBD s ");
System.exit(-1);
}
File abstractDir = new File(args[1]);
if (!abstractDir.isDirectory()) {
System.err.println("Error: the specified directory does not exist.");
System.exit(-1);
}
File[] abstractArray = abstractDir.listFiles();
TreeSet errorList = new TreeSet();
EvalResult er = do9010Evaluation(abstractArray, errorList);
writeFile(errorList, new File(args[2]));
System.out.println("\n\nAccuracy on 90/10 split: " + er.ACC);
System.exit(0);
}
/**
* Entry point for cross validation mode
*
* @param args
* the command line mode
*/
private static void startXValidationMode(String[] args) {
System.out.println("performing cross-validation");
if (args.length != 4) {
System.err.println("usage: JSBD x ");
System.exit(-1);
}
File abstractDir = new File(args[1]);
if (!abstractDir.isDirectory()) {
System.err.println("Error: the specified directory does not exist.");
System.exit(-1);
}
File[] abstractArray = abstractDir.listFiles();
int n = ((new Integer(args[2])).intValue());
if (n > (abstractArray.length / 2) || n > 10 || n < 2) {
System.err.println("Error: cannot perform " + n + " cross-validation rounds. Choose n in [2:10].");
System.exit(-1);
}
TreeSet errorList = new TreeSet();
double acc = doCrossEvaluation(abstractArray, n, errorList);
writeFile(errorList, new File(args[3]));
System.out.println("\n\nAccuracy on cross validation: " + acc);
System.exit(0);
}
/**
* Entry point for prediction mode
*
* @param args
* the command line arguments
*/
private static void startPredictionMode(String[] args) {
System.out.println("doing the sentence splitting...");
if (args.length != 4) {
System.err.println("usage: JSBD p ");
System.exit(-1);
}
File inDir = new File(args[1]);
if (!inDir.isDirectory()) {
System.err.println("Error: the specified input directory does not exist.");
System.exit(-1);
}
File[] inFiles = inDir.listFiles();
File outDir = new File(args[2]);
if (!outDir.isDirectory()) {
System.err.println("Error: the specified output directory does not exist.");
System.exit(-1);
}
String modelFilename = args[3];
doPrediction(inFiles, outDir, modelFilename);
}
/**
* Entry point for training mode
*
* @param args
* the command line arguments
*/
private static void startTrainingMode(String[] args) {
System.out.println("training the model...");
if (args.length != 3) {
System.err.println("usage: JSBD t ");
System.exit(-1);
}
File trainDir = new File(args[1]);
if (!trainDir.isDirectory()) {
System.err.println("Error: the specified directory does not exist.");
System.exit(-1);
}
File[] trainFiles = trainDir.listFiles();
System.out.println("number of files to train on: " + trainFiles.length);
String modelFilename = args[2];
doTraining(trainFiles, modelFilename);
System.out.println("Saved model to: " + modelFilename);
}
/**
* Entry point for check mode
*
* @param args
* the command line arguments
*/
private static void startCheckMode(String[] args) {
System.out.println("checking abstracts...");
if (args.length != 2) {
System.err.println("usage: JSBD c ");
System.exit(-1);
}
File abstractDir = new File(args[1]);
if (!abstractDir.isDirectory()) {
System.err.println("Error: the specified directory does not exist.");
System.exit(-1);
}
File[] abstractArray = abstractDir.listFiles();
// check data for validity:
doCheckAbstracts(abstractArray);
System.exit(0);
}
/**
* checks the data for validity... just for the beginning and to debug
*
* @param abstractList
*/
private static void doCheckAbstracts(File[] abstractList) {
SentenceSplitter tpFunctions = new SentenceSplitter();
tpFunctions.makeTrainingData(abstractList, false);
System.out.println("done.");
}
/**
* evaluation via 90-10 split of data
*/
private static EvalResult do9010Evaluation(File[] abstractArray, TreeSet errorList) {
ArrayList abstractList = new ArrayList();
for (int i = 0; i < abstractArray.length; i++)
abstractList.add(abstractArray[i]);
Collections.shuffle(abstractList, new Random(1));
int sizeAll = abstractList.size();
int sizeTest = (int) (sizeAll * 0.1);
int sizeTrain = sizeAll - sizeTest;
if (sizeTest == 0) {
System.err.println("Error: no test files for this split. Number of files in directory might be too small.");
System.exit(-1);
}
System.out.println("all: " + sizeAll + "\ttrain: " + sizeTrain + "\t" + "test: " + sizeTest);
File[] trainFiles = new File[sizeTrain];
File[] predictFiles = new File[sizeTest];
for (int i = 0; i < sizeTrain; i++)
trainFiles[i] = abstractList.get(i);
int j = 0;
for (int i = sizeTrain; i < abstractList.size(); i++)
predictFiles[j++] = abstractList.get(i);
return doEvaluation(trainFiles, predictFiles, errorList);
}
/**
* cross-evaluation, returns average accuracy
*
* @param abstractArray
* an array of File-objects
* @param n
* the number of rounds for cross-validation
* @return avg accuracy over all x-validation rounds
*/
private static double doCrossEvaluation(File[] abstractArray, int n, TreeSet errorList) {
ArrayList abstractList = new ArrayList();
for (int i = 0; i < abstractArray.length; i++)
abstractList.add(abstractArray[i]);
Collections.shuffle(abstractList, new Random(1));
int pos = 0;
int sizeRound = abstractArray.length / n;
int sizeAll = abstractArray.length;
int sizeLastRound = sizeRound + sizeAll % n;
System.out.println("number of files in directory: " + sizeAll);
System.out.println("size of each/last round: " + sizeRound + "/" + sizeLastRound);
System.out.println();
EvalResult[] evalResults = new EvalResult[n]; //
double avgAcc = 0;
double avgF = 0;
for (int i = 0; i < n; i++) { // in each round
File[] trainFiles;
File[] predictFiles;
int p = 0;
int t = 0;
if (i == n - 1) {
// last round
trainFiles = new File[sizeAll - sizeLastRound];
predictFiles = new File[sizeLastRound];
for (int j = 0; j < abstractList.size(); j++) {
File f = abstractList.get(j);
if (j < pos) {
trainFiles[t] = f;
t++;
} else {
predictFiles[p] = f;
p++;
}
}
} else {
// other rounds
trainFiles = new File[sizeAll - sizeRound];
predictFiles = new File[sizeRound];
for (int j = 0; j < abstractList.size(); j++) {
File f = abstractList.get(j);
if (j < pos || j >= (pos + sizeRound)) {
trainFiles[t] = f;
t++;
} else {
predictFiles[p] = f;
p++;
}
}
pos += sizeRound;
}
// now evaluate for this round
System.out.println("training size: " + trainFiles.length);
System.out.println("prediction size: " + predictFiles.length);
evalResults[i] = doEvaluation(trainFiles, predictFiles, errorList);
}
DecimalFormat df = new DecimalFormat("0.000");
for (int i = 0; i < evalResults.length; i++) {
avgAcc += evalResults[i].ACC;
avgF += evalResults[i].getF();
System.out.println(i + ": " + df.format(evalResults[i].ACC));
}
avgAcc = avgAcc / (double) n;
avgF = avgF / (double) n;
System.out.println("avg accuracy: " + df.format(avgAcc));
System.out.println("avg f-score: " + df.format(avgF));
return avgAcc;
}
/**
* normal evaluation, returns the accuracy errorList has format: filenameorglabelpred
* labeltoken
*
* @param trainFiles
* the files from which the model should be learned
* @param predictFiles
* the files for evaluated prediction
* @param errorList:
* write classification errors there stored...
* @return accuracy
*/
private static EvalResult doEvaluation(File[] trainFiles, File[] predictFiles, TreeSet errorList) {
SentenceSplitter tpFunctions = new SentenceSplitter();
// get EOS symbols
EOSSymbols eoss = new EOSSymbols();
// get training data
InstanceList trainData = tpFunctions.makeTrainingData(trainFiles, false);
Pipe myPipe = trainData.getPipe();
// train a model
System.out.println("training...");
tpFunctions.train(trainData, myPipe);
if (true)
return doEvaluation(tpFunctions.getModel(), predictFiles, errorList);
// get testing data
InstanceList predictData = tpFunctions.makePredictionData(predictFiles, myPipe);
// predict with model and evaluate
System.out.println("predicting...");
int corr = 0;
int all = 0;
int fp = 0;
int fn = 0;
double acc = 0;
for (int i = 0; i < predictData.size(); i++) {
Instance inst = (Instance) predictData.get(i);
String abstractName = (String) inst.getSource();
ArrayList units = null;
try {
units = tpFunctions.predict(inst, doPostprocessing);
} catch (Exception e) {
e.printStackTrace();
}
ArrayList orgLabels = getLabelsFromLabelSequence((LabelSequence) inst.getTarget());
for (int j = 0; j < units.size(); j++) {
String unitRep = units.get(j).rep;
String pred = units.get(j).label;
String org = orgLabels.get(j);
if (eoss.tokenEndsWithEOSSymbol(unitRep)) { // evaluate only if
// token ends with
// EOS symbol
all++;
if (pred.equals(org))
corr++;
else { // store errors
String error = abstractName + "\t" + org + "\t" + pred + "\t" + unitRep + " (" + j + ")";
// System.out.println(error);
errorList.add(error);
if (pred.equals("EOS") && org.equals("IS"))
fp++;
else if (pred.equals("IS") && org.equals("EOS"))
fn++;
}
}
}
}
acc = corr / (double) all;
EvalResult er = new EvalResult();
er.corrDecisions = corr;
er.nrDecisions = all;
er.fn = fn;
er.fp = fp;
er.ACC = acc;
System.out.println("all : " + all);
System.out.println("corr: " + corr);
System.out.println("fp :" + fp);
System.out.println("fn :" + fn);
System.out.println("R :" + er.getR());
System.out.println("P :" + er.getP());
System.out.println("F :" + er.getF());
System.out.println("ACC : " + acc);
// return acc;
return er;
}
/**
* normal evaluation, returns the accuracy errorList has format: filenameorglabelpred
* labeltoken
*
* @param crf
* a previously trained model
* @param predictFiles
* the files for evaluated prediction
* @param errorList:
* write classification errors there stored...
* @return accuracy
*/
private static EvalResult doEvaluation(CRF crf, File[] predictFiles, TreeSet errorList) {
SentenceSplitter tpFunctions = new SentenceSplitter();
tpFunctions.setModel(crf);
// get EOS symbols
EOSSymbols eoss = new EOSSymbols();
// get testing data
InstanceList predictData = tpFunctions.makePredictionData(predictFiles, crf.getInputPipe());
// predict with model and evaluate
System.out.println("predicting...");
int corr = 0;
int all = 0;
int fn= 0;
int fp=0;
double acc = 0;
for (int i = 0; i < predictData.size(); i++) {
Instance inst = predictData.get(i);
String abstractName = (String) inst.getSource();
ArrayList units = null;
try {
units = tpFunctions.predict(inst, doPostprocessing);
} catch (Exception e) {
e.printStackTrace();
}
ArrayList orgLabels = getLabelsFromLabelSequence((LabelSequence) inst.getTarget());
// for postprocessing
// if (doPostprocessing) {
// predLabels = tpFunctions.postprocessingFilter(predLabels, units);
// }
// System.out.println("\n" + abstractName + "\n" + Tokens + "\n" +
// orgLabels + "\n" + predLabels);
for (int j = 0; j < units.size(); j++) {
String unitRep = units.get(j).rep;
String pred = units.get(j).label;
String org = orgLabels.get(j);
if (eoss.tokenEndsWithEOSSymbol(unitRep)) { // evaluate only if
// token ends with
// EOS symbol
all++;
if (pred.equals(org))
corr++;
else { // store errors
String error = abstractName + "\t" + org + "\t" + pred + "\t" + unitRep + " (" + j + ")";
// System.out.println(error);
errorList.add(error);
if (pred.equals("EOS") && org.equals("IS"))
fp++;
else if (pred.equals("IS") && org.equals("EOS"))
fn++;
}
}
}
}
acc = corr / (double) all;
EvalResult er = new EvalResult();
er.corrDecisions = corr;
er.nrDecisions = all;
er.fn = fn;
er.fp = fp;
er.ACC = acc;
System.out.println("all : " + all);
System.out.println("corr: " + corr);
System.out.println("fp :" + fp);
System.out.println("fn :" + fn);
System.out.println("R :" + er.getR());
System.out.println("P :" + er.getP());
System.out.println("F :" + er.getF());
System.out.println("ACC : " + acc);
// return acc;
return er;
}
/**
* to train a sentence boundary detector input: abstracts with one sentence per line outout: a
* crf model stored in file
*
* @param trainDir
* the directory with training abstracts
* @param modelFile
* the file to store the trained model
*/
private static void doTraining(File[] trainFiles, String modelFilename) {
SentenceSplitter sentenceSplitter = new SentenceSplitter();
// get training data
System.out.println("making training data...");
InstanceList trainData = sentenceSplitter.makeTrainingData(trainFiles, false);
Pipe myPipe = trainData.getPipe();
// train a model
System.out.println("training model...");
sentenceSplitter.train(trainData, myPipe);
sentenceSplitter.writeModel(modelFilename);
}
/**
* this performs sentence splitting input: files with all sentences of an abstract within one
* line output: files with one sentence per line
*
* @param inDir
* the input
* @param outDir
* the output, where to store the splitted sentences
* @param the
* the stored model to load
*/
private static void doPrediction(File[] inFiles, File outDir, String modelFilename) {
SentenceSplitter sentenceSplitter = new SentenceSplitter();
System.out.println("reading model...");
try {
sentenceSplitter.readModel(new File(modelFilename));
} catch (Exception e) {
e.printStackTrace();
}
// make prediction data
System.out.println("starting sentence splitting...");
Pipe myPipe = sentenceSplitter.getModel().getInputPipe();
int step = 100;
int percentage = 0;
Instance inst = null;
Instance tmp = null;
for (int i = 0; i < inFiles.length; i++) {
long s1 = System.currentTimeMillis();
if (i % step == 0 && i > 0) {
percentage += 1;
System.out.println(i + " files done...");
}
ArrayList fileLines = sentenceSplitter.readFile(inFiles[i]);
tmp = new Instance(fileLines, "", "", inFiles[i].getName());
inst = myPipe.instanceFrom(tmp);
fileLines = null;
ArrayList units = null;
try {
units = sentenceSplitter.predict(inst, doPostprocessing);
} catch (Exception e) {
e.printStackTrace();
}
ArrayList orgLabels = getLabelsFromLabelSequence((LabelSequence) inst.getTarget());
// for postprocessing
// if (doPostprocessing) {
// predLabels = tpFunctions.postprocessingFilter(predLabels, units);
// }
// System.out.println(inFiles[i].toString());
// for (int j = 0; j < Tokens.size(); j++)
// System.out.println(orgLabels.get(j).equals(predLabels.get(j))
// + "\t" + orgLabels.get(j) + "\t" + predLabels.get(j) + "\t" +
// Tokens.get(j));
// now write to file
String fName = inFiles[i].toString();
String newfName = fName.substring(fName.lastIndexOf("/") + 1, fName.length());
File fNew = new File(outDir.toString() + "/" + newfName);
ArrayList lines = new ArrayList();
String sentence = "";
for (int j = 0; j < units.size(); j++) {
String label = units.get(j).label;
String unitRep = units.get(j).rep;
sentence += (sentence.length() == 0) ? unitRep : " " + unitRep;
if (label.equals("EOS")) {
lines.add(sentence);
sentence = "";
}
}
long s2 = System.currentTimeMillis();
writeFile(lines, fNew);
}
}
private static ArrayList getLabelsFromLabelSequence(LabelSequence ls) {
ArrayList labels = new ArrayList();
for (int j = 0; j < ls.size(); j++)
labels.add((String) ls.get(j));
return labels;
}
private static void writeFile(TreeSet lines, File outFile) {
try {
FileWriter fw = new FileWriter(outFile);
for (Iterator iter = lines.iterator(); iter.hasNext();)
fw.write((String) iter.next() + "\n");
fw.close();
} catch (IOException e) {
e.printStackTrace();
}
}
private static void writeFile(ArrayList lines, File outFile) {
try {
FileWriter fw = new FileWriter(outFile);
for (int i = 0; i < lines.size(); i++)
fw.write(lines.get(i) + "\n");
fw.close();
} catch (Exception e) {
e.printStackTrace();
}
}
private static class EvalResult {
int nrDecisions;
double ACC;
double fp;
double fn;
double corrDecisions;
double getF() {
return 2 * getR() * getP() / (getR() + getP());
}
double getR() {
return (double) corrDecisions / (corrDecisions + fn);
}
double getP() {
return (double) corrDecisions / (corrDecisions + fp);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy