All Downloads are FREE. Search and download functionalities are using the official Maven repository.

epic.sequences.CRFModel.scala Maven / Gradle / Ivy

There is a newer version: 0.4.4
Show newest version
package epic.sequences

import epic.framework._
import breeze.util._
import breeze.linalg._
import epic.sequences.CRF.{AnchoredFeaturizer, TransitionVisitor}
import breeze.features.FeatureVector
import epic.features._
import epic.lexicon.SimpleLexicon
import java.util
import epic.util.{ProgressLog, SafeLogging, NotProvided, Optional}
import epic.constraints.TagConstraints

/**
 *
 * @author dlwh
 */
@SerialVersionUID(1L)
class CRFModel[L, W](val featureIndex: Index[Feature],
                     val lexicon: TagConstraints.Factory[L, W],
                     val featurizer: CRF.IndexedFeaturizer[L, W],
                     initialWeights: Feature=>Double = {(_: Feature) => 0.0}) extends Model[TaggedSequence[L, W]] with StandardExpectedCounts.Model[TaggedSequence[L, W]] with Serializable {
  def labelIndex: Index[L] = featurizer.labelIndex

  def extractCRF(weights: DenseVector[Double]) = {
     inferenceFromWeights(weights)
  }

  type Inference = CRFInference[L, W]
  type Marginal = CRF.Marginal[L, W]
  type Scorer = CRF.Anchoring[L, W]

  def initialValueForFeature(f: Feature): Double = initialWeights(f)

  def inferenceFromWeights(weights: DenseVector[Double]): Inference =
    new CRFInference(weights, featureIndex, lexicon, featurizer)


  def accumulateCounts(inf: Inference, s: Scorer, d: TaggedSequence[L, W], marg: Marginal, counts: ExpectedCounts, scale: Double): Unit = {
    counts.loss += marg.logPartition * scale
    val localization = s.asInstanceOf[inf.Anchoring].localization
    val visitor = new TransitionVisitor[L, W] {

      def apply(pos: Int, prev: Int, cur: Int, count: Double) {
        val feats = localization.featuresForTransition(pos, prev, cur)
        if(count != 0) assert(feats ne null, (pos, prev, cur, marg.length, marg.anchoring.validSymbols(pos), marg.anchoring.validSymbols(pos-1)))
        axpy(scale * count, feats, counts)
      }
    }
    marg.visit(visitor)
  }

}


@SerialVersionUID(1)
class CRFInference[L, W](val weights: DenseVector[Double],
                         val featureIndex: Index[Feature],
                         val lexicon: TagConstraints.Factory[L, W],
                         featurizer: CRF.IndexedFeaturizer[L, W]) extends AugmentableInference[TaggedSequence[L, W], CRF.Anchoring[L, W]] with CRF[L, W] with AnnotatingInference[TaggedSequence[L, W]] with Serializable {




  def scorer(v: TaggedSequence[L, W]): Scorer = new Anchoring(v.words)

  def viterbi(sentence: IndexedSeq[W], anchoring: CRF.Anchoring[L, W]): TaggedSequence[L, W] = {
    CRF.viterbi(new Anchoring(sentence) * anchoring)
  }


  def annotate(datum: TaggedSequence[L, W], m: Marginal): TaggedSequence[L, W] = {
    CRF.posteriorDecode(m)
  }

  type Marginal = CRF.Marginal[L, W]
  type ExpectedCounts = StandardExpectedCounts[Feature]
  type Scorer = CRF.Anchoring[L, W]

  def emptyCounts = StandardExpectedCounts.zero(this.featureIndex)

  def anchor(w: IndexedSeq[W]) = new Anchoring(w)


  def labelIndex = featurizer.labelIndex
  def startSymbol = featurizer.startSymbol


  def marginal(scorer: Scorer, v: TaggedSequence[L, W], aug: CRF.Anchoring[L, W]): CRFInference[L, W]#Marginal = {
    CRF.Marginal(scorer * aug)
  }


  def goldMarginal(scorer: Scorer, v: TaggedSequence[L, W], aug: CRF.Anchoring[L, W]): Marginal = {
    CRF.Marginal.goldMarginal[L, W](new Anchoring(v.words) * aug, v.label)
  }


  private val allLabels =  (0 until labelIndex.size).toSet

  def baseAugment(v: TaggedSequence[L, W]): CRF.Anchoring[L, W] = {
    new CRF.IdentityAnchoring(v.words, IndexedSeq.fill(v.words.length)(allLabels),labelIndex, featurizer.startSymbol)
  }

  class Anchoring(val words: IndexedSeq[W]) extends CRF.Anchoring[L, W] {
    val localization = featurizer.anchor(words)

    val transCache = Array.ofDim[Double](labelIndex.size, labelIndex.size, length)
    for(a <- transCache; b <- a) util.Arrays.fill(b, Double.NegativeInfinity)
    for(i <- 0 until length; c <- validSymbols(i); p <- validSymbols(i-1)) {
      val feats = localization.featuresForTransition(i, p, c)
      if(feats ne null)
        transCache(p)(c)(i) = weights dot feats
      else transCache(p)(c)(i) = Double.NegativeInfinity
    }



    def validSymbols(pos: Int): Set[Int] = localization.validSymbols(pos)

    def scoreTransition(pos: Int, prev: Int, cur: Int): Double = {
       transCache(prev)(cur)(pos)
    }

    def labelIndex: Index[L] = featurizer.labelIndex

    def startSymbol = featurizer.startSymbol
  }


  def posteriorDecode(m: Marginal):TaggedSequence[L, W] = {
    CRF.posteriorDecode(m)
  }
}

class TaggedSequenceModelFactory[L](val startSymbol: L,
                                    gazetteer: Optional[Gazetteer[Any, String]] = NotProvided,
                                    wordFeaturizer: Optional[WordFeaturizer[String]] = NotProvided,
                                    transitionFeaturizer: Optional[WordFeaturizer[String]] = NotProvided,
                                    weights: Feature=>Double = { (f:Feature) => 0.0},
                                    hashFeatureScale: Double = 0.0) extends SafeLogging {

  import TaggedSequenceModelFactory._

  def makeModel(train: IndexedSeq[TaggedSequence[L, String]]): CRFModel[L, String] = {
    val labelIndex: Index[L] = Index[L](Iterator(startSymbol) ++ train.iterator.flatMap(_.label))
    val counts: Counter2[L, String, Double] = Counter2.count(train.flatMap(p => p.label zip p.words)).mapValues(_.toDouble)


    val lexicon:TagConstraints.Factory[L, String] = new SimpleLexicon[L, String](labelIndex, counts)


    var featurizer: WordFeaturizer[String] = wordFeaturizer.getOrElse(WordFeaturizer.goodPOSTagFeaturizer(counts))
    featurizer = gazetteer.foldLeft(featurizer)(_ + _)
    val l2featurizer: WordFeaturizer[String] = transitionFeaturizer.getOrElse(WordFeaturizer.goodPOSTagTransitionFeaturizer(counts))

    val indexedFeaturizer = IndexedWordFeaturizer.fromData(featurizer, train.map{_.words})
    val indexedL2featurizer = IndexedWordFeaturizer.fromData(l2featurizer, train.map{_.words})

    val lfBuilder = new CrossProductIndex.Builder(labelIndex, indexedFeaturizer.featureIndex, includeLabelOnlyFeatures = false, hashFeatures = HashFeature.Relative(hashFeatureScale))
    val label2Index = Index[(L, L)]()
    val label2Features = Array.tabulate(labelIndex.size, labelIndex.size) { (l1, l2) =>
      label2Index.index(labelIndex.get(l1) -> labelIndex.get(l2))
    }
    val l2Builder = new CrossProductIndex.Builder(label2Index, indexedL2featurizer.featureIndex, includeLabelOnlyFeatures = true, hashFeatures = HashFeature.Relative(hashFeatureScale))

    val progress = new ProgressLog(logger,train.length, frequency=1000, name= "NumFeatures")
    for(s <- train) {
      val loc = indexedFeaturizer.anchor(s.words)
      val l2loc = indexedL2featurizer.anchor(s.words)
      val lexLoc = lexicon.anchor(s.words)

      for {
        b <- 0 until s.length
        l <- lexLoc.allowedTags(b)
      } {
        lfBuilder.add(l, loc.featuresForWord(b))
        if(lexLoc.allowedTags(b).size > 1) {
          for(prevTag <- if(b == 0) Set(labelIndex(startSymbol)) else lexLoc.allowedTags(b-1)) {
            l2Builder.add(label2Features(prevTag)(l), l2loc.featuresForWord(b))
          }
        }
      }
      progress.info(s"${lfBuilder.size + l2Builder.size}")
    }


    val indexed = new IndexedStandardFeaturizer[L, String](indexedFeaturizer,
      indexedL2featurizer, lexicon, startSymbol, labelIndex, label2Features, lfBuilder.result(), l2Builder.result())

    logger.info(s"There are ${indexed.label2FeatureIndex.size} and ${indexed.labelFeatureIndex.size} features total.")
    val model = new CRFModel(indexed.featureIndex, lexicon, indexed, weights)

    model
  }

}

object TaggedSequenceModelFactory {


  @SerialVersionUID(1L)
  class IndexedStandardFeaturizer[L, String](wordFeaturizer: IndexedWordFeaturizer[String],
                                             l2WordFeaturizer: IndexedWordFeaturizer[String],
                                             val lexicon: TagConstraints.Factory[L, String],
                                             val startSymbol: L,
                                             val labelIndex: Index[L],
                                             label2Features: Array[Array[Int]],
                                             val labelFeatureIndex: CrossProductIndex[L, Feature],
                                             val label2FeatureIndex: CrossProductIndex[(L, L), Feature]
                                             ) extends CRF.IndexedFeaturizer[L,String] with Serializable { outer =>

    val featureIndex = SegmentedIndex(labelFeatureIndex, label2FeatureIndex)
    private val loff = featureIndex.componentOffset(0)
    private val l2off = featureIndex.componentOffset(1)


    private val startSymbolSet = Set(labelIndex(startSymbol))

    def anchor(w: IndexedSeq[String]): AnchoredFeaturizer[L, String] = new AnchoredFeaturizer[L, String] {
      val loc = wordFeaturizer.anchor(w)
      val l2loc = l2WordFeaturizer.anchor(w)
      val lexLoc = lexicon.anchor(w)
      def featureIndex: Index[Feature] =  outer.featureIndex

      def validSymbols(pos: Int): Set[Int] = if(pos < 0 || pos >= w.length) startSymbolSet else  lexLoc.allowedTags(pos)

      def length = w.length



      val featureArray = Array.ofDim[FeatureVector](length, labelIndex.size, labelIndex.size)
      private val posNeedsAmbiguity = Array.tabulate(length)(i => validSymbols(i).size > 1)
      for {
        pos <- 0 until length
        curTag <- validSymbols(pos)
        features = loc.featuresForWord(pos)
        justLabel = labelFeatureIndex.crossProduct(Array(curTag), features, offset = loff, usePlainLabelFeatures = false)
        prevTag <- validSymbols(pos-1)
      } {
        val l2feats = l2loc.featuresForWord(pos)
        val feats = if(posNeedsAmbiguity(pos)) {
          justLabel++ label2FeatureIndex.crossProduct(Array(label2Features(prevTag)(curTag)), l2feats, offset = l2off, usePlainLabelFeatures = true)
        } else {
          justLabel
        }

        featureArray(pos)(prevTag)(curTag) = new FeatureVector(feats)
      }

      def featuresForTransition(pos: Int, prev: Int, cur: Int): FeatureVector = {
        featureArray(pos)(prev)(cur)
      }

    }
  }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy