All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.classify.Dataset Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.classify;

import java.io.PrintWriter;
import java.io.StringWriter;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ScoredComparator;
import edu.stanford.nlp.util.ScoredObject;


import edu.stanford.nlp.util.logging.Redwood;


/**
 * An interfacing class for {@link ClassifierFactory} that incrementally
 * builds a more memory-efficient representation of a {@link List} of
 * {@link Datum} objects for the purposes of training a {@link Classifier}
 * with a {@link ClassifierFactory}.
 *
 * @author Roger Levy ([email protected])
 * @author Anna Rafferty (various refactoring with GeneralDataset/RVFDataset)
 * @author Sarah Spikes ([email protected]) (templatization)
 * @author [email protected] {@link #getL1NormalizedTFIDFDatum(Datum, Counter) and #getL1NormalizedTFIDFDataset()}
 *
 * @param  Label type
 * @param  Feature type
 */
public class Dataset extends GeneralDataset {

  private static final long serialVersionUID = -3883164942879961091L;

  final static Redwood.RedwoodChannels logger = Redwood.channels(Dataset.class);

  public Dataset() {
    this(10);
  }

  public Dataset(int numDatums) {
    initialize(numDatums);
  }

  public Dataset(int numDatums, Index featureIndex, Index labelIndex) {
    initialize(numDatums);
    this.featureIndex = featureIndex;
    this.labelIndex = labelIndex;
  }

  public Dataset(Index featureIndex, Index labelIndex) {
    this(10, featureIndex, labelIndex);
  }


  /**
   * Constructor that fully specifies a Dataset.  Needed this for MulticlassDataset.
   */
  public Dataset(Index labelIndex, int[] labels, Index featureIndex, int[][] data) {
    this (labelIndex, labels, featureIndex, data, data.length);
  }

  /**
   * Constructor that fully specifies a Dataset.  Needed this for MulticlassDataset.
   */
  public Dataset(Index labelIndex, int[] labels, Index featureIndex, int[][] data, int size) {
    this.labelIndex = labelIndex;
    this.labels = labels;
    this.featureIndex = featureIndex;
    this.data = data;
    this.size = size;
  }

  /** {@inheritDoc} */
  @Override
  public Pair, GeneralDataset> split(double percentDev) {
    return split(0, (int)(percentDev * size()));
  }

  /** {@inheritDoc} */
  @Override
  public Pair,GeneralDataset> split(int start, int end) {
    int devSize = end - start;
    int trainSize = size() - devSize;

    int[][] devData = new int[devSize][];
    int[] devLabels = new int[devSize];

    int[][] trainData = new int[trainSize][];
    int[] trainLabels = new int[trainSize];

    synchronized (System.class) {
      System.arraycopy(data, start, devData, 0, devSize);
      System.arraycopy(labels, start, devLabels, 0, devSize);

      System.arraycopy(data, 0, trainData, 0, start);
      System.arraycopy(data, end, trainData, start, size() - end);
      System.arraycopy(labels, 0, trainLabels, 0, start);
      System.arraycopy(labels, end, trainLabels, start, size() - end);
    }

    if (this instanceof WeightedDataset) {
      float[] trainWeights = new float[trainSize];
      float[] devWeights = new float[devSize];

      WeightedDataset w = (WeightedDataset)this;

      synchronized (System.class) {
        System.arraycopy(w.weights, start, devWeights, 0, devSize);
        System.arraycopy(w.weights, 0, trainWeights, 0, start);
        System.arraycopy(w.weights, end, trainWeights, start, size() - end);
      }

      WeightedDataset dev = new WeightedDataset<>(labelIndex, devLabels, featureIndex, devData, devSize, devWeights);
      WeightedDataset train = new WeightedDataset<>(labelIndex, trainLabels, featureIndex, trainData, trainSize, trainWeights);

      return new Pair<>(train, dev);
    }
    Dataset dev = new Dataset<>(labelIndex, devLabels, featureIndex, devData, devSize);
    Dataset train = new Dataset<>(labelIndex, trainLabels, featureIndex, trainData, trainSize);

    return new Pair<>(train, dev);
  }


  public Dataset getRandomSubDataset(double p, int seed) {
    int newSize = (int)(p * size());
    Set indicesToKeep = Generics.newHashSet();
    Random r = new Random(seed);
    int s = size();
    while (indicesToKeep.size() < newSize) {
      indicesToKeep.add(r.nextInt(s));
    }

    int[][] newData = new int[newSize][];
    int[] newLabels = new int[newSize];

    int i = 0;
    for (int j : indicesToKeep) {
      newData[i] = data[j];
      newLabels[i] = labels[j];
      i++;
    }

    return new Dataset<>(labelIndex, newLabels, featureIndex, newData);
  }

  @Override
  public double[][] getValuesArray() {
    return null;
  }

  /**
   * Constructs a Dataset by reading in a file in SVM light format.
   */
  public static Dataset readSVMLightFormat(String filename) {
    return readSVMLightFormat(filename, new HashIndex<>(), new HashIndex<>());
  }

  /**
   * Constructs a Dataset by reading in a file in SVM light format.
   * The lines parameter is filled with the lines of the file for further processing
   * (if lines is null, it is assumed no line information is desired)
   */
  public static Dataset readSVMLightFormat(String filename, List lines) {
    return readSVMLightFormat(filename, new HashIndex<>(), new HashIndex<>(), lines);
  }

  /**
   * Constructs a Dataset by reading in a file in SVM light format.
   * the created dataset has the same feature and label index as given
   */
  public static Dataset readSVMLightFormat(String filename, Index featureIndex, Index labelIndex) {
    return readSVMLightFormat(filename, featureIndex, labelIndex, null);
  }
  /**
   * Constructs a Dataset by reading in a file in SVM light format.
   * the created dataset has the same feature and label index as given
   */
  public static Dataset readSVMLightFormat(String filename, Index featureIndex, Index labelIndex, List lines) {
    Dataset dataset;
    try {
      dataset = new Dataset<>(10, featureIndex, labelIndex);
      for (String line : ObjectBank.getLineIterator(new File(filename))) {
        if(lines != null)
          lines.add(line);
        dataset.add(svmLightLineToDatum(line));
      }

    } catch (Exception e) {
      throw new RuntimeException(e);
    }
    return dataset;
  }

  private static int line1 = 0;

  public static Datum svmLightLineToDatum(String l) {
    line1++;
    l = l.replaceAll("#.*", ""); // remove any trailing comments
    String[] line = l.split("\\s+");
    Collection features = new ArrayList<>();
    for (int i = 1; i < line.length; i++) {
      String[] f = line[i].split(":");
      if (f.length != 2) {
        logger.info("Dataset error: line " + line1);
      }
      int val = (int) Double.parseDouble(f[1]);
      for (int j = 0; j < val; j++) {
        features.add(f[0]);
      }
    }
    features.add(String.valueOf(Integer.MAX_VALUE));  // a constant feature for a class
    Datum d = new BasicDatum<>(features, line[0]);
    return d;
  }

  /**
   *  Get Number of datums a given feature appears in.
   */
  public Counter getFeatureCounter()
  {
    Counter featureCounts = new ClassicCounter<>();
    for (int i=0; i < this.size(); i++)
    {
      BasicDatum datum = (BasicDatum) getDatum(i);
      Set featureSet   = Generics.newHashSet(datum.asFeatures());
      for (F key : featureSet) {
        featureCounts.incrementCount(key, 1.0);
      }
    }
    return featureCounts;
  }

  /**
   * Method to convert features from counts to L1-normalized TFIDF based features
   * @param datum with a collection of features.
   * @param featureDocCounts a counter of doc-count for each feature.
   * @return RVFDatum with l1-normalized tf-idf features.
   */
  public RVFDatum getL1NormalizedTFIDFDatum(Datum datum,Counter featureDocCounts){
      Counter tfidfFeatures = new ClassicCounter<>();
      for(F feature : datum.asFeatures()){
        if(featureDocCounts.containsKey(feature))
          tfidfFeatures.incrementCount(feature,1.0);
      }
      double l1norm = 0;
      for(F feature: tfidfFeatures.keySet()){
        double idf = Math.log(((double)(this.size()+1))/(featureDocCounts.getCount(feature)+0.5));
        double tf = tfidfFeatures.getCount(feature);
        tfidfFeatures.setCount(feature, tf*idf);
        l1norm += tf*idf;
      }
      for(F feature: tfidfFeatures.keySet()){
        double tfidf = tfidfFeatures.getCount(feature);
        tfidfFeatures.setCount(feature, tfidf/l1norm);
      }
      RVFDatum rvfDatum = new RVFDatum<>(tfidfFeatures, datum.label());
      return rvfDatum;
  }

  /**
   * Method to convert this dataset to RVFDataset using L1-normalized TF-IDF features
   * @return RVFDataset
   */
  public RVFDataset getL1NormalizedTFIDFDataset(){
    RVFDataset rvfDataset = new RVFDataset<>(this.size(), this.featureIndex, this.labelIndex);
    Counter featureDocCounts = getFeatureCounter();
    for(int i = 0; i < this.size(); i++){
      Datum datum = this.getDatum(i);
      RVFDatum rvfDatum = getL1NormalizedTFIDFDatum(datum,featureDocCounts);
      rvfDataset.add(rvfDatum);
    }
    return rvfDataset;
  }

  @Override
  public void add(Datum d) {
    add(d.asFeatures(), d.label());
  }

  public void add(Collection features, L label) {
    add(features, label, true);
  }

  public void add(Collection features, L label, boolean addNewFeatures) {
    ensureSize();
    addLabel(label);
    addFeatures(features, addNewFeatures);
    size++;
  }

  /**
   * Adds a datums defined by feature indices and label index
   * Careful with this one! Make sure that all indices are valid!
   * @param features
   * @param label
   */
  public void add(int [] features, int label) {
    ensureSize();
    addLabelIndex(label);
    addFeatureIndices(features);
    size++;
  }

  protected void ensureSize() {
    if (labels.length == size) {
      int[] newLabels = new int[size * 2];
      int[][] newData = new int[size * 2][];
      synchronized (System.class) {
        System.arraycopy(labels, 0, newLabels, 0, size);
        System.arraycopy(data, 0, newData, 0, size);
      }
      labels = newLabels;
      data = newData;
    }
  }

  protected void addLabel(L label) {
    labelIndex.add(label);
    labels[size] = labelIndex.indexOf(label);
  }

  protected void addLabelIndex(int label) {
    labels[size] = label;
  }

  protected void addFeatures(Collection features) {
    addFeatures(features, true);
  }

  protected void addFeatures(Collection features, boolean addNewFeatures) {
    int[] intFeatures = new int[features.size()];
    int j = 0;
    for (F feature : features) {
      if(addNewFeatures) featureIndex.add(feature);
      int index = featureIndex.indexOf(feature);
      if (index >= 0) {
        intFeatures[j] = featureIndex.indexOf(feature);
        j++;
      }
    }
    data[size] = new int[j];
    synchronized (System.class) {
      System.arraycopy(intFeatures, 0, data[size], 0, j);
    }
  }

  protected void addFeatureIndices(int [] features) {
    data[size] = features;
  }

  @Override
  protected final void initialize(int numDatums) {
    labelIndex = new HashIndex<>();
    featureIndex = new HashIndex<>();
    labels = new int[numDatums];
    data = new int[numDatums][];
    size = 0;
  }

  /**
   * @return the index-ed datum
   */
  @Override
  public Datum getDatum(int index) {
    return new BasicDatum<>(featureIndex.objects(data[index]), labelIndex.get(labels[index]));
  }

  /**
   * @return the index-ed datum
   */
  @Override
  public RVFDatum getRVFDatum(int index) {
     ClassicCounter c = new ClassicCounter<>();
    for (F key : featureIndex.objects(data[index])) {
      c.incrementCount(key);
    }
    return new RVFDatum<>(c, labelIndex.get(labels[index]));
  }

  /**
   * Prints some summary statistics to stderr for the Dataset.
   */
  @Override
  public void summaryStatistics() {
    logger.info(toSummaryStatistics());
  }

  /** A String that is multiple lines of text giving summary statistics.
   *  (It does not end with a newline, though.)
   *
   *  @return A textual summary of the Dataset
   */
  public String toSummaryStatistics() {
    StringBuilder sb = new StringBuilder();
    sb.append("numDatums: ").append(size).append('\n');
    sb.append("numDatumsPerLabel: ").append(this.numDatumsPerLabel()).append('\n');
    sb.append("numLabels: ").append(labelIndex.size()).append(" [");
    Iterator iter = labelIndex.iterator();
    while (iter.hasNext()) {
      sb.append(iter.next());
      if (iter.hasNext()) {
        sb.append(", ");
      }
    }
    sb.append("]\n");
    sb.append("numFeatures (Phi(X) types): ").append(featureIndex.size()).append(" [");
    int sz = Math.min(5, featureIndex.size());
    for (int i = 0; i < sz; i++) {
      if (i > 0) {
        sb.append(", ");
      }
      sb.append(featureIndex.get(i));
    }
    if (sz < featureIndex.size()) {
      sb.append(", ...");
    }
    sb.append(']');
    return sb.toString();
  }


  /**
   * Applies feature count thresholds to the Dataset.
   * Only features that match pattern_i and occur at
   * least threshold_i times (for some i) are kept.
   *
   * @param thresholds a list of pattern, threshold pairs
   */
  public void applyFeatureCountThreshold(List> thresholds) {

    // get feature counts
    float[] counts = getFeatureCounts();

    // build a new featureIndex
    Index newFeatureIndex = new HashIndex<>();
    LOOP:
    for (F f : featureIndex) {
      for (Pair threshold : thresholds) {
        Pattern p = threshold.first();
        Matcher m = p.matcher(f.toString());
        if (m.matches()) {
          if (counts[featureIndex.indexOf(f)] >= threshold.second) {
            newFeatureIndex.add(f);
          }
          continue LOOP;
        }
      }
      // we only get here if it didn't match anything on the list
      newFeatureIndex.add(f);
    }

    counts = null;

    int[] featMap = new int[featureIndex.size()];
    for (int i = 0; i < featMap.length; i++) {
      featMap[i] = newFeatureIndex.indexOf(featureIndex.get(i));
    }

    featureIndex = null;

    for (int i = 0; i < size; i++) {
      List featList = new ArrayList<>(data[i].length);
      for (int j = 0; j < data[i].length; j++) {
        if (featMap[data[i][j]] >= 0) {
          featList.add(featMap[data[i][j]]);
        }
      }
      data[i] = new int[featList.size()];
      for (int j = 0; j < data[i].length; j++) {
        data[i][j] = featList.get(j);
      }
    }

    featureIndex = newFeatureIndex;
  }


  /**
   * prints the full feature matrix in tab-delimited form.  These can be BIG
   * matrices, so be careful!
   */
  public void printFullFeatureMatrix(PrintWriter pw) {
    String sep = "\t";
    for (int i = 0; i < featureIndex.size(); i++) {
      pw.print(sep + featureIndex.get(i));
    }
    pw.println();
    for (int i = 0; i < labels.length; i++) {
      pw.print(labelIndex.get(i));
      Set feats = Generics.newHashSet();
      for (int j = 0; j < data[i].length; j++) {
        int feature = data[i][j];
        feats.add(Integer.valueOf(feature));
      }
      for (int j = 0; j < featureIndex.size(); j++) {
        if (feats.contains(Integer.valueOf(j))) {
          pw.print(sep + '1');
        } else {
          pw.print(sep + '0');
        }
      }
    }
  }

  /** {@inheritDoc} */
  @Override
  public void printSparseFeatureMatrix() {
    printSparseFeatureMatrix(new PrintWriter(System.out, true));
  }

  /** {@inheritDoc} */
  @Override
  public void printSparseFeatureMatrix(PrintWriter pw) {
    String sep = "\t";
    for (int i = 0; i < size; i++) {
      pw.print(labelIndex.get(labels[i]));
      int[] datum = data[i];
      for (int j : datum) {
        pw.print(sep + featureIndex.get(j));
      }
      pw.println();
    }
  }


  public void changeLabelIndex(Index newLabelIndex) {

    labels = trimToSize(labels);

    for (int i = 0; i < labels.length; i++) {
      labels[i] = newLabelIndex.indexOf(labelIndex.get(labels[i]));
    }
    labelIndex = newLabelIndex;
  }

  public void changeFeatureIndex(Index newFeatureIndex) {

    data = trimToSize(data);
    labels = trimToSize(labels);

    int[][] newData = new int[data.length][];
    for (int i = 0; i < data.length; i++) {
      int[] newD = new int[data[i].length];
      int k = 0;
      for (int j = 0; j < data[i].length; j++) {
        int newIndex = newFeatureIndex.indexOf(featureIndex.get(data[i][j]));
        if (newIndex >= 0) {
          newD[k++] = newIndex;
        }
      }
      newData[i] = new int[k];
      synchronized (System.class) {
        System.arraycopy(newD, 0, newData[i], 0, k);
      }
    }
    data = newData;
    featureIndex = newFeatureIndex;
  }

  public void selectFeaturesBinaryInformationGain(int numFeatures) {
    double[] scores = getInformationGains();
    selectFeatures(numFeatures,scores);
  }

  /**
   * Generic method to select features based on the feature scores vector provided as an argument.
   * @param numFeatures number of features to be selected.
   * @param scores a vector of size total number of features in the data.
   */
  public void selectFeatures(int numFeatures, double[] scores) {

    List> scoredFeatures = new ArrayList<>();

    for (int i = 0; i < scores.length; i++) {
      scoredFeatures.add(new ScoredObject<>(featureIndex.get(i), scores[i]));
    }

    Collections.sort(scoredFeatures, ScoredComparator.DESCENDING_COMPARATOR);
    Index newFeatureIndex = new HashIndex<>();
    for (int i = 0; i < scoredFeatures.size() && i < numFeatures; i++) {
      newFeatureIndex.add(scoredFeatures.get(i).object());
      //logger.info(scoredFeatures.get(i));
    }

    for (int i = 0; i < size; i++) {
      int[] newData = new int[data[i].length];
      int curIndex = 0;
      for (int j = 0; j < data[i].length; j++) {
        int index;
        if ((index = newFeatureIndex.indexOf(featureIndex.get(data[i][j]))) != -1) {
          newData[curIndex++] = index;
        }
      }
      int[] newDataTrimmed = new int[curIndex];
      synchronized (System.class) {
        System.arraycopy(newData, 0, newDataTrimmed, 0, curIndex);
      }
      data[i] = newDataTrimmed;
    }
    featureIndex = newFeatureIndex;
  }


  public double[] getInformationGains() {

//    assert size > 0;
//    data = trimToSize(data);  // Don't need to trim to size, and trimming is dangerous the dataset is empty (you can't add to it thereafter)
    labels = trimToSize(labels);

    // counts the number of times word X is present
    ClassicCounter featureCounter = new ClassicCounter<>();

    // counts the number of time a document has label Y
    ClassicCounter labelCounter = new ClassicCounter<>();

    // counts the number of times the document has label Y given word X is present
    TwoDimensionalCounter condCounter = new TwoDimensionalCounter<>();

    for (int i = 0; i < labels.length; i++) {
      labelCounter.incrementCount(labelIndex.get(labels[i]));

      // convert the document to binary feature representation
      boolean[] doc = new boolean[featureIndex.size()];
      //logger.info(i);
      for (int j = 0; j < data[i].length; j++) {
        doc[data[i][j]] = true;
      }

      for (int j = 0; j < doc.length; j++) {
        if (doc[j]) {
          featureCounter.incrementCount(featureIndex.get(j));
          condCounter.incrementCount(featureIndex.get(j), labelIndex.get(labels[i]), 1.0);
        }
      }
    }

    double entropy = 0.0;
    for (int i = 0; i < labelIndex.size(); i++) {
      double labelCount = labelCounter.getCount(labelIndex.get(i));
      double p = labelCount / size();
      entropy -= p * (Math.log(p) / Math.log(2));
    }

    double[] ig = new double[featureIndex.size()];
    Arrays.fill(ig, entropy);

    for (int i = 0; i < featureIndex.size(); i++) {
      F feature = featureIndex.get(i);

      double featureCount = featureCounter.getCount(feature);
      double notFeatureCount = size() - featureCount;

      double pFeature =  featureCount / size();
      double pNotFeature = (1.0 - pFeature);

      if (featureCount == 0) { ig[i] = 0; continue; }
      if (notFeatureCount == 0) { ig[i] = 0; continue; }

      double sumFeature = 0.0;
      double sumNotFeature = 0.0;

      for (int j = 0; j < labelIndex.size(); j++) {
        L label = labelIndex.get(j);

        double featureLabelCount = condCounter.getCount(feature, label);
        double notFeatureLabelCount = size() - featureLabelCount;

        // yes, these dont sum to 1.  that is correct.
        // one is the prob of the label, given that the
        // feature is present, and the other is the prob
        // of the label given that the feature is absent
        double p = featureLabelCount / featureCount;
        double pNot = notFeatureLabelCount / notFeatureCount;

        if (featureLabelCount != 0) {
          sumFeature += p * (Math.log(p) / Math.log(2));
        }

        if (notFeatureLabelCount != 0) {
          sumNotFeature += pNot * (Math.log(pNot) / Math.log(2));
        }
        //System.out.println(pNot+" "+(Math.log(pNot)/Math.log(2)));

      }

        //logger.info(pFeature+" * "+sumFeature+" = +"+);
        //logger.info("^ "+pNotFeature+" "+sumNotFeature);

      ig[i] += pFeature*sumFeature + pNotFeature*sumNotFeature;
      /* earlier the line above used to be: ig[i] = pFeature*sumFeature + pNotFeature*sumNotFeature;
       * This completely ignored the entropy term computed above. So added the "+=" to take that into account.
       * -Ramesh ([email protected])
       */
    }
    return ig;
  }

  public void updateLabels(int[] labels) {
    if (labels.length != size())
      throw new IllegalArgumentException(
          "size of labels array does not match dataset size");

    this.labels = labels;
  }

  @Override
  public String toString() {
    return "Dataset of size " + size;
  }

  public String toSummaryString() {
    StringWriter sw = new StringWriter();
    PrintWriter pw = new PrintWriter(sw);
    pw.println("Number of data points: " + size());
    pw.println("Number of active feature tokens: " + numFeatureTokens());
    pw.println("Number of active feature types:" + numFeatureTypes());
    return pw.toString();
  }

  /**
   * Need to sort the counter by feature keys and dump it
   *
   */
  public static void printSVMLightFormat(PrintWriter pw, ClassicCounter c, int classNo) {
    Integer[] features = c.keySet().toArray(new Integer[c.keySet().size()]);
    Arrays.sort(features);
    StringBuilder sb = new StringBuilder();
    sb.append(classNo);
    sb.append(' ');
    for (int f: features) {
      sb.append(f + 1).append(':').append(c.getCount(f)).append(' ');
    }
    pw.println(sb.toString());
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy