All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.sequences.SequenceGibbsSampler Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.sequences; 
import edu.stanford.nlp.util.logging.Redwood;

import edu.stanford.nlp.util.RuntimeInterruptedException;
import edu.stanford.nlp.util.concurrent.*;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.StringUtils;

//debug
import edu.stanford.nlp.ie.*;

import java.util.*;
import java.io.PrintStream;

// TODO: change so that it uses the scoresOf() method properly

/**
 * A Gibbs sampler for sequence models. Given a sequence model implementing the SequenceModel
 * interface, this class is capable of
 * sampling sequences from the distribution over sequences that it defines. It can also use
 * this sampling procedure to find the best sequence.
 * @author grenager
 */
public class SequenceGibbsSampler implements BestSequenceFinder  {

  /** A logger for this class */
  private static Redwood.RedwoodChannels log = Redwood.channels(SequenceGibbsSampler.class);

  // a random number generator
  private static Random random = new Random(2147483647L);
  public static int verbose = 0;

  private List document;
  private int numSamples;
  private int sampleInterval;
  private int speedUpThreshold = -1;
  private SequenceListener listener;
  private static final int RANDOM_SAMPLING = 0;
  private static final int SEQUENTIAL_SAMPLING = 1;
  private static final int CHROMATIC_SAMPLING = 2;

  //debug
  EmpiricalNERPriorBIO priorEn, priorCh = null;


  public boolean returnLastFoundSequence = false;
  private int samplingStyle;
  // determines how many parallel threads to run in chromatic sampling
  private int chromaticSize;
  private List> partition;

  public static int[] copy(int[] a) {
    int[] result = new int[a.length];
    System.arraycopy(a, 0, result, 0, a.length);
    return result;
  }

  public static int[] getRandomSequence(SequenceModel model) {
    int[] result = new int[model.length()];
    for (int i = 0; i < result.length; i++) {
      int[] classes = model.getPossibleValues(i);
      result[i] = classes[random.nextInt(classes.length)];
    }
    return result;
  }

  /**
   * Finds the best sequence by collecting numSamples samples, scoring them, and then choosing
   * the highest scoring sample.
   * @return the array of type int representing the highest scoring sequence
   */
  public int[] bestSequence(SequenceModel model) {
    int[] initialSequence = getRandomSequence(model);
    return findBestUsingSampling(model, numSamples, sampleInterval, initialSequence);
  }

  /**
   * Finds the best sequence by collecting numSamples samples, scoring them, and then choosing
   * the highest scoring sample.
   * @return the array of type int representing the highest scoring sequence
   */
  public int[] findBestUsingSampling(SequenceModel model, int numSamples, int sampleInterval, int[] initialSequence) {
    List samples = collectSamples(model, numSamples, sampleInterval, initialSequence);
    int[] best = null;
    double bestScore = Double.NEGATIVE_INFINITY;
    for (Object sample : samples) {
      int[] sequence = (int[]) sample;
      double score = model.scoreOf(sequence);
      if (score > bestScore) {
        best = sequence;
        bestScore = score;
        log.info("found new best (" + bestScore + ")");
        log.info(ArrayMath.toString(best));
      }
    }
    return best;
  }

  public int[] findBestUsingAnnealing(SequenceModel model, CoolingSchedule schedule) {
    int[] initialSequence = getRandomSequence(model);
    return findBestUsingAnnealing(model, schedule, initialSequence);
  }

  public int[] findBestUsingAnnealing(SequenceModel model, CoolingSchedule schedule, int[] initialSequence) {
    if (verbose>0) log.info("Doing annealing");
    listener.setInitialSequence(initialSequence);
    List result = new ArrayList();
    // so we don't change the initial, or the one we just stored
    int[] sequence = copy(initialSequence);
    int[] best = null;
    double bestScore = Double.NEGATIVE_INFINITY;
    double score = Double.NEGATIVE_INFINITY;
    // if (!returnLastFoundSequence) {
    //   score = model.scoreOf(sequence);
    // }

    Set positionsChanged = null;
    if (speedUpThreshold > 0)
      positionsChanged = Generics.newHashSet();

    for (int i=0; ibestScore) {
          best = sequence;
          bestScore = score;
        }      
      }
      if (i % 50 == 0) {
        if (verbose > 1) log.info("itr " + i + ": " + bestScore + "\t");
      }
      if (verbose>0) log.info(".");
    }
    if (verbose>1) {
      log.info();
      printSamples(result, System.err);
    }
    if (verbose>0) log.info("done.");
    //return sequence;
    return best;
  }

  /**
   * Collects numSamples samples of sequences, from the distribution over sequences defined
   * by the sequence model passed on construction.
   * All samples collected are sampleInterval samples apart, in an attempt to reduce
   * autocorrelation.
   * @return a List containing the sequence samples, as arrays of type int, and their scores
   */
  public List collectSamples(SequenceModel model, int numSamples, int sampleInterval) {
    int[] initialSequence = getRandomSequence(model);
    return collectSamples(model, numSamples, sampleInterval, initialSequence);
  }

  /**
   * Collects numSamples samples of sequences, from the distribution over sequences defined
   * by the sequence model passed on construction.
   * All samples collected are sampleInterval samples apart, in an attempt to reduce
   * autocorrelation.
   * @return a Counter containing the sequence samples, as arrays of type int, and their scores
   */
  public List collectSamples(SequenceModel model, int numSamples, int sampleInterval, int[] initialSequence) {
    if (verbose>0) log.info("Collecting samples");
    listener.setInitialSequence(initialSequence);
    List result = new ArrayList<>();
    int[] sequence = initialSequence;
    for (int i=0; i0) log.info(".");
      System.err.flush();
    }
    if (verbose>1) {
      log.info();
      printSamples(result, System.err);
    }
    if (verbose>0) log.info("done.");
    return result;
  }

  /**
   * Samples the sequence repeatedly, making numSamples passes over the entire sequence.
   */
  public double sampleSequenceRepeatedly(SequenceModel model, int[] sequence, int numSamples) {
    sequence = copy(sequence); // so we don't change the initial, or the one we just stored
    listener.setInitialSequence(sequence);
    double returnScore = Double.NEGATIVE_INFINITY;
    for (int iter=0; iter onlySampleThesePositions) {
    double returnScore = Double.NEGATIVE_INFINITY;
    // log.info("Sampling forward");
    if (onlySampleThesePositions != null) {
      for (int pos: onlySampleThesePositions) {
        returnScore = samplePosition(model, sequence, pos, temperature);
      }
    } else {
      if (samplingStyle == SEQUENTIAL_SAMPLING) {
        for (int pos=0; pos> results = new ArrayList<>();
        for (List indieList: partition) {
          if (indieList.size() <= chromaticSize) {
            for (int pos: indieList) {
              Pair newPosProb = samplePositionHelper(model, sequence, pos, temperature); 
              sequence[pos] = newPosProb.first();
            }
          } else {
            MulticoreWrapper, List>> wrapper = new MulticoreWrapper<>(chromaticSize,
                    new ThreadsafeProcessor, List>>() {
                      @Override
                      public List> process(List posList) {
                        List> allPos = new ArrayList<>(posList.size());
                        Pair newPosProb = null;
                        for (int pos : posList) {
                          newPosProb = samplePositionHelper(model, sequence, pos, temperature);
                          // returns the position to sample in first place and new label in second place
                          allPos.add(new Pair<>(pos, newPosProb.first()));
                        }
                        return allPos;
                      }

                      @Override
                      public ThreadsafeProcessor, List>> newInstance() {
                        return this;
                      }
                    });
            results.clear();
            int interval = Math.max(1, indieList.size() / chromaticSize);
            for (int begin = 0, end = 0, indieListSize = indieList.size(); end < indieListSize; begin += interval) {
              end = Math.min(begin + interval, indieListSize);
              wrapper.put(indieList.subList(begin, end));
              while (wrapper.peek()) {
                results.addAll(wrapper.poll());
              }
            }
            wrapper.join();
            while (wrapper.peek()) {
              results.addAll(wrapper.poll());
            }
            for(Pair posVal : results) {
              sequence[posVal.first()] = posVal.second();
            }
          }
        }
        returnScore = model.scoreOf(sequence);
      }
    }
    return returnScore;
  }

  /**
   * Samples the complete sequence once in the backward direction
   * Destructively modifies the sequence in place.
   * @param sequence the sequence to start with.
   */
  public double sampleSequenceBackward(SequenceModel model, int[] sequence) {
    return sampleSequenceBackward(model, sequence, 1.0);
  }
  /**
   * Samples the complete sequence once in the backward direction
   * Destructively modifies the sequence in place.
   * @param sequence the sequence to start with.
   */
  public double sampleSequenceBackward(SequenceModel model, int[] sequence, double temperature) {
    double returnScore = Double.NEGATIVE_INFINITY;
    for (int pos=sequence.length-1; pos>=0; pos--) {
      returnScore = samplePosition(model, sequence, pos, temperature);
    }
    return returnScore;
  }

  /**
   * Samples a single position in the sequence.
   * Destructively modifies the sequence in place.
   * returns the score of the new sequence
   * @param sequence the sequence to start with
   * @param pos the position to sample.
   */
  public double samplePosition(SequenceModel model, int[] sequence, int pos) {
    return samplePosition(model, sequence, pos, 1.0);
  }

  /**
   * Samples a single position in the sequence.
   * Does not modify the sequence passed in.
   * returns the score of the new label for the position to sample
   * @param sequence the sequence to start with
   * @param pos the position to sample.
   * @param temperature the temperature to control annealing
   */
  private Pair samplePositionHelper(SequenceModel model, int[] sequence, int pos, double temperature) {
    double[] distribution = model.scoresOf(sequence, pos);
    if (temperature!=1.0) {
      if (temperature==0.0) {
        // set the max to 1.0
        int argmax = ArrayMath.argmax(distribution);
        Arrays.fill(distribution, Double.NEGATIVE_INFINITY);
        distribution[argmax] = 0.0;
      } else {
        // take all to a power
        // use the temperature to increase/decrease the entropy of the sampling distribution
        ArrayMath.multiplyInPlace(distribution, 1.0/temperature);
      }
    }
    ArrayMath.logNormalize(distribution);
    ArrayMath.expInPlace(distribution);
    int newTag = ArrayMath.sampleFromDistribution(distribution, random);
    double newProb = distribution[newTag];
    return new Pair<>(newTag, newProb);
  }

  /**
   * Samples a single position in the sequence.
   * Destructively modifies the sequence in place.
   * returns the score of the new sequence
   * @param sequence the sequence to start with
   * @param pos the position to sample.
   * @param temperature the temperature to control annealing
   */
  public double samplePosition(SequenceModel model, int[] sequence, int pos, double temperature) {
    int oldTag = sequence[pos];
    Pair newPosProb = samplePositionHelper(model, sequence, pos, temperature); 
    int newTag = newPosProb.first();
//    System.out.println("Sampled " + oldTag + "->" + newTag);
    sequence[pos] = newTag;
    listener.updateSequenceElement(sequence, pos, oldTag);
    return newPosProb.second();
  }

  public void printSamples(List samples, PrintStream out) {
    for (int i = 0; i < document.size(); i++) {
      HasWord word = (HasWord) document.get(i);
      String s = "null";
      if (word!=null) {
        s = word.word();
      }
      out.print(StringUtils.padOrTrim(s, 10));
      for (Object sample : samples) {
        int[] sequence = (int[]) sample;
        out.print(" " + StringUtils.padLeft(sequence[i], 2));
      }
      out.println();
    }
  }

  /**
   * @param document the underlying document which is a list of HasWord; a slight abstraction violation, but useful for debugging!!
   */
  public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener, List document,
      boolean returnLastFoundSequence, int samplingStyle, int chromaticSize, List> partition, int speedUpThreshold, EmpiricalNERPriorBIO priorEn, EmpiricalNERPriorBIO priorCh) {
    this.numSamples = numSamples;
    this.sampleInterval = sampleInterval;
    this.listener = listener;
    this.document = document;
    this.returnLastFoundSequence = returnLastFoundSequence;
    this.samplingStyle = samplingStyle;
    if (verbose > 0) {
      if (samplingStyle == RANDOM_SAMPLING) {
        log.info("Using random sampling");
      } else if (samplingStyle == CHROMATIC_SAMPLING) {
        log.info("Using chromatic sampling with " + chromaticSize + " threads");
      } else if (samplingStyle == SEQUENTIAL_SAMPLING) {
        log.info("Using sequential sampling");
      }
    }
    this.chromaticSize = chromaticSize;
    this.partition = partition;
    this.speedUpThreshold = speedUpThreshold;
    //debug
    this.priorEn = priorEn;
    this.priorCh = priorCh;
  }

  public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener, List document) {
    this(numSamples, sampleInterval, listener, document, false, 1, 0, null, -1, null, null);
  }

  public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener) {
    this(numSamples, sampleInterval, listener, null);
  }

  public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener,
      int samplingStyle, int chromaticSize, List> partition, int speedUpThreshold, EmpiricalNERPriorBIO priorEn, EmpiricalNERPriorBIO priorCh) {
    this(numSamples, sampleInterval, listener, null, false, samplingStyle, chromaticSize, partition, speedUpThreshold, priorEn, priorCh);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy