All Downloads are FREE. Search and download functionalities are using the official Maven repository.

epic.sentiment.SentimentTreebankPipeline.scala Maven / Gradle / Ivy

The newest version!
package epic.sentiment

import java.io.File
import breeze.config.CommandLineParser
import epic.trees._
import epic.parser.models.{ParserInference, ParserModel}
import epic.parser._
import breeze.linalg._
import epic.framework._
import epic.constraints.{LabeledSpanConstraints, SpanConstraints, ChartConstraints}
import breeze.optimize.CachedBatchDiffFunction
import com.typesafe.scalalogging.slf4j.LazyLogging
import epic.parser.models.SpanModelFactory
import epic.trees.ProcessedTreebank
import epic.trees.TreeInstance
import breeze.optimize.FirstOrderMinimizer.OptParams
import epic.trees.Span
import scala.collection.mutable.HashMap
import breeze.util._
import epic.parser.models.ParserExtractableModelFactory

/**
 *
 *
 * @author dlwh
 */
object SentimentTreebankPipeline extends LazyLogging {
  case class Options(path: File,
                     opt: OptParams,
                     lossType: String = "",
                     iterPerEval: Int = 100,
                     evalOnTest: Boolean = false,
                     includeDevInTrain: Boolean = false,
                     modelFactory: ParserExtractableModelFactory[AnnotatedLabel, String] = new SpanModelFactory,
                     rootLossScaling: Double = 1.0)


  def main(args: Array[String]):Unit = {
    val params = CommandLineParser.readIn[Options](args)

    val treebank = new ProcessedTreebank(params.path, treebankType = "simple")

    var trainTrees: IndexedSeq[TreeInstance[AnnotatedLabel, String]] = treebank.trainTrees
    if(params.evalOnTest && params.includeDevInTrain)
      trainTrees ++= treebank.devTrees
    println(trainTrees.size + " train trees, " + treebank.devTrees.size + " dev trees, " + treebank.testTrees.size + " test trees");
    val gen = GenerativeParser.fromTrees(trainTrees)


    class GoldBracketingsConstraints extends ChartConstraints.Factory[AnnotatedLabel, String] {
      val trees = (trainTrees ++ treebank.devTrees ++ treebank.testTrees).map(ti => ti.words -> ti.tree).toMap
//      val trees = ((if (params.includeDevInTrain) trainTrees else trainTrees ++ treebank.devTrees) ++ treebank.testTrees).map(ti => ti.words -> ti.tree).toMap

      def constraints(w: IndexedSeq[String]): ChartConstraints[AnnotatedLabel] = {
        val constraints = SpanConstraints.fromTree(trees.getOrElse(w, gen.bestBinarizedTree(w)))
        val cons = new LabeledSpanConstraints.PromotedSpanConstraints(constraints)
        ChartConstraints(cons, cons)
      }
    }

    // TODO: params are inelegant
    val sentimentLoss: (Int, Int) => Double = if (params.lossType == "defaultLoss") {
      SentimentLossAugmentation.defaultLoss
    } else if (params.lossType == "posNegLoss") {
      SentimentLossAugmentation.posNegLoss
    } else if (params.lossType == "hammingLoss") {
      SentimentLossAugmentation.hammingLoss
    } else {
      SentimentLossAugmentation.noLoss;
    }
    val constrainer = new SentimentLossAugmentation(trainTrees,
      gen.topology,
      gen.lexicon,
      new GoldBracketingsConstraints,
      sentimentLoss,
      params.rootLossScaling)

//    val model = new SpanModelFactory(annotator = GenerativeParser.defaultAnnotator(vertical = params.v), dummyFeats = 0.5).make(trainTrees, constrainer)
    val model = params.modelFactory.make(trainTrees, gen.topology, gen.lexicon, new GoldBracketingsConstraints)

    val obj = new ModelObjective(model, trainTrees)
    val cachedObj = new CachedBatchDiffFunction(obj)
    val init = obj.initialWeightVector(true)

    val name = "SentiParser"

    for ((state, iter) <- params.opt.iterations(cachedObj, init).take(1000).zipWithIndex
         if iter % params.iterPerEval == 0) try {
      val parser = model.extractParser(state.x).copy(decoder=new MaxConstituentDecoder[AnnotatedLabel, String])
//      if(params.evalOnTest)
//        println("Eval: " + evaluate(s"$name-$iter", parser, treebank.testTrees))
//      else
//        println("Eval: " + evaluate(s"$name-$iter", parser, treebank.devTrees))
      if(params.evalOnTest) {
        println("NORMAL DECODE: Eval: " + evaluate(s"$name-$iter", parser, treebank.testTrees, DecodeType.Normal));
      } else {
        println("Span confusions");
        println(renderArr(evaluateSpanConfusions(s"$name-$iter", parser, treebank.devTrees, DecodeType.Normal)));
        println("Root confusions");
        println(renderArr(evaluateRootConfusions(s"$name-$iter", parser, treebank.devTrees, DecodeType.Normal)));
        println("NORMAL DECODE: Eval: " + evaluate(s"$name-$iter", parser, treebank.devTrees, DecodeType.Normal));
        println("TERNARY DECODE: Eval: " + evaluate(s"$name-$iter", parser, treebank.devTrees, DecodeType.Ternary));
//        println("BINARY DECODE: Eval: " + evaluateBetter(s"$name-$iter", parser, treebank.devTrees, DecodeType.Binary));
      }
    } catch {
      case e: Exception => e.printStackTrace(); throw e
    }


  }
  
  def renderArr(arr: Array[Array[Int]]) = arr.map(_.map(_.toString).reduce(_ + "\t" + _)).reduce(_ + "\n" + _);


  class Model[L, W](val inner: ParserModel[L, W]) extends epic.framework.Model[TreeInstance[L, W]] {
    type ExpectedCounts = inner.ExpectedCounts
    type Marginal = inner.Marginal
    type Inference = SentimentTreebankPipeline.Inference[L, W]
    type Scorer = inner.Scorer

    def emptyCounts = inner.emptyCounts


    def accumulateCounts(inf: Inference, s: Scorer, d: TreeInstance[L, W], m: Marginal, accum: ExpectedCounts, scale: Double): Unit = {
      inner.accumulateCounts(inf.pm.asInstanceOf[inner.Inference], s, d, m, accum, scale)
    }

    /**
     * Models have features, and this defines the mapping from indices in the weight vector to features.
     * @return
     */
    def featureIndex: Index[Feature] = inner.featureIndex

    def initialValueForFeature(f: Feature): Double = inner.initialValueForFeature(f)

    def inferenceFromWeights(weights: DenseVector[Double]): Inference = new SentimentTreebankPipeline.Inference(inner.inferenceFromWeights(weights))

    def expectedCountsToObjective(ecounts: ExpectedCounts): (Double, DenseVector[Double]) = {
      inner.expectedCountsToObjective(ecounts)
    }
  }

  class Inference[L, W](val pm: ParserInference[L, W]) extends epic.framework.Inference[TreeInstance[L, W]] {
    val labels = pm.grammar.topology.labelIndex.toIndexedSeq.map(_ -> 0)
    type Scorer = pm.Scorer
    type Marginal = pm.Marginal

    def scorer(v: TreeInstance[L, W]): Scorer = pm.scorer(v)

    def goldMarginal(scorer: Scorer, v: TreeInstance[L, W]): Inference[L, W]#Marginal = {
      pm.goldMarginal(scorer, v)
    }


    def marginal(anch: Scorer, v: TreeInstance[L, W]): Inference[L, W]#Marginal = {
      LatentTreeMarginal[L, W](anch, v.tree.map(l => labels:scala.collection.IndexedSeq[(L, Int)]))
    }
  }

  case class Stats(spansRight: Int,
                         numSpans: Int,
                         spansRightTernary: Int, // denom is same as numSpans
                         spansRightBinary: Int,
                         numBinarySpans: Int,
                         numBinarySpansSocher: Int,
                         rootsRight: Int,
                         numRoots: Int,
                         rootsRightTernary: Int, // denom is same as numSpans
                         rootsRightBinary: Int,
                         numBinaryRoots: Int,
                         numBinaryRootsSocher: Int) {
    def +(stats: Stats) = Stats(spansRight + stats.spansRight,
                                            numSpans + stats.numSpans,
                                            spansRightTernary + stats.spansRightTernary,
                                            spansRightBinary + stats.spansRightBinary,
                                            numBinarySpans + stats.numBinarySpans,
                                            numBinarySpansSocher + stats.numBinarySpansSocher,
                                            rootsRight + stats.rootsRight,
                                            numRoots + stats.numRoots,
                                            rootsRightTernary + stats.rootsRightTernary,
                                            rootsRightBinary + stats.rootsRightBinary,
                                            numBinaryRoots + stats.numBinaryRoots,
                                            numBinaryRootsSocher + stats.numBinaryRootsSocher);

    
    override def toString = {
      val render: (Int, Int) => String = SentimentEvaluator.renderNumerDenom;
      "Spans: " + render(spansRight, numSpans) + "\n" +
                    "  Ternary: " + render(spansRightTernary, numSpans) + "\n" +
//                    "  Binary: " + render(spansRightBinary, numBinarySpans) + "\n" +
//                    "  Binary (Socher): " + render(spansRightBinary, numBinarySpansSocher) + "\n" +
                    "Roots: " + render(rootsRight, numRoots) + "\n" +
                    "  Ternary: " + render(rootsRightTernary, numRoots) + "\n";
//                    "  Binary: " + render(rootsRightBinary, numBinaryRoots) + "\n" +
//                    "  Binary (Socher): " + render(rootsRightBinary, numBinaryRootsSocher);
    }
      
                                            
//    override def toString = f"Stats(cspans=${coarseSpansRight.toDouble/coarseSpans}%.4f: $coarseSpansRight/$coarseSpans spans=${spansRight.toDouble/numSpans}%.4f: $spansRight/$numSpans, coarseRoots=${coarseRootsRight.toDouble/numCoarseRoots}: $coarseRootsRight/$numCoarseRoots , roots=${rootsRight.toDouble/numRoots}%.4f: $rootsRight/$numRoots)"
  }

  object DecodeType extends Enumeration {
    type DecodeType = Value
    val Normal, Binary, Ternary = Value;
  }
  import DecodeType._
  
  def evaluateSpanConfusions(name: String, parser: Parser[AnnotatedLabel, String], testTrees: IndexedSeq[TreeInstance[AnnotatedLabel, String]], decodeType: DecodeType) = {
    testTrees.par.map { ti =>
      val spanConfusions = Array.tabulate(5, 5)((i, j) => 0);
      val goldTree = ti.tree.children.head.map(_.label.toInt)
      val marg = parser.marginal(ti.words)
      val guessTree = decode(ti.tree.map(_ => ()), marg, decodeType).map(_.label.toInt)
      val guess: Set[(Int, Span)] = guessTree.preorder.map(t => (t.label, t.span)).toSet
      val guessMap: HashMap[Span,Int] = new HashMap[Span,Int]() ++ guess.map(_.swap)
      val gold: Set[(Int, Span)] = goldTree.preorder.map(t => (t.label, t.span)).toSet
      for ((gLabel, gSpan) <- gold) {
        val pLabel = guessMap(gSpan);
        spanConfusions(gLabel)(pLabel) += 1;
      }
      spanConfusions;
    }.reduce((arr1, arr2) => Array.tabulate(5, 5)((i, j) => arr1(i)(j) + arr2(i)(j)));
  }
  
  def evaluateRootConfusions(name: String, parser: Parser[AnnotatedLabel, String], testTrees: IndexedSeq[TreeInstance[AnnotatedLabel, String]], decodeType: DecodeType) = {
    testTrees.par.map { ti =>
      val rootConfusions = Array.tabulate(5, 5)((i, j) => 0);
      val goldTree = ti.tree.children.head.map(_.label.toInt)
      val marg = parser.marginal(ti.words)
      val guessTree = decode(ti.tree.map(_ => ()), marg, decodeType).map(_.label.toInt)
      rootConfusions(goldTree.label)(guessTree.label) += 1;
      rootConfusions;
    }.reduce((arr1, arr2) => Array.tabulate(5, 5)((i, j) => arr1(i)(j) + arr2(i)(j)));
  }

  def evaluate(name: String, parser: Parser[AnnotatedLabel, String], testTrees: IndexedSeq[TreeInstance[AnnotatedLabel, String]], decodeType: DecodeType) = {
    println("Evaluating at " + name);
    testTrees.par.map { ti =>
      val goldTree = ti.tree.children.head.map(_.label.toInt)
      val goldRoot = goldTree.label
      val marg = parser.marginal(ti.words)
      val guessTree = decode(ti.tree.map(_ => ()), marg, decodeType).map(_.label.toInt)
      val guessRoot = guessTree.label;
      val guess: Set[(Int, Span)] = guessTree.preorder.map(t => (t.label, t.span)).toSet
      val guessMap: HashMap[Span,Int] = new HashMap[Span,Int]() ++ guess.map(pair => (pair._2, pair._1));
      val gold: Set[(Int, Span)] = goldTree.preorder.map(t => (t.label, t.span)).toSet
      
      var spansRight = 0;
      var numSpans = 0;
      var spansRightTernary = 0;
      var spansRightBinary = 0;
      var numBinarySpans = 0;
      var numBinarySpansSocher = 0;
      for ((gLabel, gSpan) <- gold) {
        val pLabel = guessMap(gSpan);
        spansRight += (if (SentimentEvaluator.isCorrectNormal(gLabel, pLabel)) 1 else 0);
        numSpans += 1;
        spansRightTernary += (if (SentimentEvaluator.isCorrectTernary(gLabel, pLabel)) 1 else 0);
        spansRightBinary += (if (SentimentEvaluator.isUsedBinaryCoarse(gLabel, pLabel) && SentimentEvaluator.isCorrectBinary(gLabel, pLabel)) 1 else 0);
        numBinarySpans += (if (SentimentEvaluator.isUsedBinaryCoarse(gLabel, pLabel)) 1 else 0);
        numBinarySpansSocher += (if (SentimentEvaluator.isUsedSocherCoarse(gLabel, pLabel)) 1 else 0);
      }
      val rootsRight = (if (SentimentEvaluator.isCorrectNormal(goldRoot, guessRoot)) 1 else 0);
      val numRoots = 1;
      val rootsRightTernary = if (SentimentEvaluator.isCorrectTernary(goldRoot, guessRoot)) 1 else 0;
      val rootsRightBinary = (if (SentimentEvaluator.isUsedBinaryCoarse(goldRoot, guessRoot) && SentimentEvaluator.isCorrectBinary(goldRoot, guessRoot)) 1 else 0);
      val numBinaryRoots = (if (SentimentEvaluator.isUsedBinaryCoarse(goldRoot, guessRoot)) 1 else 0);
      val numBinaryRootsSocher = (if (SentimentEvaluator.isUsedSocherCoarse(goldRoot, guessRoot)) 1 else 0);
      Stats(spansRight, numSpans, spansRightTernary, spansRightBinary, numBinarySpans, numBinarySpansSocher,
                  rootsRight, numRoots, rootsRightTernary, rootsRightBinary, numBinaryRoots, numBinaryRootsSocher)
    }.reduce(_+_);
  }

  def decode(tree: BinarizedTree[Unit], marginal: ParseMarginal[AnnotatedLabel, String], decodeType: DecodeType) = {
    val (topMarg, botMarg) = marginal.labelMarginals
    tree.extend { t =>
      val summed = topMarg(t.begin, t.end)
      if(decodeType == Binary) {
        val neg = (summed(AnnotatedLabel("0")) + summed(AnnotatedLabel("1")) )
        val pos = (summed(AnnotatedLabel("3")) + summed(AnnotatedLabel("4")) )
        if(neg > pos) {
          AnnotatedLabel("0")
        } else  {
          AnnotatedLabel("4")
        }
      } else if(decodeType == Ternary) {
        val neg = (summed(AnnotatedLabel("0")) + summed(AnnotatedLabel("1")) )
        val pos = (summed(AnnotatedLabel("3")) + summed(AnnotatedLabel("4")) )
        val neutral = (summed(AnnotatedLabel("2")));
        
        if(neg > pos && neg > neutral) {
          AnnotatedLabel("0")
        } else if (pos > neg && pos > neutral) {
          AnnotatedLabel("4")
        } else {
          AnnotatedLabel("2")
        }
      } else {
        summed.argmax
      }
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy