
cc.mallet.classify.FeatureConstraintUtil Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
/* Copyright (C) 2009 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.InfoGain;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
/**
* Utility functions for creating feature constraints that can be used with GE training.
* @author Gregory Druck [email protected]
*/
public class FeatureConstraintUtil {
private static Logger logger = MalletLogger.getLogger(FeatureConstraintUtil.class.getName());
/**
* Reads range constraints stored using strings from a file. Format can be either:
*
* feature_name (label_name:lower_probability,upper_probability)+
*
* or
*
* feature_name (label_name:probability)+
*
* Constraints are only added for feature-label pairs that are present.
*
* @param filename File with feature constraints.
* @param data InstanceList used for alphabets.
* @return Constraints.
*/
public static HashMap readRangeConstraintsFromFile(String filename, InstanceList data) {
HashMap constraints = new HashMap();
for (int li = 0; li < data.getTargetAlphabet().size(); li++) {
System.err.println(data.getTargetAlphabet().lookupObject(li));
}
try {
BufferedReader reader = new BufferedReader(new FileReader(filename));
String line = reader.readLine();
while (line != null) {
String[] split = line.split("\\s+");
// assume the feature name has no spaces
String featureName = split[0];
int featureIndex = data.getDataAlphabet().lookupIndex(featureName,false);
if (featureIndex == -1) {
throw new RuntimeException("Feature " + featureName + " not found in the alphabet!");
}
double[][] probs = new double[data.getTargetAlphabet().size()][2];
for (int i = 0; i < probs.length; i++) Arrays.fill(probs[i ],Double.NEGATIVE_INFINITY);
for (int index = 1; index < split.length; index++) {
String[] labelSplit = split[index].split(":");
int li = data.getTargetAlphabet().lookupIndex(labelSplit[0],false);
assert (li != -1) : labelSplit[0];
if (labelSplit[1].contains(",")) {
String[] rangeSplit = labelSplit[1].split(",");
double lower = Double.parseDouble(rangeSplit[0]);
double upper = Double.parseDouble(rangeSplit[1]);
probs[li][0] = lower;
probs[li][1] = upper;
}
else {
double prob = Double.parseDouble(labelSplit[1]);
probs[li][0] = prob;
probs[li][1] = prob;
}
}
constraints.put(featureIndex, probs);
line = reader.readLine();
}
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
return constraints;
}
/**
* Reads feature constraints from a file, whether they are stored
* using Strings or indices.
*
* @param filename File with feature constraints.
* @param data InstanceList used for alphabets.
* @return Constraints.
*/
public static HashMap readConstraintsFromFile(String filename, InstanceList data) {
if (testConstraintsFileIndexBased(filename)) {
return readConstraintsFromFileIndex(filename,data);
}
return readConstraintsFromFileString(filename,data);
}
/**
* Reads feature constraints stored using strings from a file.
*
* feature_name (label_name:probability)+
*
* Labels that do appear get probability 0.
*
* @param filename File with feature constraints.
* @param data InstanceList used for alphabets.
* @return Constraints.
*/
public static HashMap readConstraintsFromFileString(String filename, InstanceList data) {
HashMap constraints = new HashMap();
File file = new File(filename);
try {
BufferedReader reader = new BufferedReader(new FileReader(file));
String line = reader.readLine();
while (line != null) {
String[] split = line.split("\\s+");
// assume the feature name has no spaces
String featureName = split[0];
int featureIndex = data.getDataAlphabet().lookupIndex(featureName,false);
assert(split.length - 1 == data.getTargetAlphabet().size()) : split.length + " " + data.getTargetAlphabet().size();
double[] probs = new double[split.length - 1];
for (int index = 1; index < split.length; index++) {
String[] labelSplit = split[index].split(":");
int li = data.getTargetAlphabet().lookupIndex(labelSplit[0],false);
assert(li != -1) : "Label " + labelSplit[0] + " not found";
double prob = Double.parseDouble(labelSplit[1]);
probs[li] = prob;
}
constraints.put(featureIndex, probs);
line = reader.readLine();
}
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
return constraints;
}
/**
* Reads feature constraints stored using strings from a file.
*
* feature_index label_0_prob label_1_prob ... label_n_prob
*
* Here each label must appear.
*
* @param filename File with feature constraints.
* @param data InstanceList used for alphabets.
* @return Constraints.
*/
public static HashMap readConstraintsFromFileIndex(String filename, InstanceList data) {
HashMap constraints = new HashMap();
File file = new File(filename);
try {
BufferedReader reader = new BufferedReader(new FileReader(file));
String line = reader.readLine();
while (line != null) {
String[] split = line.split("\\s+");
int featureIndex = Integer.parseInt(split[0]);
assert(split.length - 1 == data.getTargetAlphabet().size());
double[] probs = new double[split.length - 1];
for (int index = 1; index < split.length; index++) {
double prob = Double.parseDouble(split[index]);
probs[index-1] = prob;
}
constraints.put(featureIndex, probs);
line = reader.readLine();
}
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
return constraints;
}
private static boolean testConstraintsFileIndexBased(String filename) {
File file = new File(filename);
String firstLine = "";
try {
BufferedReader reader = new BufferedReader(new FileReader(file));
firstLine = reader.readLine();
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
return !firstLine.contains(":");
}
/**
* Select features with the highest information gain.
*
* @param list InstanceList for computing information gain.
* @param numFeatures Number of features to select.
* @return List of features with the highest information gains.
*/
public static ArrayList selectFeaturesByInfoGain(InstanceList list, int numFeatures) {
ArrayList features = new ArrayList();
InfoGain infogain = new InfoGain(list);
for (int rank = 0; rank < numFeatures; rank++) {
features.add(infogain.getIndexAtRank(rank));
}
return features;
}
/**
* Select top features in LDA topics.
*
* @param numSelFeatures Number of features to select.
* @param ldaEst LDAEstimatePr which provides an interface to an LDA model.
* @param seqAlphabet The alphabet for the sequence dataset, which may be different from the vector dataset alphabet.
* @param alphabet The vector dataset alphabet.
* @return ArrayList with the int indices of the selected features.
*/
public static ArrayList selectTopLDAFeatures(int numSelFeatures, ParallelTopicModel lda, Alphabet alphabet) {
ArrayList features = new ArrayList();
Alphabet seqAlphabet = lda.getAlphabet();
int numTopics = lda.getNumTopics();
Object[][] sorted = lda.getTopWords(seqAlphabet.size());
for (int pos = 0; pos < seqAlphabet.size(); pos++) {
for (int ti = 0; ti < numTopics; ti++) {
Object feat = sorted[ti][pos].toString();
int fi = alphabet.lookupIndex(feat,false);
if ((fi >=0) && (!features.contains(fi))) {
logger.info("Selected feature: " + feat);
features.add(fi);
if (features.size() == numSelFeatures) {
return features;
}
}
}
}
return features;
}
public static HashMap setTargetsUsingData(InstanceList list, ArrayList features) {
return setTargetsUsingData(list,features,true);
}
public static HashMap setTargetsUsingData(InstanceList list, ArrayList features, boolean normalize) {
return setTargetsUsingData(list,features,false,normalize);
}
/**
* Set target distributions using estimates from data.
*
* @param list InstanceList used to estimate targets.
* @param features List of features for constraints.
* @param normalize Whether to normalize by feature counts
* @return Constraints (map of feature index to target), with targets
* set using estimates from supplied data.
*/
public static HashMap setTargetsUsingData(InstanceList list, ArrayList features, boolean useValues, boolean normalize) {
HashMap constraints = new HashMap();
double[][] featureLabelCounts = getFeatureLabelCounts(list,useValues);
for (int i = 0; i < features.size(); i++) {
int fi = features.get(i);
if (fi != list.getDataAlphabet().size()) {
double[] prob = featureLabelCounts[fi];
if (normalize) {
// Smooth probability distributions by adding a (very)
// small count. We just need to make sure they aren't
// zero in which case the KL-divergence is infinite.
MatrixOps.plusEquals(prob, 1e-8);
MatrixOps.timesEquals(prob, 1./MatrixOps.sum(prob));
}
constraints.put(fi, prob);
}
}
return constraints;
}
/**
* Set target distributions using "Schapire" heuristic described in
* "Learning from Labeled Features using Generalized Expectation Criteria"
* Gregory Druck, Gideon Mann, Andrew McCallum.
*
* @param labeledFeatures HashMap of feature indices to lists of label indices for that feature.
* @param numLabels Total number of labels.
* @param majorityProb Probability mass divided among majority labels.
* @return Constraints (map of feature index to target distribution), with target
* distributions set using heuristic.
*/
public static HashMap setTargetsUsingHeuristic(HashMap> labeledFeatures, int numLabels, double majorityProb) {
HashMap constraints = new HashMap();
Iterator keyIter = labeledFeatures.keySet().iterator();
while (keyIter.hasNext()) {
int fi = keyIter.next();
ArrayList labels = labeledFeatures.get(fi);
constraints.put(fi, getHeuristicPrior(labels,numLabels,majorityProb));
}
return constraints;
}
/**
* Set target distributions using feature voting heuristic described in
* "Learning from Labeled Features using Generalized Expectation Criteria"
* Gregory Druck, Gideon Mann, Andrew McCallum.
*
* @param labeledFeatures HashMap of feature indices to lists of label indices for that feature.
* @param trainingData InstanceList to use for computing expectations with feature voting.
* @return Constraints (map of feature index to target distribution), with target
* distributions set using feature voting.
*/
public static HashMap setTargetsUsingFeatureVoting(HashMap> labeledFeatures,
InstanceList trainingData) {
HashMap constraints = new HashMap();
int numLabels = trainingData.getTargetAlphabet().size();
Iterator keyIter = labeledFeatures.keySet().iterator();
double[][] featureCounts = new double[labeledFeatures.size()][numLabels];
for (int ii = 0; ii < trainingData.size(); ii++) {
Instance instance = trainingData.get(ii);
FeatureVector fv = (FeatureVector)instance.getData();
Labeling labeling = trainingData.get(ii).getLabeling();
double[] labelDist = new double[numLabels];
if (labeling == null) {
labelByVoting(labeledFeatures,instance,labelDist);
} else {
int li = labeling.getBestIndex();
labelDist[li] = 1.;
}
keyIter = labeledFeatures.keySet().iterator();
int i = 0;
while (keyIter.hasNext()) {
int fi = keyIter.next();
if (fv.location(fi) >= 0) {
for (int li = 0; li < numLabels; li++) {
featureCounts[i][li] += labelDist[li] * fv.valueAtLocation(fv.location(fi));
}
}
i++;
}
}
keyIter = labeledFeatures.keySet().iterator();
int i = 0;
while (keyIter.hasNext()) {
int fi = keyIter.next();
// smoothing counts
MatrixOps.plusEquals(featureCounts[i], 1e-8);
MatrixOps.timesEquals(featureCounts[i],1./MatrixOps.sum(featureCounts[i]));
constraints.put(fi, featureCounts[i]);
i++;
}
return constraints;
}
/**
* Label features using heuristic described in
* "Learning from Labeled Features using Generalized Expectation Criteria"
* Gregory Druck, Gideon Mann, Andrew McCallum.
*
* @param list InstanceList used to compute statistics for labeling features.
* @param features List of features to label.
* @param reject Whether to reject labeling features.
* @return Labeled features, HashMap mapping feature indices to list of labels.
*/
public static HashMap> labelFeatures(InstanceList list, ArrayList features, boolean reject) {
HashMap> labeledFeatures = new HashMap>();
double[][] featureLabelCounts = getFeatureLabelCounts(list,true);
int numLabels = list.getTargetAlphabet().size();
int minRank = 100 * numLabels;
InfoGain infogain = new InfoGain(list);
double sum = 0;
for (int rank = 0; rank < minRank; rank++) {
sum += infogain.getValueAtRank(rank);
}
double mean = sum / minRank;
for (int i = 0; i < features.size(); i++) {
int fi = features.get(i);
// reject features with infogain
// less than cutoff
if (reject && infogain.value(fi) < mean) {
//System.err.println("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
logger.info("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
continue;
}
double[] prob = featureLabelCounts[fi];
MatrixOps.plusEquals(prob,1e-8);
MatrixOps.timesEquals(prob, 1./MatrixOps.sum(prob));
int[] sortedIndices = getMaxIndices(prob);
ArrayList labels = new ArrayList();
if (numLabels > 2) {
// take anything within a factor of 2 of the best
// but no more than numLabels/2
boolean discard = false;
double threshold = prob[sortedIndices[0]] / 2;
for (int li = 0; li < numLabels; li++) {
if (prob[li] > threshold) {
labels.add(li);
}
if (reject && labels.size() > (numLabels / 2)) {
//System.err.println("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
logger.info("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
discard = true;
break;
}
}
if (discard) {
continue;
}
}
else {
labels.add(sortedIndices[0]);
}
labeledFeatures.put(fi, labels);
}
return labeledFeatures;
}
public static HashMap> labelFeatures(InstanceList list, ArrayList features) {
return labelFeatures(list,features,true);
}
public static double[][] getFeatureLabelCounts(InstanceList list, boolean useValues) {
int numFeatures = list.getDataAlphabet().size();
int numLabels = list.getTargetAlphabet().size();
double[][] featureLabelCounts = new double[numFeatures][numLabels];
for (int ii = 0; ii < list.size(); ii++) {
Instance instance = list.get(ii);
FeatureVector featureVector = (FeatureVector)instance.getData();
// this handles distributions over labels
for (int li = 0; li < numLabels; li++) {
double py = instance.getLabeling().value(li);
for (int loc = 0; loc < featureVector.numLocations(); loc++) {
int fi = featureVector.indexAtLocation(loc);
double val;
if (useValues) {
val = featureVector.valueAtLocation(loc);
}
else {
val = 1.0;
}
featureLabelCounts[fi][li] += py * val;
}
}
}
return featureLabelCounts;
}
private static double[] getHeuristicPrior (ArrayList labeledFeatures, int numLabels, double majorityProb) {
int numIndices = labeledFeatures.size();
double[] dist = new double[numLabels];
if (numIndices == numLabels) {
for (int i = 0; i < dist.length; i++) {
dist[i] = 1./numLabels;
}
return dist;
}
double keywordProb = majorityProb / numIndices;
double otherProb = (1 - majorityProb) / (numLabels - numIndices);
for (int i = 0; i < labeledFeatures.size(); i++) {
int li = labeledFeatures.get(i);
dist[li] = keywordProb;
}
for (int li = 0; li < numLabels; li++) {
if (dist[li] == 0) {
dist[li] = otherProb;
}
}
assert(Maths.almostEquals(MatrixOps.sum(dist),1));
return dist;
}
private static void labelByVoting(HashMap> labeledFeatures, Instance instance, double[] scores) {
FeatureVector fv = (FeatureVector)instance.getData();
int numFeatures = instance.getDataAlphabet().size() + 1;
int[] numLabels = new int[instance.getTargetAlphabet().size()];
Iterator keyIterator = labeledFeatures.keySet().iterator();
while (keyIterator.hasNext()) {
ArrayList majorityClassList = labeledFeatures.get(keyIterator.next());
for (int i = 0; i < majorityClassList.size(); i++) {
int li = majorityClassList.get(i);
numLabels[li]++;
}
}
keyIterator = labeledFeatures.keySet().iterator();
while (keyIterator.hasNext()) {
int next = keyIterator.next();
assert(next < numFeatures);
int loc = fv.location(next);
if (loc < 0) {
continue;
}
ArrayList majorityClassList = labeledFeatures.get(next);
for (int i = 0; i < majorityClassList.size(); i++) {
int li = majorityClassList.get(i);
scores[li] += 1;
}
}
double sum = MatrixOps.sum(scores);
if (sum == 0) {
MatrixOps.plusEquals(scores, 1.0);
sum = MatrixOps.sum(scores);
}
for (int li = 0; li < scores.length; li++) {
scores[li] /= sum;
}
}
/*
* These functions are no longer needed.
*
private static double[][] getPrWordTopic(LDAHyper lda){
int numTopics = lda.getNumTopics();
int numTypes = lda.getAlphabet().size();
double[][] prWordTopic = new double[numTopics][numTypes];
for (int ti = 0 ; ti < numTopics; ti++){
for (int wi = 0 ; wi < numTypes; wi++){
prWordTopic[ti][wi] = (double) lda.getCountFeatureTopic(wi, ti) / (double) lda.getCountTokensPerTopic(ti);
}
}
return prWordTopic;
}
private static int[][] getSortedTopic(double[][] prTopicWord){
int numTopics = prTopicWord.length;
int numTypes = prTopicWord[0].length;
int[][] sortedTopicIdx = new int[numTopics][numTypes];
for (int ti = 0; ti < numTopics; ti++){
int[] topicIdx = getMaxIndices(prTopicWord[ti]);
System.arraycopy(topicIdx, 0, sortedTopicIdx[ti], 0, topicIdx.length);
}
return sortedTopicIdx;
}
*/
private static int[] getMaxIndices(double[] x) {
ArrayList list = new ArrayList();
for (int i = 0; i < x.length; i++) {
Element element = new Element(i,x[i]);
list.add(element);
}
Collections.sort(list);
Collections.reverse(list);
int[] sortedIndices = new int[x.length];
for (int i = 0; i < x.length; i++) {
sortedIndices[i] = list.get(i).index;
}
return sortedIndices;
}
private static class Element implements Comparable {
private int index;
private double value;
public Element(int index, double value) {
this.index = index;
this.value = value;
}
public int compareTo(Element element) {
return Double.compare(this.value, element.value);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy