edu.stanford.nlp.classify.WeightedDataset Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of stanford-corenlp Show documentation
Show all versions of stanford-corenlp Show documentation
Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.
package edu.stanford.nlp.classify;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.util.Index;
import java.util.Collection;
import java.util.List;
import java.util.Random;
/**
* @author Galen Andrew
* @author Sarah Spikes ([email protected]) (Templatization)
*/
public class WeightedDataset extends Dataset {
private static final long serialVersionUID = -5435125789127705430L;
protected float[] weights;
public WeightedDataset(Index labelIndex, int[] labels, Index featureIndex, int[][] data, int size, float[] weights) {
super(labelIndex, labels, featureIndex, data, size);
this.weights = weights;
}
public WeightedDataset() {
this(10);
}
public WeightedDataset(int initSize) {
super(initSize);
weights = new float[initSize];
}
private float[] trimToSize(float[] i) {
float[] newI = new float[size];
synchronized (System.class) {
System.arraycopy(i, 0, newI, 0, size);
}
return newI;
}
public float[] getWeights() {
weights = trimToSize(weights);
return weights;
}
@Override
public float[] getFeatureCounts() {
float[] counts = new float[featureIndex.size()];
for (int i = 0, m = size; i < m; i++) {
for (int j = 0, n = data[i].length; j < n; j++) {
counts[data[i][j]] += weights[i];
}
}
return counts;
}
@Override
public void add(Datum d) {
add(d, 1.0f);
}
@Override
public void add(Collection features, L label) {
add(features, label, 1.0f);
}
public void add(Datum d, float weight) {
add(d.asFeatures(), d.label(), weight);
}
@Override
protected void ensureSize() {
super.ensureSize();
if (weights.length == size) {
float[] newWeights = new float[size * 2];
synchronized (System.class) {
System.arraycopy(weights, 0, newWeights, 0, size);
}
weights = newWeights;
}
}
public void add(Collection features, L label, float weight) {
ensureSize();
addLabel(label);
addFeatures(features);
weights[size++] = weight;
}
/**
* Set the weight of datum i.
* @param i The index of the datum to change the weight of.
* @param weight The weight to set
*/
public void setWeight(int i, float weight) {
weights[i] = weight;
}
/**
* Randomizes (shuffles) the data array in place.
* Needs to be redefined here because we need to randomize the weights as well.
*/
@Override
public void randomize(long randomSeed) {
Random rand = new Random(randomSeed);
for(int j = size - 1; j > 0; j --){
int randIndex = rand.nextInt(j);
int [] tmp = data[randIndex];
data[randIndex] = data[j];
data[j] = tmp;
int tmpL = labels[randIndex];
labels[randIndex] = labels[j];
labels[j] = tmpL;
float tmpW = weights[randIndex];
weights[randIndex] = weights[j];
weights[j] = tmpW;
}
}
/**
* Randomizes (shuffles) the data array in place.
* Needs to be redefined here because we need to randomize the weights as well.
*/
@Override
public void shuffleWithSideInformation(long randomSeed, List sideInformation) {
if (size != sideInformation.size()) {
throw new IllegalArgumentException("shuffleWithSideInformation: sideInformation not of same size as Dataset");
}
Random rand = new Random(randomSeed);
for(int j = size - 1; j > 0; j --){
int randIndex = rand.nextInt(j);
int [] tmp = data[randIndex];
data[randIndex] = data[j];
data[j] = tmp;
int tmpL = labels[randIndex];
labels[randIndex] = labels[j];
labels[j] = tmpL;
float tmpW = weights[randIndex];
weights[randIndex] = weights[j];
weights[j] = tmpW;
E tmpE = sideInformation.get(randIndex);
sideInformation.set(randIndex, sideInformation.get(j));
sideInformation.set(j, tmpE);
}
}
}