All Downloads are FREE. Search and download functionalities are using the official Maven repository.

epic.sentiment.SentimentLossAugmentation.scala Maven / Gradle / Ivy

The newest version!
package epic.sentiment

import epic.trees.{Span, AnnotatedLabel, TreeInstance}
import epic.framework.LossAugmentation
import epic.lexicon.Lexicon
import epic.parser.{UnrefinedGrammarAnchoring, RuleTopology}
import epic.constraints.ChartConstraints

/**
 * TODO
 *
 * @author dlwh
 **/
case class SentimentLossAugmentation[W](trainTrees: IndexedSeq[TreeInstance[AnnotatedLabel, W]],
                                        topology: RuleTopology[AnnotatedLabel],
                                        lexicon: Lexicon[AnnotatedLabel, W],
                                        constraintFactory: ChartConstraints.Factory[AnnotatedLabel, W],
                                        loss: (Int, Int)=>Double = SentimentLossAugmentation.defaultLoss,
                                        rootLossScaling:Double = 1.0) extends LossAugmentation[TreeInstance[AnnotatedLabel, W], UnrefinedGrammarAnchoring[AnnotatedLabel, W]] {

  val losses = Array.tabulate(5,5)(loss)

  def projectedLabel(l: AnnotatedLabel) =   if(l == AnnotatedLabel.TOP) -1 else l.label.toInt
  val sentimentScores: Array[Int] = topology.labelEncoder.tabulateArray(projectedLabel)

  val trainingMap = trainTrees.iterator.map(ti => ti.words -> ti).toMap

  def lossAugmentation(datum: TreeInstance[AnnotatedLabel, W]): UnrefinedGrammarAnchoring[AnnotatedLabel, W] = {
    // drop the root
    val goldMap = datum.tree.map(projectedLabel).preorder.filter(_.label != -1).map{t => t.span -> t.label}.toMap

    new SentimentLossAnchoring(topology, lexicon, datum.words, goldMap, constraintFactory.constraints(datum.words))
  }


  /**
   * Returns a [[epic.parser.UnrefinedGrammarAnchoring]] for this particular sentence.
   * @param words
   * @return
   */
  def anchor(words: IndexedSeq[W]): UnrefinedGrammarAnchoring[AnnotatedLabel, W] = {
    trainingMap.get(words)
      .map( ti =>lossAugmentation(ti) )
      .getOrElse ( UnrefinedGrammarAnchoring.identity(topology, lexicon, words, constraintFactory.constraints(words)) )
  }

  case class SentimentLossAnchoring[L, W](topology: RuleTopology[L],
                                          lexicon: Lexicon[L, W],
                                          words: IndexedSeq[W],
                                          goldLabels: Map[Span, Int],
                                          sparsityPattern: ChartConstraints[L])  extends epic.parser.UnrefinedGrammarAnchoring[L, W]{
    override def addConstraints(cs: ChartConstraints[L]): UnrefinedGrammarAnchoring[L, W] = copy(sparsityPattern = sparsityPattern & cs)

    def scoreBinaryRule(begin: Int, split: Int, end: Int, rule: Int): Double = 0.0

    def scoreUnaryRule(begin: Int, end: Int, rule: Int): Double = 0

    def scoreSpan(begin: Int, end: Int, tag: Int): Double = {
       goldLabels.get(Span(begin,end)) match {
        case Some(goldLabel) =>
          assert(goldLabel != -1)
          val guessLabel = sentimentScores(tag)
          if(guessLabel == -1) {
            breeze.numerics.I(goldLabel == guessLabel) * 10000
          } else {
            losses(goldLabel)(guessLabel) * (if (begin == 0 && end == words.size) rootLossScaling else 1.0);
          }
        case None =>
           0
      }
    }

  }


}

object SentimentLossAugmentation {
  def defaultLoss(gold: Int, guess: Int) = (gold - guess).abs.toDouble
  def posNegLoss(gold: Int, guess: Int) = {
    if (gold > 2) {
      if (guess > 2) 0 else 1
    } else if (gold < 2) {
      if (guess < 2) 0 else 1
    } else {
      if (guess == 2) 0 else 1
    }
  }
  def hammingLoss(gold: Int, guess: Int) = if (gold != guess) 1 else 0;
  def noLoss(gold: Int, guess: Int) = 0
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy