All Downloads are FREE. Search and download functionalities are using the official Maven repository.

epic.features.HackyHeadFinderTest.scala Maven / Gradle / Ivy

The newest version!
package epic.features

import epic.trees.SimpleTreebank
import epic.trees.PartialTreeProcessor
import epic.trees.Treebank
import java.io.File
import epic.trees.HeadFinder
import breeze.linalg.Counter2
import epic.trees.Tree
import breeze.linalg.Counter
import scala.collection.mutable.HashMap

object HackyHeadFinderTest {

  def main(args: Array[String]) {
//    val treebank = new SimpleTreebank(new File(ptbPath), new File(ptbPath), new File(ptbPath));
    val treebank = Treebank.fromPennTreebankDir(new File("data/wsj"));
    
    val process = PartialTreeProcessor();
    val treesWords = treebank.train.trees.toSeq;
    val processedTreesWords = treesWords.map(treeWordsPair => (process(treeWordsPair._1), treeWordsPair._2));
    
    println("Training lexicon");
    var sentIdx = 0;
    val trainWordTagCounts = Counter2[String,String,Double];
    for ((tree, words) <- processedTreesWords) {
      if (sentIdx % 1000 == 0) {
        println("Sentence: " + sentIdx);
      }
      val treeLeaves = tree.leaves.toSeq;
      for (i <- 0 until treeLeaves.size) {
        trainWordTagCounts(treeLeaves(i).label, words(i)) += 1.0;
      }
      sentIdx += 1;
    }
    val wordToTagMap = new HashMap[String,String];
    for (word <- trainWordTagCounts.keysIterator.map(_._2)) {
      var bestTag = "";
      var bestTagCount = 0.0;
      val tagCounts = trainWordTagCounts(::, word).iterator;
      for ((tag, count) <- tagCounts) {
        if (count > bestTagCount) {
          bestTag = tag;
          bestTagCount = count;
        }
      }
      wordToTagMap.put(word, bestTag);
    }
    println("Done training lexicon");
    
    
    val hf = HeadFinder.collins;
    val hackyHeadFinder = new RuleBasedHackyHeadFinder;
    
    var correct = Counter[String,Int];
    var correctPredTags = Counter[String,Int];
    var total = Counter[String,Int];
    def rec(tree: Tree[(String,Int)], words: Seq[String]): Unit = {
      if (!tree.isLeaf && !tree.label._1.isEmpty) {
        val headIdx = tree.label._2 - tree.begin;
        val hhfHead = hackyHeadFinder.findHead(tree.label._1, tree.leaves.map(_.label._1).toSeq);
        val predTags = words.slice(tree.begin, tree.end).map(word => if (wordToTagMap.contains(word)) wordToTagMap(word) else "NN");
        val hhfHeadPredTags = hackyHeadFinder.findHead(tree.label._1, predTags);
        if (hhfHead == headIdx) {
          correct(tree.label._1) += 1;
        }
        if (hhfHeadPredTags == headIdx) {
          correctPredTags(tree.label._1) += 1;
        } else {
          println(tree.label + " => " + tree.leaves.map(_.label._1).toIndexedSeq + "\n      " + predTags + "; gold = " + headIdx + ", pred (gold) = " + hhfHead + ", pred (pred) = " + hhfHeadPredTags);
        }
        total(tree.label._1) += 1;
      }
      if (!tree.isLeaf) {
        tree.children.foreach(rec(_, words));
      }
    };
    
    val devTreesWords = treebank.dev.trees.toSeq.map(treeWordsPair => (hf.annotateHeadIndices(process(treeWordsPair._1)), treeWordsPair._2));
    for (i <- 0 until 100) {
      val tree = devTreesWords(i)._1;
      val words = devTreesWords(i)._2;
      rec(tree, words);
      
//      println(tree.render(devTreesWords(i)._2, false));
//      println(processedTrees(i).render(treesWords(i)._2, false));
//      println(processedTreesWithIndices(i).render(treesWords(i)._2, false));
    }
    var totalAcc = 0;
    var totalCount = 0;
    for (key <- total.keySet) {
      println(key + ": " + correctPredTags(key) + " / " + total(key));
      totalAcc += correctPredTags(key);
      totalCount += total(key);
    }
    println(totalAcc + " / " + totalCount);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy