All Downloads are FREE. Search and download functionalities are using the official Maven repository.

epic.parser.models.ThreePointModel.scala Maven / Gradle / Ivy

The newest version!
package epic.parser.models

import java.io.File

import breeze.collection.mutable.TriangularArray
import breeze.config.Help
import breeze.features.FeatureVector
import breeze.linalg._
import breeze.util.Index
import epic.constraints.ChartConstraints
import epic.constraints.ChartConstraints.Factory
import epic.dense.{IdentityTransform, AffineTransform, Transform}
import epic.features.SurfaceFeaturizer.SingleWordSpanFeaturizer
import epic.features._
import epic.framework.{Inference, Feature}
import epic.lexicon.Lexicon
import epic.parser._
import epic.parser.projections.GrammarRefinements
import epic.preprocess.MLSentenceSegmenter.BiasFeature
import epic.trees.annotations.TreeAnnotator
import epic.trees._
import epic.util.{NotProvided, Optional, LRUCache}

/**
 * TODO
 *
 * @author dlwh
 **/
class ThreePointModel[L, L2, W](annotator: (BinarizedTree[L], IndexedSeq[W]) => BinarizedTree[IndexedSeq[L2]],
                                 val constrainer: ChartConstraints.Factory[L, W],
                                 val topology: RuleTopology[L],
                                 val lexicon: Lexicon[L, W],
                                 refinedTopology: RuleTopology[L2],
                                 refinements: GrammarRefinements[L, L2],
                                 labelFeaturizer: RefinedFeaturizer[L, W, Feature],
                                 wordFeaturizer: IndexedWordFeaturizer[W],
                                 rank: Int) extends ParserModel[L, W] {
  override type Inference = ThreePointModel.ThreePointInference[L, L2, W]


  override def accumulateCounts(inf: Inference, s: Scorer, d: TreeInstance[L, W], m: Marginal, accum: ExpectedCounts, scale: Double): Unit = {
    inf.grammar.extractEcounts(m, accum.counts, scale)
    accum.loss += scale * m.logPartition
  }

  override val featureIndex = new SegmentedIndex(new AffineTransform.Index(rank, labelFeaturizer.index.size, false) +: IndexedSeq.fill(3)(new AffineTransform.Index(rank, wordFeaturizer.featureIndex.size, false)))
 override def inferenceFromWeights(weights: DenseVector[Double]): Inference = {


    val grammar = new ThreePointModel.Grammar[L, L2, W](topology, lexicon, refinedTopology, refinements, labelFeaturizer,
      wordFeaturizer,
      featureIndex,
    weights
    )
    new Inference(annotator, constrainer,
      grammar, refinements)
  }

  override def initialValueForFeature(f: Feature): Double = f.hashCode().toDouble / 1000 % 2
}

object ThreePointModel {

  object Point extends Enumeration {
    val First, Split, Last = Value // TODO: probably add Lex
  }

  case class ThreePointInference[L, L2, W](annotator: (BinarizedTree[L], IndexedSeq[W]) => BinarizedTree[IndexedSeq[L2]],
                                         constrainer: ChartConstraints.Factory[L, W],
                                         grammar: Grammar[L, L2, W],
                                         refinements: GrammarRefinements[L, L2]) extends ParserInference[L, W] {

    override def goldMarginal(scorer: Scorer, ti: TreeInstance[L, W], aug: UnrefinedGrammarAnchoring[L, W]): Marginal = {
      import ti._

      val annotated = annotator(tree, words).map(_.map(refinements.labels.localize))

      val product = grammar.anchor(words, constrainer.constraints(ti.words))
      LatentTreeMarginal(product, annotated)
    }



  }

  @SerialVersionUID(1L)
  case class Grammar[L, L2, W](topology: RuleTopology[L],
                          lexicon: Lexicon[L, W],
                          refinedTopology: RuleTopology[L2],
                          refinements: GrammarRefinements[L, L2],
                          labelFeaturizer: RefinedFeaturizer[L, W, Feature],
                          wordFeaturizer: IndexedWordFeaturizer[W],
                          featureIndex: SegmentedIndex[Feature, AffineTransform.Index],
                          weights: DenseVector[Double]) extends epic.parser.Grammar[L, W] with Serializable {
    val IndexedSeq(ruleMatrix, wordMatrices@ _*) = reshapeWeightMatrices(weights)
    assert(wordMatrices.length == 3)


    private def reshapeWeightMatrices(weights: DenseVector[Double]): IndexedSeq[DenseMatrix[Double]] = {
      val segments = featureIndex.shardWeights(weights)
      (featureIndex.indices zip segments).map { case (index, segment) => index.makeMatrix(segment)}

    }

    override def withPermissiveLexicon = {
      new Grammar(topology, lexicon.morePermissive, refinedTopology, refinements, labelFeaturizer, wordFeaturizer, featureIndex, weights)
    }

    def anchor(w: IndexedSeq[W], cons: ChartConstraints[L]) = new Anchoring(w, cons)

    // TODO: cache these if necessary for speed
    class Anchoring(w: IndexedSeq[W], cons: ChartConstraints[L]) extends ProjectionsGrammarAnchoring[L, L2, W] {

      override def addConstraints(constraints: ChartConstraints[L]): GrammarAnchoring[L, W] = {
        anchor(w, cons & constraints)
      }

      override def sparsityPattern: ChartConstraints[L] = cons

      def refinements = Grammar.this.refinements
      def refinedTopology: RuleTopology[L2] = Grammar.this.refinedTopology

      val topology = Grammar.this.topology
      val lexicon = Grammar.this.lexicon

      def words = w

      val sspec = wordFeaturizer.anchor(w)
      val lspec = labelFeaturizer.anchor(w)

      val wordActivations: Array[Array[DenseVector[Double]]] = Array.tabulate(words.length, wordMatrices.length) { (i, c) =>
        wordMatrices(c) * new FeatureVector(sspec.featuresForWord(i))
      }

      def actForPos(w: Int, c: Point.Value) = wordActivations(w)(c.id)

      // doesn't include split point, which we'll do online
      val precachedSpanActivations = TriangularArray.tabulate(words.length + 1) { (i, j) =>
        if(sparsityPattern.isAllowedSpan(i, j) && i != j) {
          val result = DenseVector.ones[Double](wordActivations.head.head.size)

          result :*= actForPos(i, Point.First)
          result :*= actForPos(j - 1, Point.Last)
//          println(result)

          result
        } else {
          null
        }
      }


      def scoreBinaryRule(begin: Int, split: Int, end: Int, rule: Int, ref: Int) = {
        val surfaceAct = precachedSpanActivations(begin, end)
        if(surfaceAct == null) {
          Double.NegativeInfinity
        } else {
          val rfeats = lspec.featuresForBinaryRule(begin, split, end, rule, ref)
          (ruleMatrix * new FeatureVector(rfeats)) dot (surfaceAct :* actForPos(split, Point.Split))
        }
      }

      def scoreUnaryRule(begin: Int, end: Int, rule: Int, ref: Int) = {
        val surfaceAct = precachedSpanActivations(begin, end)
        if(surfaceAct == null) {
          Double.NegativeInfinity
        } else {
          val rfeats = lspec.featuresForUnaryRule(begin, end, rule, ref)
          (ruleMatrix * new FeatureVector(rfeats)) dot surfaceAct
        }
      }

      def scoreSpan(begin: Int, end: Int, tag: Int, ref: Int) = {
        val surfaceAct = precachedSpanActivations(begin, end)
        if(surfaceAct == null) {
          Double.NegativeInfinity
        } else {
          val rfeats = lspec.featuresForSpan(begin, end, tag, ref)
          (ruleMatrix * new FeatureVector(rfeats)) dot surfaceAct
        }
      }

    }

    def extractEcounts(m: ParseMarginal[L, W], deriv: DenseVector[Double], scale: Double): Unit = {
      val anchoring = anchor(m.words, m.anchoring.sparsityPattern)

      val IndexedSeq(dRuleAct, dWeights @ _*) = reshapeWeightMatrices(deriv)

      val w = m.words
      val length = w.length

      /*
      // cache: we remember the (begin/end) pair we saw with each
      val UNUSED = (-1, -1)
      val states = Array.fill(w.length + 2)(UNUSED) // 1 for each split,  length for unaries, length +1 for spans
      val ruleCountsPerState = Array.fill(w.length + 2)(SparseVector.zeros[Double](labelFeaturizer.index.size))

      def checkFlush(begin: Int, split: Int, end: Int) {
        val state: (Int, Int) = (begin, end)
        val oldState: (Int, Int) = states(split)
        if(oldState != state) {
          if(oldState != UNUSED) {
            val ffeats = if(split >= length) sspec.featuresForSpan(oldState._1, oldState._2) else sspec.featuresForSplit(oldState._1, split, oldState._2)
            layer.tallyDerivative(deriv, ruleCountsPerState(split) *= scale, new FeatureVector(ffeats))
            ruleCountsPerState(split) := 0.0
          }
          states(split) = state
        }
      }
      */

      m visit new AnchoredVisitor[L] {
        import anchoring._

        override def visitUnaryRule(begin: Int, end: Int, rule: Int, ref: Int, score: Double): Unit = {
          val rfeats = lspec.featuresForUnaryRule(begin, end, rule, ref)
          val surfaceAct = precachedSpanActivations(begin, end)

          val ruleAct = ruleMatrix * new FeatureVector(rfeats)
          for(f <- rfeats) {
            axpy(score * scale, surfaceAct, dRuleAct(::, f))
          }

          val beginAct = ruleAct :* actForPos(end - 1, Point.Last)
          for(f <- sspec.featuresForWord(begin)) {
            axpy(score * scale, beginAct, dWeights(Point.First.id)(::, f))
          }

          val endAct = ruleAct :* actForPos(begin, Point.First)
          for(f <- sspec.featuresForWord(end - 1)) {
            axpy(score * scale, endAct, dWeights(Point.Last.id)(::, f))
          }

        }

        override def visitSpan(begin: Int, end: Int, tag: Int, ref: Int, score: Double): Unit = {
          val rfeats = lspec.featuresForSpan(begin, end, tag, ref)
          val surfaceAct = precachedSpanActivations(begin, end)

          val ruleAct = ruleMatrix * new FeatureVector(rfeats)
          for(f <- rfeats) {
            axpy(score * scale, surfaceAct, dRuleAct(::, f))
          }

          val actWithoutBegin = ruleAct :* actForPos(end - 1, Point.Last)
          for(f <- sspec.featuresForWord(begin)) {
            axpy(score * scale, actWithoutBegin, dWeights(Point.First.id)(::, f))
          }

          val actWithoutEnd = ruleAct :* actForPos(begin, Point.First)
          for(f <- sspec.featuresForWord(end - 1)) {
            axpy(score * scale, actWithoutEnd, dWeights(Point.Last.id)(::, f))
          }


        }

        override def visitBinaryRule(begin: Int, split: Int, end: Int, rule: Int, ref: Int, score: Double): Unit = {
          val rfeats = lspec.featuresForBinaryRule(begin, split, end, rule, ref)
          val surfaceAct = precachedSpanActivations(begin, end)

          val ruleAct = ruleMatrix * new FeatureVector(rfeats)
          for(f <- rfeats) {
            axpy(score * scale, surfaceAct, dRuleAct(::, f))
          }

          val beginAct = ruleAct :* actForPos(end - 1, Point.Last) :* actForPos(split, Point.Split)
          for(f <- sspec.featuresForWord(begin)) {
            axpy(score * scale, beginAct, dWeights(Point.First.id)(::, f))
          }

          val endAct = ruleAct :* actForPos(begin, Point.First) :* actForPos(split, Point.Split)
          for(f <- sspec.featuresForWord(end - 1)) {
            axpy(score * scale, endAct, dWeights(Point.Last.id)(::, f))
          }

          val splitAct = ruleAct :* surfaceAct
          for(f <- sspec.featuresForWord(split)) {
            axpy(score * scale, splitAct, dWeights(Point.Split.id)(::, f))
          }

        }
      }

    }
  }



}

case class ThreePointModelFactory(@Help(text=
  """The kind of annotation to do on the refined grammar. Default uses just parent annotation.
You can also epic.trees.annotations.KMAnnotator to get more or less Klein and Manning 2003.
  """)
                                 annotator: TreeAnnotator[AnnotatedLabel, String, AnnotatedLabel] = GenerativeParser.defaultAnnotator(),
                                 @Help(text="Old weights to initialize with. Optional")
                                 oldWeights: File = null,
                                 @Help(text="For features not seen in gold trees, we bin them into dummyFeats * numGoldFeatures bins using hashing. If negative, use absolute value as number of hash features.")
                                 dummyFeats: Double = 0.5,
                                 commonWordThreshold: Int = 100,
                                 ngramCountThreshold: Int = 5,
                                 useNGrams:Boolean = false,
                                 maxNGramOrder:Int = 2,
                                 useGrammar: Boolean = true,
                                 rank: Int = 1,
                                 posFeaturizer: Optional[WordFeaturizer[String]] = NotProvided,
                                 spanFeaturizer: Optional[SplitSpanFeaturizer[String]] = NotProvided,
                                 extraParams: ExtraParams = ExtraParams()) extends ParserModelFactory[AnnotatedLabel, String] {

  type MyModel = ThreePointModel[AnnotatedLabel, AnnotatedLabel, String]



  override def make(trainTrees: IndexedSeq[TreeInstance[AnnotatedLabel, String]],
                    topology: RuleTopology[AnnotatedLabel],
                    lexicon: Lexicon[AnnotatedLabel, String],
                    constrainer: ChartConstraints.Factory[AnnotatedLabel, String]): MyModel = {
    import extraParams._
    val annTrees: IndexedSeq[TreeInstance[AnnotatedLabel, String]] = trainTrees.map(annotator(_))
    println("Here's what the annotation looks like on the first few trees")
    annTrees.slice(0, Math.min(3, annTrees.size)).foreach(tree => println(tree.render(false)))

    val (annWords, annBinaries, annUnaries) = this.extractBasicCounts(annTrees)
    val refGrammar = RuleTopology(AnnotatedLabel.TOP, annBinaries, annUnaries)

    val xbarGrammar = topology
    val xbarLexicon = lexicon

    val indexedRefinements = GrammarRefinements(xbarGrammar, refGrammar, (_: AnnotatedLabel).baseAnnotatedLabel)

    var wf: WordFeaturizer[String] = {
      val dsl = new WordFeaturizer.DSL(annWords)
      import dsl._
      unigrams(lfsuf, 1) + word + new WordFeaturizer[String] {
        override def anchor(w: IndexedSeq[String]): WordFeatureAnchoring[String] = new WordFeatureAnchoring[String] {
          override def featuresForWord(pos: Int): Array[Feature] = Array(BiasFeature)

          override def words: IndexedSeq[String] =  w
        }
      }
    }

    if(useMorph)
      wf += MorphFeaturizer(pathsToMorph.split(","))

    val indexedWord = IndexedWordFeaturizer.fromData(wf, annTrees.map{_.words})

    def labelFeaturizer(l: AnnotatedLabel) = Set(l, l.baseAnnotatedLabel).toSeq
    def ruleFeaturizer(r: Rule[AnnotatedLabel]) = if(useGrammar) Set(r, r.map(_.baseAnnotatedLabel)).toSeq else if(r.isInstanceOf[UnaryRule[AnnotatedLabel]]) Set(r.parent, r.parent.baseAnnotatedLabel).toSeq else Seq.empty

    val featurizer = new ProductionFeaturizer[AnnotatedLabel, AnnotatedLabel, String](xbarGrammar, indexedRefinements,
      lGen=labelFeaturizer,
      rGen=ruleFeaturizer)


    new ThreePointModel(annotator.latent,
      constrainer,
      topology, lexicon,
      refGrammar, indexedRefinements,
      featurizer, indexedWord,
      rank)

  }



}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy