All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.grmm.learning.extract.AcrfExtractorTui Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

There is a newer version: 2.0.12
Show newest version
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning.extract;
/**
 *
 * Created: Aug 23, 2005
 *
 * @author  trainSource = constructIterator(trainFile.value, dataDir.value, trainingIsList.value);
    Iterator testSource;
    if (testFile.wasInvoked ()) {
      testSource = constructIterator (testFile.value, dataDir.value, trainingIsList.value);
    } else {
      testSource = null;
    }

    ACRF.Template[] tmpls = parseModelFile (modelFile.value);

    ACRFExtractorTrainer trainer = createTrainer (trainerOption.value);
    ACRFEvaluator eval = createEvaluator (evalOption.value);
    ExtractionEvaluator extractionEval = createExtractionEvaluator (extractionEvalOption.value);

    Inferencer inf = createInferencer (inferencerOption.value);
    Inferencer maxInf = createInferencer (maxInferencerOption.value);

    trainer.setPipes (tokPipe, new TokenSequence2FeatureVectorSequence ())
            .setDataSource (trainSource, testSource)
            .setEvaluator (eval)
            .setTemplates (tmpls)
            .setInferencer (inf)
            .setViterbiInferencer (maxInf)
            .setCheckpointDirectory (outputPrefix.value)
            .setNumCheckpointIterations (checkpointIterations.value)
            .setCacheUnrolledGraphs (cacheUnrolledGraph.value)
            .setUsePerTemplateTrain (perTemplateTrain.value)
            .setPerTemplateIterations (pttIterations.value);

    logger.info ("Starting training...");
    ACRFExtractor extor = trainer.trainExtractor ();
    timing.tick ("Training");

    FileUtils.writeGzippedObject (new File (outputPrefix.value, "extor.ser.gz"), extor);
    timing.tick ("Serializing");

    InstanceList testing = trainer.getTestingData ();
    if (testing != null) {
      eval.test (extor.getAcrf (), testing, "Final results");
    }

    if ((extractionEval != null) && (testing != null)) {
      Extraction extraction = extor.extract (testing);
      extractionEval.evaluate (extraction);
      timing.tick ("Evaluting");
    }

    System.out.println ("Total time (ms) = " + timing.elapsedTime ());
  }

  private static BshInterpreter setupInterpreter ()
  {
    BshInterpreter interpreter = CommandOption.getInterpreter ();
    try {
      interpreter.eval ("import edu.umass.cs.mallet.base.extract.*");
      interpreter.eval ("import edu.umass.cs.mallet.grmm.inference.*");
      interpreter.eval ("import edu.umass.cs.mallet.grmm.learning.*");
      interpreter.eval ("import edu.umass.cs.mallet.grmm.learning.templates.*");
      interpreter.eval ("import edu.umass.cs.mallet.grmm.learning.extract.*");
    } catch (EvalError e) {
      throw new RuntimeException (e);
    }

    return interpreter;
  }


  private static Iterator constructIterator (File trainFile, File dataDir, boolean isList) throws IOException
  {
    if (isList) {
      return new FileListIterator (trainFile, dataDir, null, null, true);
    } else {
      return new LineGroupIterator (new FileReader (trainFile), Pattern.compile ("^\\s*$"), true);
    }
  }

  public static ACRFEvaluator createEvaluator (String spec) throws EvalError
  {
    if (spec.indexOf ('(') >= 0) {
      // assume it's Java code, and don't screw with it.
      return (ACRFEvaluator) interpreter.eval (spec);
    } else {
      LinkedList toks = new LinkedList (Arrays.asList (spec.split ("\\s+")));
      return createEvaluator (toks);
    }
  }

  private static ExtractionEvaluator createExtractionEvaluator (String spec) throws EvalError
  {
    if (spec.indexOf ('(') >= 0) {
      // assume it's Java code, and don't screw with it.
      return (ExtractionEvaluator) interpreter.eval (spec);
    } else {
      spec = "new "+spec+"Evaluator ()";
      return (ExtractionEvaluator) interpreter.eval (spec);
    }
  }

  private static ACRFEvaluator createEvaluator (LinkedList toks)
  {
    String type = (String) toks.removeFirst ();

    if (type.equalsIgnoreCase ("SEGMENT")) {
      int slice = Integer.parseInt ((String) toks.removeFirst ());
      if (toks.size() % 2 != 0)
        throw new RuntimeException ("Error in --eval "+evalOption.value+": Every start tag must have a continue.");
      int numTags = toks.size () / 2;
      String[] startTags = new String [numTags];
      String[] continueTags = new String [numTags];

      for (int i = 0; i < numTags; i++) {
        startTags[i] = (String) toks.removeFirst ();
        continueTags[i] = (String) toks.removeFirst ();
      }

      return new MultiSegmentationEvaluatorACRF (startTags, continueTags, slice);

    } else if (type.equalsIgnoreCase ("LOG")) {
      return new DefaultAcrfTrainer.LogEvaluator ();

    } else if (type.equalsIgnoreCase ("SERIAL")) {
      List evals = new ArrayList ();
      while (!toks.isEmpty ()) {
        evals.add (createEvaluator (toks));
      }
      return new AcrfSerialEvaluator (evals);

    } else {
      throw new RuntimeException ("Error in --eval "+evalOption.value+": illegal evaluator "+type);
    }
  }

  private static ACRFExtractorTrainer createTrainer (String spec) throws EvalError
  {
    String cmd;
    if (spec.indexOf ('(') >= 0) {
      // assume it's Java code, and don't screw with it.
      cmd = spec;
    } else if (spec.endsWith ("Trainer")) {
      cmd = "new "+spec+"()";
    } else {
      cmd = "new "+spec+"Trainer()";
    }

    // Return whatever the Java code says to
    Object trainer = interpreter.eval (cmd);

    if (trainer instanceof ACRFExtractorTrainer)
      return (ACRFExtractorTrainer) trainer;
    else if (trainer instanceof DefaultAcrfTrainer)
      return new ACRFExtractorTrainer ().setTrainingMethod ((ACRFTrainer) trainer);
    else throw new RuntimeException ("Don't know what to do with trainer "+trainer);
  }

  private static Inferencer createInferencer (String spec) throws EvalError
  {
    String cmd;
    if (spec.indexOf ('(') >= 0) {
      // assume it's Java code, and don't screw with it.
      cmd = spec;
    } else {
      cmd = "new "+spec+"()";
    }

    // Return whatever the Java code says to
    Object inf = interpreter.eval (cmd);

    if (inf instanceof Inferencer)
      return (Inferencer) inf;

    else throw new RuntimeException ("Don't know what to do with inferencer "+inf);
  }


  public static void doProcessOptions (Class childClass, String[] args)
  {
    CommandOption.List options = new CommandOption.List ("", new CommandOption[0]);
    options.add (childClass);
    options.process (args);
    options.logOptions (Logger.getLogger (""));
  }

  private static ACRF.Template[] parseModelFile (File mdlFile) throws IOException, EvalError
  {
    BufferedReader in = new BufferedReader (new FileReader (mdlFile));

    List tmpls = new ArrayList ();
    String line = in.readLine ();
    while (line != null) {
      Object tmpl = interpreter.eval (line);
      if (!(tmpl instanceof ACRF.Template)) {
        throw new RuntimeException ("Error in "+mdlFile+" line "+in.toString ()+":\n  Object "+tmpl+" not a template");
      }
      tmpls.add (tmpl);
      line = in.readLine ();
    }

    return (ACRF.Template[]) tmpls.toArray (new ACRF.Template [0]);
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy