
cc.mallet.grmm.learning.extract.AcrfExtractorTui Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning.extract;
/**
*
* Created: Aug 23, 2005
*
* @author trainSource = constructIterator(trainFile.value, dataDir.value, trainingIsList.value);
Iterator testSource;
if (testFile.wasInvoked ()) {
testSource = constructIterator (testFile.value, dataDir.value, trainingIsList.value);
} else {
testSource = null;
}
ACRF.Template[] tmpls = parseModelFile (modelFile.value);
ACRFExtractorTrainer trainer = createTrainer (trainerOption.value);
ACRFEvaluator eval = createEvaluator (evalOption.value);
ExtractionEvaluator extractionEval = createExtractionEvaluator (extractionEvalOption.value);
Inferencer inf = createInferencer (inferencerOption.value);
Inferencer maxInf = createInferencer (maxInferencerOption.value);
trainer.setPipes (tokPipe, new TokenSequence2FeatureVectorSequence ())
.setDataSource (trainSource, testSource)
.setEvaluator (eval)
.setTemplates (tmpls)
.setInferencer (inf)
.setViterbiInferencer (maxInf)
.setCheckpointDirectory (outputPrefix.value)
.setNumCheckpointIterations (checkpointIterations.value)
.setCacheUnrolledGraphs (cacheUnrolledGraph.value)
.setUsePerTemplateTrain (perTemplateTrain.value)
.setPerTemplateIterations (pttIterations.value);
logger.info ("Starting training...");
ACRFExtractor extor = trainer.trainExtractor ();
timing.tick ("Training");
FileUtils.writeGzippedObject (new File (outputPrefix.value, "extor.ser.gz"), extor);
timing.tick ("Serializing");
InstanceList testing = trainer.getTestingData ();
if (testing != null) {
eval.test (extor.getAcrf (), testing, "Final results");
}
if ((extractionEval != null) && (testing != null)) {
Extraction extraction = extor.extract (testing);
extractionEval.evaluate (extraction);
timing.tick ("Evaluting");
}
System.out.println ("Total time (ms) = " + timing.elapsedTime ());
}
private static BshInterpreter setupInterpreter ()
{
BshInterpreter interpreter = CommandOption.getInterpreter ();
try {
interpreter.eval ("import edu.umass.cs.mallet.base.extract.*");
interpreter.eval ("import edu.umass.cs.mallet.grmm.inference.*");
interpreter.eval ("import edu.umass.cs.mallet.grmm.learning.*");
interpreter.eval ("import edu.umass.cs.mallet.grmm.learning.templates.*");
interpreter.eval ("import edu.umass.cs.mallet.grmm.learning.extract.*");
} catch (EvalError e) {
throw new RuntimeException (e);
}
return interpreter;
}
private static Iterator constructIterator (File trainFile, File dataDir, boolean isList) throws IOException
{
if (isList) {
return new FileListIterator (trainFile, dataDir, null, null, true);
} else {
return new LineGroupIterator (new FileReader (trainFile), Pattern.compile ("^\\s*$"), true);
}
}
public static ACRFEvaluator createEvaluator (String spec) throws EvalError
{
if (spec.indexOf ('(') >= 0) {
// assume it's Java code, and don't screw with it.
return (ACRFEvaluator) interpreter.eval (spec);
} else {
LinkedList toks = new LinkedList (Arrays.asList (spec.split ("\\s+")));
return createEvaluator (toks);
}
}
private static ExtractionEvaluator createExtractionEvaluator (String spec) throws EvalError
{
if (spec.indexOf ('(') >= 0) {
// assume it's Java code, and don't screw with it.
return (ExtractionEvaluator) interpreter.eval (spec);
} else {
spec = "new "+spec+"Evaluator ()";
return (ExtractionEvaluator) interpreter.eval (spec);
}
}
private static ACRFEvaluator createEvaluator (LinkedList toks)
{
String type = (String) toks.removeFirst ();
if (type.equalsIgnoreCase ("SEGMENT")) {
int slice = Integer.parseInt ((String) toks.removeFirst ());
if (toks.size() % 2 != 0)
throw new RuntimeException ("Error in --eval "+evalOption.value+": Every start tag must have a continue.");
int numTags = toks.size () / 2;
String[] startTags = new String [numTags];
String[] continueTags = new String [numTags];
for (int i = 0; i < numTags; i++) {
startTags[i] = (String) toks.removeFirst ();
continueTags[i] = (String) toks.removeFirst ();
}
return new MultiSegmentationEvaluatorACRF (startTags, continueTags, slice);
} else if (type.equalsIgnoreCase ("LOG")) {
return new DefaultAcrfTrainer.LogEvaluator ();
} else if (type.equalsIgnoreCase ("SERIAL")) {
List evals = new ArrayList ();
while (!toks.isEmpty ()) {
evals.add (createEvaluator (toks));
}
return new AcrfSerialEvaluator (evals);
} else {
throw new RuntimeException ("Error in --eval "+evalOption.value+": illegal evaluator "+type);
}
}
private static ACRFExtractorTrainer createTrainer (String spec) throws EvalError
{
String cmd;
if (spec.indexOf ('(') >= 0) {
// assume it's Java code, and don't screw with it.
cmd = spec;
} else if (spec.endsWith ("Trainer")) {
cmd = "new "+spec+"()";
} else {
cmd = "new "+spec+"Trainer()";
}
// Return whatever the Java code says to
Object trainer = interpreter.eval (cmd);
if (trainer instanceof ACRFExtractorTrainer)
return (ACRFExtractorTrainer) trainer;
else if (trainer instanceof DefaultAcrfTrainer)
return new ACRFExtractorTrainer ().setTrainingMethod ((ACRFTrainer) trainer);
else throw new RuntimeException ("Don't know what to do with trainer "+trainer);
}
private static Inferencer createInferencer (String spec) throws EvalError
{
String cmd;
if (spec.indexOf ('(') >= 0) {
// assume it's Java code, and don't screw with it.
cmd = spec;
} else {
cmd = "new "+spec+"()";
}
// Return whatever the Java code says to
Object inf = interpreter.eval (cmd);
if (inf instanceof Inferencer)
return (Inferencer) inf;
else throw new RuntimeException ("Don't know what to do with inferencer "+inf);
}
public static void doProcessOptions (Class childClass, String[] args)
{
CommandOption.List options = new CommandOption.List ("", new CommandOption[0]);
options.add (childClass);
options.process (args);
options.logOptions (Logger.getLogger (""));
}
private static ACRF.Template[] parseModelFile (File mdlFile) throws IOException, EvalError
{
BufferedReader in = new BufferedReader (new FileReader (mdlFile));
List tmpls = new ArrayList ();
String line = in.readLine ();
while (line != null) {
Object tmpl = interpreter.eval (line);
if (!(tmpl instanceof ACRF.Template)) {
throw new RuntimeException ("Error in "+mdlFile+" line "+in.toString ()+":\n Object "+tmpl+" not a template");
}
tmpls.add (tmpl);
line = in.readLine ();
}
return (ACRF.Template[]) tmpls.toArray (new ACRF.Template [0]);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy