All Downloads are FREE. Search and download functionalities are using the official Maven repository.

epic.parser.LatentTreeMarginal.scala Maven / Gradle / Ivy

There is a newer version: 0.4.4
Show newest version
package epic.parser
/*
 Copyright 2012 David Hall

 Licensed under the Apache License, Version 2.0 (the "License")
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
*/
import projections.{GrammarRefinements, ProjectionIndexer}
import java.util.Arrays
import epic.trees._
import breeze.numerics._

import LatentTreeMarginal._
import scala._
import java.util
import breeze.linalg.Counter2

/**
 *
 * @author dlwh
 */

case class LatentTreeMarginal[L, W](anchoring: GrammarAnchoring[L, W],
                                    tree: BinarizedTree[IndexedSeq[(L, Int)]]) extends ParseMarginal[L, W] {

  private val stree = insideScores()
  outsideScores(stree)


  def isMaxMarginal: Boolean = false

  private val z = stree.label.inside.sum
  val logPartition = Scaling.toLogSpace(z, stree.label.iscale)

  def visitPostorder(spanVisitor: AnchoredVisitor[L], threshold: Double = Double.NegativeInfinity) = {
    // normalizer
    val rootScale = stree.label.iscale
    if (logPartition.isInfinite || logPartition.isNaN)
      sys.error(s"NaN or infinite $logPartition ${stree.label.inside.mkString(", ")}")

    stree.postorder foreach {
      case t@NullaryTree(Beliefs(labels, iScores, iScale, oScores, oScale), span) =>
        for( i <- 0 until labels.length) {
          val (l, ref) = labels(i)
          val iS = iScores(i)
          val oS = oScores(i)
          val ruleScore = Scaling.unscaleValue(iS / z * oS, iScale + oScale - rootScale)
          assert(!ruleScore.isNaN)
          // assert(exp(ruleScore) > 0, " " + ruleScore)
          spanVisitor.visitSpan(t.span.begin, t.span.end, l, ref, ruleScore)
        }
      case t@UnaryTree(Beliefs(aLabels, _, _, aScores, aScale), Tree(Beliefs(cLabels, cScores,cScale, _,  _), _, _), chain, span) =>
        var pi = 0
        while(pi < aLabels.size) {
          val (a, aRef) = aLabels(pi)
          val opScore = aScores(pi)
          pi += 1
          var ci = 0
          while(ci < cLabels.size) {
            val (c, cRef) = cLabels(ci)
            val icScore = cScores(ci)
            ci += 1
            val rule = topology.index(UnaryRule(topology.labelIndex.get(a), topology.labelIndex.get(c), chain))
            val ruleRef = anchoring.ruleRefinementFromRefinements(rule, aRef, cRef)
            if(ruleRef != -1 ) {
              val rs = math.exp(anchoring.scoreUnaryRule(t.span.begin, t.span.end, rule, ruleRef)) // exp!
              val ruleScore = Scaling.unscaleValue(opScore / z * rs * icScore, aScale + cScale - rootScale)
              assert(!ruleScore.isNaN)
              // assert(exp(ruleScore) > 0, " " + ruleScore)
              spanVisitor.visitUnaryRule(t.span.begin, t.span.end, rule, ruleRef, ruleScore)
            }
          }
        }
      case t@BinaryTree(Beliefs(aLabels, _, _, aScores, aScale), Tree(Beliefs(bLabels, bScores, bScale, _, _), _, _), Tree(Beliefs(cLabels, cScores, cScale, _,  _), _, _), span) =>
        val begin = span.begin
        val split = t.rightChild.span.begin
        val end = t.span.end
        for {
          ((a, aRef), opScore) <- aLabels zip aScores
          ((b, bRef), ilScore) <- bLabels zip bScores
          ((c, cRef), irScore) <- cLabels zip cScores
        } {
          val rule = topology.index(BinaryRule(topology.labelIndex.get(a),
            topology.labelIndex.get(b),
            topology.labelIndex.get(c)))
          val ruleRef = anchoring.ruleRefinementFromRefinements(rule, aRef, bRef, cRef)
          val rs = math.exp(anchoring.scoreBinaryRule(begin, split, end, rule, ruleRef) + anchoring.scoreSpan(begin, end, a, aRef)) // exp!
          val count = Scaling.unscaleValue(opScore / z * rs * ilScore * irScore, aScale + bScale + cScale - rootScale)
          spanVisitor.visitSpan(begin, end, a, aRef, count)
          spanVisitor.visitBinaryRule(begin, split, end, rule, ruleRef, count)
        }
    }
  }


  // private stuff to do the computation

  private def insideScores() = {
    val indexedTree:BinarizedTree[Beliefs[L]] = tree.map{ labels => Beliefs(labels.map { case (l, ref) => topology.labelIndex(l) -> ref})}

    indexedTree.postorder.foreach {
      case t@NullaryTree(Beliefs(labels, scores, _, _, _), span) =>
        // fill in POS tags:
        assert(t.span.length == 1)
        var foundOne = false
        for {
           i <- 0 until scores.length
           (label, ref) = labels(i)
           wScore =  anchoring.scoreSpan(t.span.begin, t.span.end, label, ref)
           if !wScore.isInfinite
        } {
          scores(i) = math.exp(wScore) // exp!
          assert(!wScore.isInfinite)
          assert(!wScore.isNaN)
          foundOne = true
        }
        if(!foundOne) {
          sys.error(s"Trouble with lexical  $words(t.span.begin)")
        }
        t.label.scaleInside(0)
      case t@UnaryTree(Beliefs(aLabels, aScores, _, _, _), Tree(Beliefs(cLabels, cScores, cScale, _,  _), _, _), chain, span) =>
        var foundOne = false
        var ai = 0
        while(ai < aLabels.length) {
          val (a, aRef) = aLabels(ai)

          var sum = 0.0
          var ci = 0
          while(ci < cLabels.length) {
            val (c, cRef) = cLabels(ci)
            val rule = topology.index(UnaryRule(topology.labelIndex.get(a), topology.labelIndex.get(c), chain))
            if(rule != -1) {
              val ruleRef = anchoring.ruleRefinementFromRefinements(rule, aRef, cRef)
              if (ruleRef != -1) {
                val score = anchoring.scoreUnaryRule(t.span.begin, t.span.end, rule, ruleRef)
                val ruleScore =  cScores(ci) * math.exp(score) // exp!
                sum += ruleScore
                assert(!ruleScore.isNaN)
                if(score != Double.NegativeInfinity && math.exp(score) == 0.0) {
                  println("Underflow!!!")
                }
                if(ruleScore != 0.0) {
                  foundOne = true
                }
              }
            }
            ci += 1
          }

          aScores(ai) = sum
          ai += 1
        }

        if(!foundOne) {
          sys.error("unary problems")
//          sys.error(s"Trouble with unary $t.render(words)}  ${grammar.labelIndex.get(a)}  ${grammar.labelIndex.get(c)}  $rule  ${anchoring.scoreUnaryRule(t.span.begin, t.span.end, rule, 0)}")
        }
        t.label.scaleInside(cScale)
      case t@BinaryTree(Beliefs(aLabels, aScores, _, _, _),
                        Tree(Beliefs(bLabels, bScores, bScale, _, _), _, _),
                        Tree(Beliefs(cLabels, cScores, cScale, _, _), _, _), span) =>
        var foundOne = false
        val begin = span.begin
        val split = t.leftChild.span.end
        val end = span.end
        var ai = 0
        while(ai < aScores.length) {
          var sum = 0.0
          val (a, aRef) = aLabels(ai)
          var bi = 0
          while(bi < bLabels.length) {
            val (b, bRef) = bLabels(bi)
            var ci = 0
            while(ci < cLabels.length) {
              val (c, cRef) = cLabels(ci)
              val rule = topology.index(BinaryRule(topology.labelIndex.get(a),
                topology.labelIndex.get(b),
                topology.labelIndex.get(c)))
              if(rule != -1) {
                val ruleRef = anchoring.ruleRefinementFromRefinements(rule, aRef, bRef, cRef)
                if(ruleRef != -1) {
                  val spanScore = anchoring.scoreSpan(begin, end, a, aRef)
                  sum += ( bScores(bi)
                    * cScores(ci)
                    * math.exp(anchoring.scoreBinaryRule(begin, split, end, rule, ruleRef) + spanScore)
                    )  // exp!

                }
              }
              ci += 1
            }
            bi += 1
          }
          aScores(ai) = sum
          if(aScores(ai) != 0) foundOne = true
          ai += 1
        }

        if(!foundOne) {
//          val r = (BinaryRule(grammar.labelIndex.get(a),
//                    grammar.labelIndex.get(b),
//                    grammar.labelIndex.get(c)))
//          sys.error(s"Trouble with binary ${t.render(words)}\n\n$r $rule $ai")
        }
        t.label.scaleInside(cScale + bScale)
      case _ => sys.error("bad tree!")
    }

    indexedTree
  }

  private def outsideScores(tree: BinarizedTree[Beliefs[L]]) {
    // Root gets score exp(0) == 1
    util.Arrays.fill(tree.label.outside, 1.0)

    // Set the outside score of each child
    tree.preorder.foreach {
      case t @ BinaryTree(parent, lchild, rchild, span) =>
        for {
          ((a, aRef), aScore) <- t.label.labels zip t.label.outside
          bi <- 0 until lchild.label.labels.length
          (b, bRef) = lchild.label.labels(bi)
          bScore = lchild.label.inside(bi)
          ci <- 0 until rchild.label.labels.length
          (c, cRef) = rchild.label.labels(ci)
          cScore = rchild.label.inside(ci)
        } {
          val rule = topology.index(BinaryRule(topology.labelIndex.get(a),
            topology.labelIndex.get(b),
            topology.labelIndex.get(c)))

          val ruleRef = anchoring.ruleRefinementFromRefinements(rule, aRef, bRef, cRef)
          val spanScore = math.exp(
            anchoring.scoreBinaryRule(span.begin, lchild.span.end, span.end, rule, ruleRef)
              + anchoring.scoreSpan(t.span.begin, t.span.end, a, aRef)
            ) // exp!
          lchild.label.outside(bi) += aScore * cScore * spanScore
          rchild.label.outside(ci) += aScore * bScore * spanScore
        }
        lchild.label.scaleOutside(t.label.oscale + rchild.label.iscale)
        rchild.label.scaleOutside(t.label.oscale + lchild.label.iscale)
      case tree: NullaryTree[IndexedSeq[Int]] => () // do nothing
      case t @ UnaryTree(_, child, chain, span) =>
        for ( ((c, cRef), ci) <- child.label.labels.zipWithIndex ) {
          var sum = 0.0
          for {
            ((a, aRef), aScore) <- t.label.labels zip t.label.outside
          } {
            val rule = topology.index(UnaryRule(topology.labelIndex.get(a), topology.labelIndex.get(c), chain))
            val ruleRef = anchoring.ruleRefinementFromRefinements(rule, aRef, cRef)
            if(ruleRef != -1) {
              val ruleScore = anchoring.scoreUnaryRule(span.begin, span.end, rule, ruleRef)
              sum += aScore * math.exp(ruleScore) // exp!
            }
          }
          child.label.outside(ci) = sum
        }
        child.label.scaleOutside(t.label.oscale)



    }

  }

  override def feasibleSplitPoints(begin: Int, end: Int, leftChild: Int, leftChildRef: Int, rightChild: Int, rightChildRef: Int): IndexedSeq[Int] = {
    tree.findSpan(begin, end) match {
      case Some(UnaryTree(a, b@BinaryTree(_, _, _, _), chain, _)) => IndexedSeq(b.splitPoint)
      case Some(b@BinaryTree(_, _, _, _)) => IndexedSeq(b.splitPoint)
      case _ => IndexedSeq.empty
    }
  }

  override def insideTopScore(begin: Int, end: Int, sym: Int, ref: Int): Double = {
    stree.findSpan(begin, end) match {
      case Some(UnaryTree(a, b, chain, span)) => a.labels.indexOf(sym -> ref) match {
        case -1 =>
          Double.NegativeInfinity
        case pos =>
          Scaling.toLogSpace(a.inside(pos), a.iscale)
      }
      case _ => Double.NegativeInfinity
    }
  }

  override def insideBotScore(begin: Int, end: Int, sym: Int, ref: Int): Double = {
    stree.findSpan(begin, end) match {
      case Some(UnaryTree(_, BinaryTree(a, _, _, span2), chain, span)) => a.labels.indexOf(sym -> ref) match {
        case -1 =>
          Double.NegativeInfinity
        case pos =>
          Scaling.toLogSpace(a.inside(pos), a.iscale)
      }
      case _ => Double.NegativeInfinity
    }

  }

//  override def marginalAt(begin: Int, end: Int): Counter2[L, Int, Double] = {
//   stree.findSpan(begin, end) match {
//      case None => Counter2[L, Int, Double]()
//      case Some(Tree(a, _, _)) => Counter2( (0 until a.labels.length).map {i => (grammar.labelIndex.get(a.labels(i)._1), a.labels(i)._2, a.inside) })
//    }
//  }
}

object LatentTreeMarginal {
  def apply[L, L2, W](anchoring: GrammarAnchoring[L, W],
                      projections: ProjectionIndexer[L, L2],
                      tree: BinarizedTree[L]): LatentTreeMarginal[L, W] = {
    new LatentTreeMarginal(anchoring,
      tree.map { l => projections.localRefinements(anchoring.topology.labelIndex(l)).toIndexedSeq.map(l -> _)})

  }

  def apply[L, W](grammar: Grammar[L, W],
                  words: IndexedSeq[W],
                  tree: BinarizedTree[IndexedSeq[(L, Int)]]):LatentTreeMarginal[L, W] = {
    LatentTreeMarginal(grammar.anchor(words), tree)
  }

  def apply[L, L2, W](grammar: Grammar[L, W],
                      ref: ProjectionIndexer[L, L2],
                      words: IndexedSeq[W],
                      tree: BinarizedTree[L]):LatentTreeMarginal[L, W] = {
    apply(grammar.anchor(words), ref, tree)
  }


  private case class Beliefs[L](labels: IndexedSeq[(Int, Int)],
                                inside: Array[Double],
                                var iscale: Int,
                                outside: Array[Double],
                                var oscale: Int) {
    override def toString = {
      s"Beliefs($labels, ${inside.mkString("{", ", ", " }")}, ${outside.mkString("{", ", ", "}")})"
    }

    def scaleInside(currentScale: Int) {
      iscale = Scaling.scaleArray(inside, currentScale)
    }

    def scaleOutside(currentScale: Int) {
      oscale = Scaling.scaleArray(outside, currentScale)
    }
  }

  private object Beliefs {
    private[LatentTreeMarginal] def apply[L](labels: IndexedSeq[(Int, Int)]):Beliefs[L] = {
      val r = new Beliefs[L](labels, new Array[Double](labels.length), 0, new Array[Double](labels.length), 0)
//      Arrays.fill(r.inside, Double.NegativeInfinity)
//      Arrays.fill(r.outside, Double.NegativeInfinity)
      r
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy