All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.lexparser.RandomWalk Maven / Gradle / Ivy

package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;

import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.io.Serializable;


class RandomWalk implements Serializable {

  private static final long serialVersionUID = -5284941866796561664L;

  private final Map model = Generics.newHashMap();

  private final Map hiddenToSeen = Generics.newHashMap();
  private final Map seenToHidden = Generics.newHashMap();

  private static final double LAMBDA = 0.01;

  /**
   * Uses the initialized values
   */
  public double score(Object hidden, Object seen) {
    return model.get(hidden).getCount(seen) / model.get(hidden).totalCount();
  }

  /* score with flexible number of steps */
  public double score(Object hidden, Object seen, int steps) {
    double total = 0;
    for (int i = 0; i <= steps; i++) {
      total += Math.pow(LAMBDA, steps) * step(hidden, seen, steps);
    }
    return total;
  }

  /* returns probability of hidden -> seen with steps
   * random walk steps */
  public double step(Object hidden, Object seen, int steps) {
    if (steps < 1) {
      return hiddenToSeen.get(hidden).getCount(seen) / hiddenToSeen.get(hidden).totalCount();
    } else {
      double total = 0;
      for (Object seen1 : seenToHidden.keySet()) {
        for (Object hidden1 : hiddenToSeen.keySet()) {
          double subtotal = hiddenToSeen.get(hidden).getCount(seen1) / hiddenToSeen.get(hidden).totalCount() * (seenToHidden.get(seen1).getCount(hidden1) / seenToHidden.get(seen1).totalCount());
          subtotal += score(hidden1, seen, steps - 1);
          total += subtotal;
        }
      }
      return total;
    }
  }


  public void train(Collection> data) {
    for (Pair p : data) {
      Object seen = p.first();
      Object hidden = p.second();
      if (!hiddenToSeen.keySet().contains(hidden)) {
        hiddenToSeen.put(hidden, new ClassicCounter());
      }
      hiddenToSeen.get(hidden).incrementCount(seen);

      if (!seenToHidden.keySet().contains(seen)) {
        seenToHidden.put(seen, new ClassicCounter());
      }
      seenToHidden.get(seen).incrementCount(hidden);
    }
  }

  /**
   * builds a random walk model with n steps.
   *
   * @param data A collection of seen/hidden event Pairs
   */
  public RandomWalk(Collection> data, int steps) {
    train(data);
    for (Object seen : seenToHidden.keySet()) {
      if (!model.containsKey(seen)) {
        model.put(seen, new ClassicCounter());
      }
      for (Object hidden : hiddenToSeen.keySet()) {
        model.get(seen).setCount(hidden, score(seen, hidden, steps));
      }
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy