All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Search JAR files by class name

Source code: Class ChainNERDemo.scala part of factorie_2.11 version 1.2

/* Copyright (C) 2008-2016 University of Massachusetts Amherst.
   This file is part of "FACTORIE" (Factor graphs, Imperative, Extensible)
   http://factorie.cs.umass.edu, http://github.com/factorie
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

package cc.factorie.tutorial
import java.io.File

import cc.factorie._
import cc.factorie.infer.{BP, BPSummary, GibbsSampler, IteratedConditionalModes}
import cc.factorie.model.{DotTemplateWithStatistics1, DotTemplateWithStatistics2, Parameters, TemplateModel}
import cc.factorie.variable._

/** A demonstration of training a linear-chain CRF for named entity recognition.
    Prints various diagnostics suitable to a demo.
    @author Andrew McCallum  */
object ChainNERDemo {

  // The variable classes
  object TokenDomain extends CategoricalVectorDomain[String]
  class Token(val word:String, features:Seq[String], labelString:String) extends BinaryFeatureVectorVariable[String] with ChainLink[Token,Sentence] {
    def domain = TokenDomain
    val label: Label = new Label(labelString, this)
    this ++= features
  }
  object LabelDomain extends CategoricalDomain[String]
  class Label(labelname: String, val token: Token) extends LabeledCategoricalVariable(labelname) {
    def domain = LabelDomain
    def hasNext = token.hasNext && token.next.label != null
    def hasPrev = token.hasPrev && token.prev.label != null
    def next = token.next.label
    def prev = token.prev.label
  }
  class Sentence extends Chain[Sentence,Token]

  // The model
  val model = new TemplateModel with Parameters {
    // Bias term on each individual label
    object bias extends DotTemplateWithStatistics1[Label] {
      val weights = Weights(new la.DenseTensor1(LabelDomain.size))
    }
    // Transition factors between two successive labels
    object transtion extends DotTemplateWithStatistics2[Label, Label] {
      val weights = Weights(new la.DenseTensor2(LabelDomain.size, LabelDomain.size))
      def unroll1(label: Label) = if (label.hasPrev) Factor(label.prev, label) else Nil
      def unroll2(label: Label) = if (label.hasNext) Factor(label, label.next) else Nil
    }
    // Factor between label and observed token
    object evidence extends DotTemplateWithStatistics2[Label, Token] {
      val weights = Weights(new la.DenseTensor2(LabelDomain.size, TokenDomain.dimensionSize))
      def unroll1(label: Label) = Factor(label, label.token)
      def unroll2(token: Token) = throw new Error("Token values shouldn't change")
    }
    this += evidence
    this += bias
    this += transtion
  }

  // The training objective
  val objective = new HammingTemplate[Label, Label#TargetType]

  def main(args: Array[String]): Unit = {
    implicit val random = new scala.util.Random(0)
    if (args.length != 2) throw new Error("Usage: ChainNERDemo trainfile testfile")

    // Read in the data
    val trainSentences = load(args(0))
    val testSentences = load(args(1))

    // Get the variables to be inferred
    val trainLabels = trainSentences.flatMap(_.links.map(_.label)).take(50000) //.take(30000)

    // Add features from next and previous tokens
    // println("Adding offset features...")
    trainLabels.map(_.token).foreach(t => {
      if (t.hasPrev) t ++= t.prev.activeCategories.filter(!_.contains('@')).map(_+"@-1")
      if (t.hasNext) t ++= t.next.activeCategories.filter(!_.contains('@')).map(_+"@+1")
    })

    // Freeze domain now that we've added all feature values observed in training data
    TokenDomain.freeze()
    println("Using "+TokenDomain.dimensionSize+" observable features.")

    // Compute features on test data
    val testLabels = testSentences.flatMap(_.links.map(_.label))//.take(2000)
    testLabels.map(_.token).foreach(t => {
      if (t.hasPrev) t ++= t.prev.activeCategories.filter(!_.contains('@')).map(_+"@-1")
      if (t.hasNext) t ++= t.next.activeCategories.filter(!_.contains('@')).map(_+"@+1")
    })

    // Print some significant features
    //println("Most predictive features:")
    //val pllo = new cc.factorie.app.classify.PerLabelLogOdds(trainSentences.flatMap(_.map(_.label)), (label:Label) => label.token)
    //for (label <- LabelDomain.values) println(label.category+": "+pllo.top(label, 20))

    // Sample and Learn!
    val startTime = System.currentTimeMillis
    (trainLabels ++ testLabels).foreach(_.setRandomly)
    val learner = new optimize.SampleRankTrainer(new GibbsSampler(model, objective) {temperature=0.1}, new cc.factorie.optimize.AdaGrad)
    val predictor = new IteratedConditionalModes(model, null)
    for (i <- 1 to 3) {
      // println("Iteration "+i)
      learner.processContexts(trainLabels)
      predictor.processAll(testLabels); predictor.processAll(trainLabels)
      trainLabels.take(20).foreach(printLabel _); println(); println()
      printDiagnostic(trainLabels.take(400))
      //trainLabels.take(20).foreach(label => println("%30s %s %s %f".format(label.token.word, label.targetCategory, label.categoryValue, objective.currentScore(label))))
      //println ("Tr50  accuracy = "+ objective.accuracy(trainLabels.take(20)))
      println ("Train accuracy = "+ objective.accuracy(trainLabels))
      println ("Test  accuracy = "+ objective.accuracy(testLabels))
    }
    if (false) {
      // Use BP Viterbi for prediction
      for (sentence <- testSentences)
        BP.inferChainMax(sentence.asSeq.map(_.label), model).setToMaximize(null)
      //BP.inferChainSum(sentence.asSeq.map(_.label), model).setToMaximize(null) // max-marginal inference

      for (sentence <- trainSentences.take(10)) {
        println("---SumProduct---")
        printTokenMarginals(sentence.asSeq, BP.inferChainSum(sentence.asSeq.map(_.label), model))
        println("---MaxProduct---")
        // printTokenMarginals(sentence.asSeq, BP.inferChainMax(sentence.asSeq.map(_.label), model))
        println("---Gibbs Sampling---")
        predictor.processAll(testLabels, 2)
        sentence.asSeq.foreach(token => printLabel(token.label))
      }
    } else {
      // Use VariableSettingsSampler for prediction
      //predictor.temperature *= 0.1
      predictor.processAll(testLabels, 2)
    }

    println ("Final Test  accuracy = "+ objective.accuracy(testLabels))
    //println("norm " + model.weights.twoNorm)
    println("Finished in " + ((System.currentTimeMillis - startTime) / 1000.0) + " seconds")

    //for (sentence <- testSentences) BP.inferChainMax(sentence.asSeq.map(_.label), model); println ("MaxBP Test accuracy = "+ objective.accuracy(testLabels))
    //for (sentence <- testSentences) BP.inferChainSum(sentence.asSeq.map(_.label), model).setToMaximize(null); println ("SumBP Test accuracy = "+ objective.accuracy(testLabels))
    //predictor.processAll(testLabels, 2); println ("Gibbs Test accuracy = "+ objective.accuracy(testLabels))
  }

  def printTokenMarginals(tokens:Seq[Token], summary:BPSummary): Unit = {
    for (token <- tokens)
      println(token.word + " " + LabelDomain.categories.zip(summary.marginal(token.label).proportions.asSeq).sortBy(_._2).reverse.mkString(" "))
    println()
  }

  // Feature extraction
  def wordToFeatures(word:String, initialFeatures:String*) : Seq[String] = {
    import scala.collection.mutable.ArrayBuffer
    val f = new ArrayBuffer[String]
    f += "W="+word
    f ++= initialFeatures
    if (word.length > 3) f += "PRE="+word.substring(0,3)
    if (Capitalized.findFirstMatchIn(word) != None) f += "CAPITALIZED"
    if (Numeric.findFirstMatchIn(word) != None) f += "NUMERIC"
    if (Punctuation.findFirstMatchIn(word) != None) f += "PUNCTUATION"
    f
  }
  val Capitalized = "^[A-Z].*".r
  val Numeric = "^[0-9]+$".r
  val Punctuation = "[-,\\.;:?!()]+".r


  def printLabel(label:Label) : Unit = {
    println("%-16s TRUE=%-8s PRED=%-8s %s".format(label.token.word, label.target.categoryValue, label.value.category, label.token.toString))
  }

  def printDiagnostic(labels:Seq[Label]) : Unit = {
    for (label <- labels; if label.intValue != label.domain.index("O")) {
      if (!label.hasPrev || label.value != label.prev.value)
        print("%-7s %-7s ".format(if (label.value != label.target.value) label.target.value.category.drop(2) else " ", label.value.category.drop(2)))
      print(label.token.word+" ")
      if (!label.hasNext || label.value != label.next.value) println()
    }
    println()
  }

  def load(filename:String) : Seq[Sentence] = {
    import scala.collection.mutable.ArrayBuffer
    import scala.io.Source
    var wordCount = 0
    var sentences = new ArrayBuffer[Sentence]
    val source = Source.fromFile(new File(filename))
    var sentence = new Sentence
    for (line <- source.getLines()) {
      if (line.length < 2) { // Sentence boundary
        sentences += sentence
        sentence = new Sentence
      } else if (line.startsWith("-DOCSTART-")) {
        // Skip document boundaries
      } else {
        val fields = line.split(' ')
        assert(fields.length == 4)
        val word = fields(0)
        val pos = fields(1)
        val label = fields(3).stripLineEnd
        sentence += new Token(word, wordToFeatures(word,"POS="+pos), label)
        wordCount += 1
      }
    }
    println("Loaded "+sentences.length+" sentences with "+wordCount+" words total from file "+filename)
    sentences
  }

}




© 2018 Weber Informatics LLC