All Downloads are FREE. Search and download functionalities are using the official Maven repository.

epic.parser.projections.OracleParser.scala Maven / Gradle / Ivy

The newest version!
package epic.parser.projections

import epic.constraints.{CachedChartConstraintsFactory, ChartConstraints}
import epic.trees._
import epic.parser._
import breeze.numerics.I
import epic.util.{NotProvided, Optional, SafeLogging, CacheBroker}
import epic.parser.GrammarAnchoring.StructureDelegatingAnchoring
import breeze.config.{Configuration, CommandLineParser, Help}
import java.io.File
import breeze.util._
import epic.parser.ParseExtractionException
import epic.trees.ProcessedTreebank
import epic.trees.TreeInstance
import epic.parser.ViterbiDecoder
import epic.parser.ParserParams.XbarGrammar
import breeze.collection.mutable.TriangularArray
import epic.constraints.ChartConstraints.UnifiedFactory

/**
 * Finds the best tree (relative to the gold tree) s.t. it's reacheable given the current anchoring.
 * Best is measured as number of correct labeled spans, as usual. If the given treebank symbol is correct,
 * bonus points can be awarded for getting the right refinement.
 *
 * On the training set, the "best" reachable tree will not always (~5% of the time) be the correct tree,
 * because pruning will remove the right answer. We don't want to try to train towards an unreachable tree,
 * because the training algorithm will do bad things. Instead, we want the best possible tree that
 * our parser could conceivably produce. That is why this class exists.
 *
 * If backupGrammar is provided, it will be used to find such a tree in the case that no tree can be
 * found with grammar (given the current constraints).
 *
 * Typically, the first grammar will be a treebank grammar that has no horizontal markovization (i.e.
 * it is not forgetfully binarized) and it also remembers the functional tags like -TMP.
 * The backup grammar is usually the grammar
 * with which the pruning masks were produced; because of the way we prune, that parser will always
 * be able to find a tree (assuming that it was able to find a tree without pruning.)
 *
 * TODO: should be a cascade of grammars
 *
 * @author dlwh
 */
class OracleParser[L, L2, W](val grammar: SimpleGrammar[L, L2, W], backupGrammar: Optional[SimpleGrammar[L, L2, W]] = NotProvided) extends SafeLogging {
  private val cache = CacheBroker().make[IndexedSeq[W], BinarizedTree[L2]]("OracleParser")

  private var problems  = 0
  private var total = 0

  def forTree(tree: BinarizedTree[L2],
              words: IndexedSeq[W],
              constraints: ChartConstraints[L]) = try {
    val projectedTree: BinarizedTree[L] = tree.map(grammar.refinements.labels.project)
    cache.getOrElseUpdate(words, {
      val treeconstraints = ChartConstraints.fromTree(grammar.topology.labelIndex, projectedTree)
      if(constraints.top.containsAll(treeconstraints.top) && constraints.bot.containsAll(treeconstraints.bot)) {
        synchronized(total += 1)
        tree
      } else try {

        logger.warn {
          val ratio = synchronized {
            problems += 1;
            total += 1;
            problems * 1.0 / total
          }
          f"Gold tree for $words not reachable. $ratio%.2f are bad so far. "
        }

        (IndexedSeq(grammar) ++ backupGrammar.iterator).iterator.map { grammar =>
          try {
            val w = words
            val marg = makeGoldPromotingAnchoring(grammar, w, tree, treeconstraints, constraints).maxMarginal

            val closest: BinarizedTree[(L, Int)] = new ViterbiDecoder[L, W]().extractMaxDerivationParse(marg)

            val globalizedClosest: BinarizedTree[L2] = closest.map({
              grammar.refinements.labels.globalize(_: L, _: Int)
            }.tupled)
            logger.trace {
              s"Reachable: ${globalizedClosest.render(words, newline = true)}\nGold:${tree.render(words, newline = true)}"
            }
            Some(globalizedClosest)
          } catch {
            case ex: ParseExtractionException => None
          }
        }.flatten.next()

      } catch {
        case ex:NoSuchElementException =>
          logger.error("Couldn't find a tree for $words")
          tree
      }
    })
  } catch {
    case ex: Exception =>
      logger.error(s"while handling projectability for $tree $words: " + ex.getMessage, ex)
      throw ex
  }


  def makeGoldPromotingAnchoring(grammar: SimpleGrammar[L, L2, W],
                                 w: IndexedSeq[W],
                                 tree: BinarizedTree[L2],
                                 treeconstraints: ChartConstraints[L],
                                 constraints: ChartConstraints[L]): GrammarAnchoring[L, W] = {
    import grammar.{refinements => _, _}
    val correctRefinedSpans = GoldTagPolicy.goldTreeForcing(tree.map(grammar.refinements.labels.fineIndex))
    val correctUnaryChains = {
      val arr = TriangularArray.fill(w.length + 1)(IndexedSeq.empty[String])
      for(UnaryTree(a, btree, chain, span) <- tree.allChildren) {
        arr(span.begin, span.end) = chain
      }
      arr
    }
    new StructureDelegatingAnchoring[L, W] {
      protected val baseAnchoring: GrammarAnchoring[L, W] = grammar.anchor(w)
      override def addConstraints(cs: ChartConstraints[L]): GrammarAnchoring[L, W] = {
        makeGoldPromotingAnchoring(grammar, w, tree, treeconstraints, constraints & cs)
      }


      override def sparsityPattern: ChartConstraints[L] = constraints

      def scoreBinaryRule(begin: Int, split: Int, end: Int, rule: Int, ref: Int): Double = {
        0.1 * baseAnchoring.scoreBinaryRule(begin, split, end, rule, ref)
      }

      def scoreUnaryRule(begin: Int, end: Int, rule: Int, ref: Int): Double = {
        val top = topology.parent(rule)
        val isAllowed = treeconstraints.top.isAllowedLabeledSpan(begin, end, top)
        val isRightRefined = isAllowed && {
          val parentRef = baseAnchoring.parentRefinement(rule, ref)
          val refTop = grammar.refinements.labels.globalize(top, parentRef)
          correctRefinedSpans.isGoldTopTag(begin, end, refTop)
        }

        val isCorrectChain = {
          isAllowed && topology.chain(rule) == correctUnaryChains(begin, end)
        }

        200 * I(isAllowed) + 10 * I(isRightRefined) + 5 * I(isCorrectChain) + 0.1 * baseAnchoring.scoreUnaryRule(begin, end, rule, ref)
      }

      def scoreSpan(begin: Int, end: Int, tag: Int, ref: Int): Double = {
        val globalized = grammar.refinements.labels.globalize(tag, ref)
        200 * I(treeconstraints.bot.isAllowedLabeledSpan(begin, end, tag)) +
          10 * I(correctRefinedSpans.isGoldBotTag(begin, end, globalized)) + 0.1 * baseAnchoring.scoreSpan(begin, end, tag, ref)
      }
    }


  }

  def oracleMarginalFactory(trees: IndexedSeq[TreeInstance[L2, W]]):ParseMarginal.Factory[L, W] = new ParseMarginal.Factory[L, W] {
    val knownTrees = trees.iterator.map(ti => ti.words -> ti.tree).toMap

    def apply(w: IndexedSeq[W], constraints: ChartConstraints[L]): RefinedChartMarginal[L, W] = {
      val refAnchoring = knownTrees.get(w).map { t =>
        val projectedTree: BinarizedTree[L] = t.map(grammar.refinements.labels.project)
        val treeconstraints = ChartConstraints.fromTree(grammar.topology.labelIndex, projectedTree)
        makeGoldPromotingAnchoring(grammar, w, t, treeconstraints, constraints)
      }.getOrElse(grammar.anchor(w))

      RefinedChartMarginal(refAnchoring, true)
    }
  }

  def oracleParser(constraintGrammar: ChartConstraints.Factory[L, W], trees: IndexedSeq[TreeInstance[L2, W]])(implicit deb: Debinarizer[L]): Parser[L, W] = {
    new Parser(grammar.topology, grammar.lexicon, constraintGrammar, oracleMarginalFactory(trees), ViterbiDecoder())
  }
}

object OracleParser {
  case class Params(treebank: ProcessedTreebank,
                    @Help(text="Prefix for the name of the eval directory. Sentences are dumped for EVALB to read.")
                    name: String,
                    @Help(text="Path to the parser file. Look in parsers/")
                    parser: File = null,
                    implicit val cache: CacheBroker,
                    @Help(text="Print this and exit.")
                    help: Boolean = false,
                    @Help(text="How many threads to parse with. Default is whatever Scala wants")
                    threads: Int = -1)

  def main(args: Array[String]) {
    val params = CommandLineParser.readIn[Params](args)
    println("Command line arguments for recovery:\n" + Configuration.fromObject(params).toCommandLineString)
    println("Evaluating Parser...")

    import params.treebank._

    val ann = GenerativeParser.defaultAnnotator()

    val initialParser = params.parser match {
      case null =>
        val (grammar, lexicon) = XbarGrammar().xbarGrammar(trainTrees)
        GenerativeParser.annotatedParser(grammar, lexicon, ann, trainTrees)
      //        GenerativeParser.annotatedParser(grammar, lexicon, Xbarize(), trainTrees)
      case f =>
        readObject[Parser[AnnotatedLabel, String]](f)
    }

    val constraints: ChartConstraints.Factory[AnnotatedLabel, String] = new UnifiedFactory({

      val maxMarginalized = initialParser.copy(marginalFactory=initialParser.marginalFactory match {
        case StandardChartFactory(ref, mm) => StandardChartFactory(ref, maxMarginal = true)
        case x => x
      })


      val uncached = new ParserChartConstraintsFactory[AnnotatedLabel, String](maxMarginalized, {(_:AnnotatedLabel).isIntermediate})
      new CachedChartConstraintsFactory[AnnotatedLabel, String](uncached)(params.cache)
    })

    val refGrammar = GenerativeParser.annotated(initialParser.topology, initialParser.lexicon, ann, trainTrees)

    val annDevTrees = devTrees.map(ann)

    val parser = {
      new OracleParser(refGrammar).oracleParser(constraints, annDevTrees)
    }

    for(ti <- devTrees.par) try {
      val m = TreeMarginal(refGrammar, ti.words, ann.localized(refGrammar.refinements.labels)(ti.tree, ti.words))
      println(m.logPartition)
    } catch {
      case e: Exception => e.printStackTrace()
    }


    val name = params.name

    println("Parser " + name)

    {
      println("Evaluating Parser on dev...")
      val stats = ParserTester.evalParser(devTrees, parser, name+ "-dev", params.threads)
      println("Eval finished. Results:")
      println(stats)
    }


  }

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy