All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.classify.WeightedDataset Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.util.Index;

import java.util.Collection;
import java.util.List;
import java.util.Random;

/**
 * @author Galen Andrew
 * @author Sarah Spikes ([email protected]) (Templatization)
 */
public class WeightedDataset extends Dataset {

  private static final long serialVersionUID = -5435125789127705430L;

  protected float[] weights;

  public WeightedDataset(Index labelIndex, int[] labels, Index featureIndex, int[][] data, int size, float[] weights) {
    super(labelIndex, labels, featureIndex, data, size);
    this.weights = weights;
  }

  public WeightedDataset() {
    this(10);
  }

  public WeightedDataset(int initSize) {
    super(initSize);
    weights = new float[initSize];
  }

  private float[] trimToSize(float[] i) {
    float[] newI = new float[size];
    synchronized (System.class) {
      System.arraycopy(i, 0, newI, 0, size);
    }
    return newI;
  }

  public float[] getWeights() {
    weights = trimToSize(weights);
    return weights;
  }

  @Override
  public float[] getFeatureCounts() {
    float[] counts = new float[featureIndex.size()];
    for (int i = 0, m = size; i < m; i++) {
      for (int j = 0, n = data[i].length; j < n; j++) {
        counts[data[i][j]] += weights[i];
      }
    }
    return counts;
  }

  @Override
  public void add(Datum d) {
    add(d, 1.0f);
  }

  @Override
  public void add(Collection features, L label) {
    add(features, label, 1.0f);
  }

  public void add(Datum d, float weight) {
    add(d.asFeatures(), d.label(), weight);
  }

  @Override
  protected void ensureSize() {
    super.ensureSize();
    if (weights.length == size) {
      float[] newWeights = new float[size * 2];
      synchronized (System.class) {
        System.arraycopy(weights, 0, newWeights, 0, size);
      }
      weights = newWeights;
    }
  }

  public void add(Collection features, L label, float weight) {
    ensureSize();
    addLabel(label);
    addFeatures(features);
    weights[size++] = weight;
  }

  /**
   * Set the weight of datum i.
   * @param i The index of the datum to change the weight of.
   * @param weight The weight to set
   */
  public void setWeight(int i, float weight) {
    weights[i] = weight;
  }

  /**
   * Randomizes (shuffles) the data array in place.
   * Needs to be redefined here because we need to randomize the weights as well.
   */
  @Override
  public void randomize(long randomSeed) {
    Random rand = new Random(randomSeed);
    for(int j = size - 1; j > 0; j --){
      int randIndex = rand.nextInt(j);

      int [] tmp = data[randIndex];
      data[randIndex] = data[j];
      data[j] = tmp;

      int tmpL = labels[randIndex];
      labels[randIndex] = labels[j];
      labels[j] = tmpL;

      float tmpW = weights[randIndex];
      weights[randIndex] = weights[j];
      weights[j] = tmpW;
    }
  }

  /**
   * Randomizes (shuffles) the data array in place.
   * Needs to be redefined here because we need to randomize the weights as well.
   */
  @Override
  public  void shuffleWithSideInformation(long randomSeed, List sideInformation) {
    if (size != sideInformation.size()) {
      throw new IllegalArgumentException("shuffleWithSideInformation: sideInformation not of same size as Dataset");
    }
    Random rand = new Random(randomSeed);
    for(int j = size - 1; j > 0; j --){
      int randIndex = rand.nextInt(j);

      int [] tmp = data[randIndex];
      data[randIndex] = data[j];
      data[j] = tmp;

      int tmpL = labels[randIndex];
      labels[randIndex] = labels[j];
      labels[j] = tmpL;

      float tmpW = weights[randIndex];
      weights[randIndex] = weights[j];
      weights[j] = tmpW;

      E tmpE = sideInformation.get(randIndex);
      sideInformation.set(randIndex, sideInformation.get(j));
      sideInformation.set(j, tmpE);
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy