All Downloads are FREE. Search and download functionalities are using the official Maven repository.

epic.parser.SimpleChartMarginal.scala Maven / Gradle / Ivy

The newest version!
package epic.parser

import breeze.util.Index
import breeze.linalg.{DenseVector, softmax, max, DenseMatrix}
import breeze.collection.mutable.TriangularArray

import spire.syntax.cfor._
import epic.constraints.ChartConstraints

/**
 * TODO
 *
 * @author dlwh
 **/
final case class SimpleChartMarginal[L, L2, W](anchoring: SimpleGrammar.Anchoring[L, L2, W],
                                               inside: SimpleParseChart[L2], outside: SimpleParseChart[L2],
                                               isMaxMarginal: Boolean = true) extends ParseMarginal[L, W] {
  override val logPartition: Double = inside.top.labelScore(0, inside.length, anchoring.refinedTopology.rootIndex)

  override def insideTopScore(begin: Int, end: Int, sym: Int, ref: Int): Double = {
    inside.top.labelScore(begin, end, anchoring.refinements.labels.globalize(sym, ref))
  }

  override def insideBotScore(begin: Int, end: Int, sym: Int, ref: Int): Double = {
    inside.bot.labelScore(begin, end, anchoring.refinements.labels.globalize(sym, ref))
  }

  override def feasibleSplitPoints(begin: Int, end: Int, leftChild: Int, leftChildRef: Int, rightChild: Int, rightChildRef: Int): IndexedSeq[Int] = {
    (begin + 1) until (end)
  }

  override def visitPostorder(spanVisitor: AnchoredVisitor[L], spanThreshold: Double): Unit = {
    if(logPartition.isInfinite) throw new RuntimeException("No parse for " + words)
    if(logPartition.isNaN) throw new RuntimeException("NaN prob!")

    val refinedTopology = anchoring.refinedTopology

    val lexLoc = anchoring.lexicon.anchor(anchoring.words)
    // handle lexical
    for (i <- 0 until words.length) {
      var visitedSomething  = false
      for {
        a <- lexLoc.allowedTags(i)
        ref <- anchoring.validLabelRefinements(i, i+ 1, a)
      } {
        val aa = anchoring.refinements.labels.globalize(a, ref)
        val score:Double = anchoring.scoreSpan(i, i+1, a, ref) + outside.bot(i, i+1, aa) - logPartition
        assert(!score.isNaN, s"${anchoring.scoreSpan(i, i + 1, a, ref)} ${outside.bot(i, i + 1, aa)} $logPartition")
        if (score != Double.NegativeInfinity) {
          spanVisitor.visitSpan(i, i+1, a, ref, math.exp(score))
          visitedSomething = true
        }
      }
    }

    // handle binaries
    for {
      span <- 2 to length
      begin <- 0 to (length - span)
      parent <- 0 until anchoring.refinedTopology.labelIndex.size
    } {
      val end = begin + span
      val aOutside = outside.bot(begin, end, parent)
      val labelMarginal = inside.bot(begin, end, parent) + aOutside - logPartition
      if(labelMarginal > spanThreshold) {
        val aCoarse = anchoring.refinements.labels.project(parent)
        val aRef = anchoring.refinements.labels.localize(parent)
        spanVisitor.visitSpan(begin, end, aCoarse, aRef, math.exp(labelMarginal))
        if(!spanVisitor.skipBinaryRules) {
          val rules = anchoring.refinedTopology.indexedBinaryRulesWithParent(parent)
          var i = 0
          while(i < rules.length) {
            val r = rules(i)
            val b = refinedTopology.leftChild(r)
            val c = refinedTopology.rightChild(r)

            var split = begin + 1
            while(split < end) {
              val bInside = inside.top.labelScore(begin, split, b)
              val cInside = inside.top.labelScore(split, end, c)
              val ruleScore = anchoring.grammar.ruleScore(r)

              val coarseR = anchoring.refinements.rules.project(r)
              val refR = anchoring.refinements.rules.localize(r)

              val margScore = bInside + cInside + ruleScore + aOutside - logPartition

              if(margScore != Double.NegativeInfinity) {
                spanVisitor.visitBinaryRule(begin, split, end, coarseR, refR, math.exp(margScore))
              }

              split += 1
            }

            i += 1
          }

        }

      }
    }

    if(!spanVisitor.skipUnaryRules)
      for {
        span <- 1 to words.length
        begin <- 0 to (words.length - span)
        end = begin + span
        parent <- 0 until anchoring.refinedTopology.labelIndex.size
      } {
        val end = begin + span
        val aOutside = outside.top(begin, end, parent)
        val labelMarginal = inside.top(begin, end, parent) + aOutside - logPartition
        if (labelMarginal > spanThreshold) {

          for (r <- anchoring.refinedTopology.indexedUnaryRulesWithParent(parent)) {
            val b = anchoring.refinedTopology.child(r)
            val bScore = inside.bot.labelScore(begin, end, b)
            val rScore = anchoring.grammar.ruleScore(r)
            val prob = math.exp(bScore + aOutside + rScore - logPartition)
            val refR = anchoring.refinements.rules.localize(r)
            val projR = anchoring.refinements.rules.project(r)
            if (prob > 0)
              spanVisitor.visitUnaryRule(begin, end, projR, refR, prob)
          }
        }
      }

  }
}

object SimpleChartMarginal  {
  import RefinedChartMarginal.{Summer, MaxSummer, LogSummer}

  def apply[L, L2, W](grammar: SimpleGrammar[L, L2, W], words: IndexedSeq[W]): SimpleChartMarginal[L, L2, W] = {
    SimpleChartMarginal(grammar.anchor(words), maxMarginal = false)
  }

  def apply[L, L2, W](anchoring: SimpleGrammar.Anchoring[L, L2, W], maxMarginal: Boolean): SimpleChartMarginal[L, L2, W] = {
    val sum = if(maxMarginal) MaxSummer else LogSummer
    val inside = buildInsideChart(anchoring, sum)
    val outside = buildOutsideChart(anchoring, inside, sum)
    SimpleChartMarginal(anchoring, inside, outside, maxMarginal)
  }

  private def buildInsideChart[L, L2, W](anchoring: SimpleGrammar.Anchoring[L, L2, W], sum: Summer):SimpleParseChart[L2] = {
    import anchoring._
    val length = anchoring.words.length
    val chart = new SimpleParseChart[L2](anchoring.grammar.refinedTopology.labelIndex, length)

    val lexLoc = anchoring.lexicon.anchor(anchoring.words)
    // handle lexical
    for (i <- 0 until length) {
      var visitedSomething  = false
      for {
        a <- lexLoc.allowedTags(i)
        ref <- anchoring.validLabelRefinements(i, i+ 1, a)
      } {
        val aa = anchoring.refinements.labels.globalize(a, ref)
        val score:Double = anchoring.scoreSpan(i, i+1, a, ref)
        if (score != Double.NegativeInfinity) {
          chart.bot.enter(i, i + 1, aa, score)
          visitedSomething = true
        }
      }

      updateInsideUnaries(chart, anchoring,  i, i+1, sum)
    }

    val tensor = grammar.insideTensor
    val numSyms = tensor.numLeftChildren

    for {
      span <- 2 to length
      begin <- 0 to (length - span)
    } {
      val end = begin + span
      val pcell = chart.bot.cell(begin, end)
      val pdata = pcell.data
      val pdoff = pcell.offset
      var split = begin + 1
      while (split < end) {

        val lcell = chart.top.cell(begin, split)
        val ldata = lcell.data
        val ldoff = lcell.offset
        val rcell = chart.top.cell(split, end)
        val rdata = rcell.data
        val rdoff = rcell.offset

        var lc = 0
        while(lc < numSyms) {
          val lcSpan = tensor.leftChildRange(lc)
          var rcOff = lcSpan.begin
          val rcEnd = lcSpan.end
          val bInside = ldata(ldoff + lc)
          if(bInside != Double.NegativeInfinity) {
            while(rcOff < rcEnd) {
              val rc = tensor.rightChildForOffset(rcOff)
              val cInside = rdata(rdoff + rc)

              val rcSpan = tensor.rightChildRange(rcOff)
              val withoutRule = bInside + cInside


              if(cInside != Double.NegativeInfinity) {
                var pOff = rcSpan.begin
                val pEnd = rcSpan.end
                while(pOff < pEnd) {
                  val p = tensor.parentForOffset(pOff)
                  val score = tensor.ruleScoreForOffset(pOff) + withoutRule
                  pdata(p + pdoff) = sum(pdata(p + pdoff), score)

                  pOff += 1
                }

              }



              rcOff += 1
            }
          }

          lc += 1

        }

        split += 1
      }


      updateInsideUnaries(chart, anchoring,  begin, end, sum)
    }

    chart
  }

  private def buildOutsideChart[L, L2, W](anchoring: SimpleGrammar.Anchoring[L, L2, W],
                                      inside: SimpleParseChart[L2], sum: Summer):SimpleParseChart[L2] = {
    val length = inside.length
    val refinedTopology = anchoring.refinedTopology
    val outside = new SimpleParseChart[L2](refinedTopology.labelIndex, length)
    outside.top.enter(0, inside.length, refinedTopology.rootIndex, 0.0)

    val tensor = anchoring.grammar.outsideTensor
    val numSyms = tensor.numLeftChildren

    for {
      span <- (length) until 0 by (-1)
      begin <- 0 to (length-span)
    } {
      val end = begin + span
      updateOutsideUnaries(outside, anchoring, begin, end, sum)

      val pcell = outside.bot.cell(begin, end)
      val pdata = pcell.data
      val pdoff = pcell.offset

      var a = 0
      while(a < numSyms) {
        val outsideA = pdata(pdoff + a)
        if (outsideA != Double.NegativeInfinity) {
          val pSpan = tensor.leftChildRange(a)
          val lcEnd = pSpan.end
          var split = begin + 1
          while (split < end) {

            val lcell = inside.top.cell(begin, split)
            val ldata = lcell.data
            val ldoff = lcell.offset
            val rcell = inside.top.cell(split, end)
            val rdata = rcell.data
            val rdoff = rcell.offset

            val olcell = outside.top.cell(begin, split)
            val oldata = olcell.data
            val oldoff = olcell.offset
            val orcell = outside.top.cell(split, end)
            val ordata = orcell.data
            val ordoff = orcell.offset

            var lcOff = pSpan.begin
            while(lcOff < lcEnd) {
              val lc = tensor.rightChildForOffset(lcOff)
              val bInside = ldata(ldoff + lc)
              if(bInside != Double.NegativeInfinity) {
                val lcSpan = tensor.rightChildRange(lcOff)
                var rcOff = lcSpan.begin
                val rcEnd = lcSpan.end

                while(rcOff < rcEnd) {
                  val rc = tensor.parentForOffset(rcOff)
                  val score = tensor.ruleScoreForOffset(rcOff) + outsideA
                  val cInside = rdata(rdoff + rc)
                  if (cInside != Double.NegativeInfinity) {
                    oldata(oldoff + lc) = sum(oldata(oldoff + lc), cInside + score)
                    ordata(ordoff + rc) = sum(ordata(ordoff + rc), bInside + score)
//                    outside.top.enter(begin, split, lc, sum(outside.top.labelScore(begin, split, lc), cInside + score))
//                    outside.top.enter(split, end, rc, sum(outside.top.labelScore(split, end, rc), bInside + score))
                  }
                  rcOff += 1
                }
              }
              lcOff += 1
            }
            split += 1
          }
        }
        a += 1
      }
    }
    outside
  }


  private def updateInsideUnaries[L, L2, W](chart: SimpleParseChart[L2],
                                        anchoring: SimpleGrammar.Anchoring[L, L2, W],
                                        begin: Int, end: Int, sum: Summer) = {
    val childCell = chart.bot.cell(begin, end)
    val parentCell = chart.top.cell(begin, end)
    val tensor = anchoring.grammar.insideTensor
    doMatrixMultiply(childCell, parentCell, tensor, sum)

  }


  private def doMatrixMultiply[W, L2, L](childCell: DenseVector[Double], parentCell: DenseVector[Double], tensor: SparseRuleTensor[L2], sum: RefinedChartMarginal.Summer) {
    val numSyms = childCell.size
    val cdata = childCell.data
    val coffset = childCell.offset
    val pdata = parentCell.data
    val poffset = parentCell.offset
    var b = 0
    while (b < numSyms) {
      val bScore = cdata(coffset + b)
      val aSpan = tensor.unaryChildRange(b)
      if (bScore != Double.NegativeInfinity) {
        var aOff = aSpan.begin
        val aEnd = aSpan.end
        while (aOff < aEnd) {
          val a = tensor.unaryParentForOffset(aOff)
          val ruleScore: Double = tensor.unaryScoreForOffset(aOff)
          val prob = pdata(a + poffset)
          pdata(a + poffset) = sum(prob, ruleScore + bScore)
          aOff += 1
        }
      }

      b += 1
    }
  }

  private def updateOutsideUnaries[L, L2, W](outside: SimpleParseChart[L2],
                                            anchoring: SimpleGrammar.Anchoring[L, L2, W],
                                            begin: Int, end: Int, sum: Summer) = {
    val childCell = outside.bot.cell(begin, end)
    val parentCell = outside.top.cell(begin, end)
    val tensor = anchoring.grammar.outsideTensor
    doMatrixMultiply(parentCell, childCell, tensor, sum)

  }

  case class SimpleChartFactory[L, L2, W](refinedGrammar: SimpleGrammar[L, L2, W], maxMarginal: Boolean = false) extends ParseMarginal.Factory[L, W] {
    def apply(w: IndexedSeq[W], constraints: ChartConstraints[L]):SimpleChartMarginal[L, L2, W] = {
      SimpleChartMarginal(refinedGrammar.anchor(w, constraints), maxMarginal = maxMarginal)
    }
  }

}



@SerialVersionUID(1)
final class SimpleParseChart[L](val index: Index[L], val length: Int) extends Serializable {

  val top, bot = new ChartScores()

  final class ChartScores private[SimpleParseChart]() {
    val scores = DenseMatrix.zeros[Double](index.size, TriangularArray.arraySize(length + 1))
    scores := Double.NegativeInfinity
    @inline def labelScore(begin: Int, end: Int, label: Int) = scores(label, TriangularArray.index(begin, end))
    @inline def apply(begin: Int, end: Int, label: Int) = scores(label, TriangularArray.index(begin, end))
    @inline def cell(begin: Int, end: Int) = {
      val ind = TriangularArray.index(begin, end)
      new DenseVector[Double](scores.data, scores.offset + scores.rows * ind, 1, scores.rows)
      //scores(::, TriangularArray.index(begin, end))
    }


    def apply(begin: Int, end: Int, label: L):Double = apply(begin, end, index(label))
    def labelScore(begin: Int, end: Int, label: L):Double = apply(begin, end, index(label))

    def enter(begin: Int, end: Int, parent: Int, w: Double): Unit = {
      scores(parent, TriangularArray.index(begin, end)) = w
    }

  }

}







© 2015 - 2025 Weber Informatics LLC | Privacy Policy