All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.classify.ColumnDataClassifier Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
// Stanford Classifier, ColumnDataClassifier - a multiclass maxent classifier
// Copyright (c) 2003-2012 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
//
// This code is a parameter language for front-end feature
// generation for the loglinear model classification code in
// the Stanford Classifier package (mainly written by Dan Klein).
//
// For more information, bug reports, fixes, contact:
//    Christopher Manning
//    Dept of Computer Science, Gates 1A
//    Stanford CA 94305-9010
//    USA
//    Support/Questions: [email protected]
//    Licensing: [email protected]
//    http://www-nlp.stanford.edu/software/classifier.shtml

package edu.stanford.nlp.classify;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.process.PTBTokenizer;
import edu.stanford.nlp.process.Tokenizer;
import edu.stanford.nlp.process.TokenizerFactory;
import edu.stanford.nlp.process.WordShapeClassifier;
import edu.stanford.nlp.stats.*;
import edu.stanford.nlp.util.*;

import java.io.*;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;


import edu.stanford.nlp.util.logging.Redwood;


/**
 * ColumnDataClassifier provides a command-line interface for doing
 * context-free (independent) classification of a series of data items,
 * where each data item is represented by a line of
 * a file, as a list of String variables, in tab-separated columns.  Some
 * features will interpret these variables as numbers, but
 * the code is mainly oriented towards generating features for string
 * classification.  To designate a real-valued feature, use the realValued
 * option described below. The classifier by default is a maxent classifier
 * (also known as a softmax classifier or a discriminative loglinear classifier;
 * equivalent to multiclass logistic regression apart from a slightly different
 * symmetric parameterization. It also implements a Bernoulli Naive
 * Bayes model and can implement an SVM by an external call to SVMlight.
 * 

* You can also use ColumnDataClassifier programmatically, where its main * usefulness beyond simply building your own LinearClassifier is that it * provides easy conversion of data items into features, using the same * properties as the command-line version. You can see example of usage in * the class {@link edu.stanford.nlp.classify.demo.ClassifierDemo}. *

* Input files are expected to * be one data item per line with two or more columns indicating the class * of the item and one or more predictive features. Columns are * separated by tab characters. Tab and newline characters cannot occur * inside field values (there is no escaping mechanism); any other characters * are legal in field values. *

* Typical usage: *

{@code * java edu.stanford.nlp.classify.ColumnDataClassifier -prop propFile * }

or

* {@code java -mx300m edu.stanford.nlp.classify.ColumnDataClassifier * -trainFile trainFile -testFile testFile -useNGrams|... > output * } *

* (Note that for large data sets, you may wish to specify * the amount of memory available to Java, such * as in the second example above.) *

* In the simplest case, there are just two tab-separated columns in the * training input: the first for the class, and the second for the String * datum which has that class. In more complex uses, each datum can * be multidimensional, and there are many columns of data attributes. *

* To illustrate simple uses, and the behavior of Naive Bayes and Maximum * entropy classifiers, example files corresponding to the examples from the * Manning and Klein maxent classifier tutorial, slides 46-49, available at * http://nlp.stanford.edu/downloads/classifier.shtml are included in the * classify package source directory (files starting with "easy"). Other * examples appear in the {@code examples} directory of the distributed * classifier. *

* In many instances, parameters can either be given on the command line * or provided using a Properties file * (specified on the command-line with {@code -prop} propFile). * Option names are the same as property names with a preceding dash. Boolean * properties can simply be turned on via the command line. Parameters of * types int, String, and double take a following argument after the option. * Command-line parameters can only define features for the first column * describing the datum. If you have multidimensional data, you need to use * a properties file. Property names, as below, are either global (things * like the testFile name) or are seen as properties that define features * for the first data column (NOTE: we count columns from 0 - unlike the Unix cut * command!). To specify features for a particular data column, precede a * feature by a column number and then a period (for example, * {@code 3.wordShape=chris4}). If no number is specified, then the * default interpretation is column 0. Note that in properties files you must * give a value to boolean properties (e.g., {@code 2.useString=true}); * just giving the property name (as {@code 2.useString}) isn't * sufficient. *

* The following properties are recognized: *

* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
Properties for ColumnDataClassifier
Property NameTypeDefault ValueDescriptionFeatName
loadClassifier Stringn/aPath of serialized classifier file to load
serializeToStringn/aPath to serialize classifier to
printToStringn/aPath to print a text representation of the linear classifier to
trainFileStringn/aPath of file to use as training data
testFileStringn/aPath of file to use as test data
encodingStringutf-8Character encoding of training and test file, e.g., utf-8, GB18030, or iso-8859-1
displayedColumnint1Column number that will be printed out to stdout in the output next to the gold class and the chosen class. This is just an aide memoire. If the value is negative, nothing is printed.
displayAllAnswersbooleanfalseIf true, print all classes and their probability, sorted by probability, rather than just the highest scoring and correct classes.
goldAnswerColumnint0Column number that contains the correct class for each data item (again, columns are numbered from 0 up).
groupingColumnint-1Column for grouping multiple data items for the purpose of computing ranking accuracy. This is appropriate when only one datum in a group can be correct, and the intention is to choose the highest probability one, rather than accepting all above a threshold. Multiple items in the same group must be contiguous in the test file (otherwise it would be necessary to cache probabilities and groups for the entire test file to check matches). If it is negative, no grouping column is used, and no ranking accuracy is reported.
rankingScoreColumnint-1If this parameter is non-negative and a groupingColumn is defined, then an average ranking score will be calculated by scoring the chosen candidate from a group according to its value in this column (for instance, the values of this column can be set to a mean reciprocal rank of 1.0 for the best answer, 0.5 for the second best and so on, or the value of this column can be a similarity score reflecting the similarity of the answer to the true answer.
rankingAccuracyClassStringnullIf this and groupingColumn are defined (positive), then the system will compute a ranking accuracy under the assumption that there is (at most) one assignment of this class for each group, and ranking accuracy counts the classifier as right if that datum is the one with highest probability according to the model.
useStringbooleanfalseGives you a feature for whole string sS-str
useClassFeaturebooleanfalseInclude a feature for the class (as a class marginal). This is the same thing as having a bias vector or having an always-on feature in a model.CLASS
binnedLengthsStringnullIf non-null, treat as a sequence of comma separated integer bounds, where items above the previous bound (if any) up to the next bound (inclusive) are binned (e.g., "1,5,15,30,60"). The feature represents the length of the String in this column.Len-range
binnedLengthsStatisticsbooleanfalseIf true, print to stderr contingency table of statistics for binnedLengths.
binnedValuesStringnullIf non-null, treat as a sequence of comma separated double bounds, where data items above the previous bound up to the next bound (inclusive) are binned. If a value in this column isn't a legal {@code double}, then the value is treated as {@code binnedValuesNaN}.Val-range
binnedValuesNaNdouble-1.0If the value of a numeric binnedValues field is not a number, it will be given this value.
binnedValuesStatisticsbooleanfalseIf true, print to stderr a contingency table of statistics for binnedValues.
countCharsStringnullIf non-null, count the number of occurrences of each character in the String, and make a feature for each character, binned according to {@code countCharsBins}Char-ch-range
countCharsBinsString"0,1"Treat as a sequence of comma separated integer bounds, where character counts above the previous bound up to and including the next bound are binned. For instance, a value of "0,2" will give 3 bins, dividing a character count into bins of 0, 1-or-2, and 3-or-more occurrences.
splitWordsRegexpStringnullIf defined, use this as a regular expression on which to split the whole string (as in the String.split() function, which will return the things between delimiters, and discard the delimiters). The resulting split-up "words" will be used in classifier features iff one of the other "useSplit" options is turned on.
splitWordsTokenizerRegexpStringnullIf defined, use this as a regular expression to cut initial pieces off a String. Either this regular expression or {@code splitWordsIgnoreRegexp} should always match the start of the String, and the size of the token is the number of characters matched. So, for example, one can group letter and number characters but do nothing else with a regular expression like {@code ([A-Za-z]+|[0-9]+|.)}, where the last disjunct will match any other single character. (If neither regular expression matches, the first character of the string is treated as a one character word, and then matching is tried again, but in this case a warning message is printed.) Note that, for Java regular expressions with disjunctions like this, the match is the first matching disjunction, not the longest matching disjunction, so patterns with common prefixes need to be ordered from most specific (longest) to least specific (shortest).) The resulting split up "words" will be used in classifier features iff one of the other "useSplit" options is turned on. Note that as usual for Java String processing, backslashes must be doubled in the regular expressions that you write.
splitWordsIgnoreRegexpString\\s+If non-empty, this regexp is used to determine character sequences which should not be returned as tokens when using {@code splitWordsTokenizerRegexp} or {@code splitWordsRegexp}. With the former, first the program attempts to match this regular expression at the start of the string (with {@code lookingAt()}) and if it matches, those characters are discarded, but if it doesn't match then {@code splitWordsTokenizerRegexp} is tried. With {@code splitWordsRegexp}, this is used to filter tokens (with {@code matches()} resulting from the splitting. By default this regular expression is set to be all whitespace tokens (i.e., \\s+). Set it to an empty string to get all tokens returned.
splitWordsWithPTBTokenizerbooleanfalseIf true, and {@code splitWordsRegexp} and {@code splitWordsTokenizerRegexp} are false, then will tokenize using the {@code PTBTokenizer}
useSplitWordsbooleanfalseMake features from the "words" that are returned by dividing the string on splitWordsRegexp or splitWordsTokenizerRegexp. Requires splitWordsRegexp or splitWordsTokenizerRegexp.SW-str
useLowercaseSplitWordsbooleanfalseMake features from the "words" that are returned by dividing the string on splitWordsRegexp or splitWordsTokenizerRegexp and then lowercasing the result. Requires splitWordsRegexp or splitWordsTokenizerRegexp. Note that this can be specified independently of useSplitWords. You can put either or both original cased and lowercased words in as features.LSW-str
useSplitWordPairsbooleanfalseMake features from the pairs of adjacent "words" that are returned by dividing the string into splitWords. Requires splitWordsRegexp or splitWordsTokenizerRegexp. This doesn't add features for the first and last words being next to a boundary; if you want those, also set useSplitFirstLastWords.SWP-str1-str2
useLowercaseSplitWordPairsbooleanfalseMake features from the lowercased form of the pairs of adjacent "words" that are returned by dividing the string into splitWords. Requires splitWordsRegexp or splitWordsTokenizerRegexp. This doesn't add features for the first and last words being next to a boundary; if you want those, also set useLowercaseSplitFirstLastWords.LSWP-str1-str2
useAllSplitWordPairsbooleanfalseMake features from all pairs of "words" that are returned by dividing the string into splitWords. Requires splitWordsRegexp or splitWordsTokenizerRegexp.ASWP-str1-str2
useAllSplitWordTriplesbooleanfalseMake features from all triples of "words" that are returned by dividing the string into splitWords. Requires splitWordsRegexp or splitWordsTokenizerRegexp.ASWT-str1-str2-str3
useSplitWordNGramsbooleanfalseMake features of adjacent word n-grams of lengths between minWordNGramLeng and maxWordNGramLeng inclusive. Note that these are word sequences, not character n-grams.SW#-str1-str2-strN
splitWordCountbooleanfalseIf true, the value of this real-valued feature is the number of split word tokens in the column.SWNUM
logSplitWordCountbooleanfalseIf true, the value of this real-valued feature is the log of the number of split word tokens in the column.LSWNUM
binnedSplitWordCountsStringnullIf non-null, treat as a sequence of comma-separated integer bounds, where items above the previous bound (if any) up to the next bound (inclusive) are binned (e.g., "1,5,15,30,60"). The feature represents the number of split words in this column.SWNUMBIN-range
maxWordNGramLengint-1If this number is positive, word n-grams above this size will not be used in the model
minWordNGramLengint1Must be positive. word n-grams below this size will not be used in the model
wordNGramBoundaryRegexpStringnullIf this is defined and the regexp matches, then the ngram stops
useSplitFirstLastWordsbooleanfalseMake a feature from each of the first and last "words" that are returned as splitWords. This is equivalent to having word bigrams with boundary tokens at each end of the sequence (they get a special feature). Requires splitWordsRegexp or splitWordsTokenizerRegexp.SFW-str, SLW-str
useLowercaseSplitFirstLastWordsbooleanfalseMake a feature from each of the first and last "words" that are returned as splitWords. This is equivalent to having word bigrams with boundary tokens at each end of the sequence (they get a special feature). Requires splitWordsRegexp or splitWordsTokenizerRegexp.LSFW-str, LSLW-str
useSplitNGramsbooleanfalseMake features from letter n-grams - internal as well as edge all treated the same - after the data string has been split into tokens. Requires splitWordsRegexp or splitWordsTokenizerRegexp.S#-str
useSplitPrefixSuffixNGramsbooleanfalseMake features from prefixes and suffixes of each token, after splitting string with splitWordsRegexp. Requires splitWordsRegexp or splitWordsTokenizerRegexp.S#B-str, S#E-str
useSplitWordVectorsStringnull(If non-null) load word vectors from this file and add their average over all split words as real-valued features. Requires splitWordsRegexp or splitWordsTokenizerRegexp. Note that for best results you need a close match between your tokenization and that used by the word vectors.SWV-num
useNGramsbooleanfalseMake features from letter n-grams - internal as well as edge all treated the same.#-str
usePrefixSuffixNGramsbooleanfalseMake features from prefix and suffix substrings of the string.#B-str, #E-str
lowercasebooleanfalseMake the input string lowercase so all features work uncased
lowercaseNGramsbooleanfalseMake features from letter n-grams all lowercase (for all of useNGrams, usePrefixSuffixNGrams, useSplitNGrams, and useSplitPrefixSuffixNGrams)
maxNGramLengint-1If this number is positive, n-grams above this size will not be used in the model
minNGramLengint2Must be positive. n-grams below this size will not be used in the model
partialNGramRegexpStringnullIf this is defined and the regexp matches, then n-grams are made only from the matching text (if no capturing groups are defined) or from the first capturing group of the regexp, if there is one. This substring is used for both useNGrams and usePrefixSuffixNGrams.
realValuedbooleanfalseTreat this column as real-valued and do not perform any transforms on the feature value.Value
logTransformbooleanfalseTreat this column as real-valued and use the log of the value as the feature value.Log
logitTransformbooleanfalseTreat this column as real-valued and use the logit of the value as the feature value.Logit
sqrtTransformbooleanfalseTreat this column as real-valued and use the square root of the value as the feature value.Sqrt
filenamebooleanfalseTreat this column as a filename (path) and then use the contents of that file (assumed to be plain text) in the calculation of features according to other flag specifications.
wordShapeStringnoneEither "none" for no wordShape use, or the name of a word shape function recognized by {@link edu.stanford.nlp.process.WordShapeClassifier#lookupShaper(String)}, such as "dan1" or "chris4". WordShape functions equivalence-class strings based on the pattern of letter, digit, and symbol characters that they contain. The details depend on the particular function chosen.SHAPE-str
splitWordShapeStringnoneEither "none" for no wordShape or the name of a word shape function recognized by {@link WordShapeClassifier#lookupShaper(String)}. This is applied to each "word" found by splitWordsRegexp or splitWordsTokenizerRegexp.SSHAPE-str
featureMinimumSupportint0A feature, that is, an (observed,class) pair, will only be included in the model providing it is seen a minimum of this number of times in the training data.
biasedHyperplaneStringnullIf non-null, a sequence of comma-separated pairs of className prob. An item will only be classified to a certain class className if its probability of class membership exceeds the given conditional probability prob; otherwise it will be assigned to a different class. If this list of classes is exhaustive, and no condition is satisfied, then the most probable class is chosen.
printFeaturesStringnullPrint out the features and their values for each instance to a file based on this name.
printClassifierStringnullStyle in which to print the classifier. One of: HighWeight, HighMagnitude, AllWeights, WeightHistogram, WeightDistribution. See LinearClassifier class for details.
printClassifierParamint100A parameter to the printing style, which may give, for example the number of parameters to print (for HighWeight or HighMagnitude).
justifybooleanfalseFor each test data item, print justification (weights) for active features used in classification.
exitAfterTrainingFeaturizationbooleanfalseIf true, the program exits after reading the training data (trainFile) and before building a classifier. This is useful in conjunction with printFeatures, if one only wants to convert data to features for use with another classifier.
verboseOptimizationbooleanfalseIf true, print much more detail about classifier optimization.
internbooleanfalseIf true, (String) intern all of the (final) feature names. Recommended (this saves memory, but slows down feature generation in training).
cacheNGramsbooleanfalseIf true, record the NGram features that correspond to a String (under the current option settings and reuse rather than recalculating if the String is seen again. Disrecommended (speeds training but can require enormous amounts of memory).
useNBbooleanfalseUse a Naive Bayes generative classifier (over all features) rather than a discriminative logistic regression classifier. (Set {@code useClass} to true to get a prior term.)
useBinarybooleanfalseUse the binary classifier (i.e. use LogisticClassifierFactory, rather than LinearClassifierFactory) to get classifier
l1regdouble0.0If set to be larger than 0, uses L1 regularization
useAdaptL1booleanfalseIf true, uses adaptive L1 regularization to find value of l1reg that gives the desired number of features set by limitFeatures
l1regmindouble0.0Minimum L1 in search
l1regmaxdouble500.0Maximum L1 in search
featureWeightThresholddouble0.0Threshold of model weight at which feature is kept. "Unimportant" low weight features are discarded. (Currently only implemented for adaptL1.)
limitFeaturesLabelsStringnullIf set, only include features for these labels in the desired number of features
limitFeaturesint0If set to be larger than 0, uses adaptive L1 regularization to find value of l1reg that gives the desired number of features
priorString/intquadraticType of prior (regularization penalty on weights). Possible values are null, "no", "quadratic", "huber", "quartic", "cosh", or "adapt". See {@link edu.stanford.nlp.classify.LogPrior LogPrior} for more information.
useSumbooleanfalseDo optimization via summed conditional likelihood, rather than the product. (This is expensive, non-standard, and somewhat unstable, but can be quite effective: see Klein and Manning 2002 EMNLP paper.)
tolerancedouble1e-4Convergence tolerance in parameter optimization
sigmadouble1.0A parameter to several of the smoothing (i.e., regularization) methods, usually giving a degree of smoothing as a standard deviation (with small positive values being stronger smoothing, and bigger values weaker smoothing). However, for Naive Bayes models it is the amount of add-sigma smoothing, so a bigger number is more smoothing.
epsilondouble0.01Used only as a parameter in the Huber loss: this is the distance from 0 at which the loss changes from quadratic to linear
useQNbooleantrueUse Quasi-Newton optimization if true, otherwise use Conjugate Gradient optimization. Recommended.
QNsizeint15Number of previous iterations of Quasi-Newton to store (this increases memory use, but speeds convergence by letting the Quasi-Newton optimization more effectively approximate the second derivative).
featureFormatbooleanfalseAssumes the input file isn't text strings but already featurized. One column is treated as the class column (as defined by {@code goldAnswerColumn}, and all other columns are treated as features of the instance. (If answers are not present, set {@code goldAnswerColumn} to a negative number.)
trainFromSVMLightbooleanfalseAssumes the trainFile is in SVMLight format (see SVMLight web page for more information)
testFromSVMLightbooleanfalseAssumes the testFile is in SVMLight format
printSVMLightFormatToStringnullIf non-null, print the featurized training data to an SVMLight format file (usually used with exitAfterTrainingFeaturization). This is just an option to write out data in a particular format. After that, you're on your own using some other piece of software that reads SVMlight format files.
crossValidationFoldsint-1If positive, the training data is divided in to this many folds and cross-validation is done on the training data (prior to testing on test data, if it is also specified)
printCrossValidationDecisionsbooleanfalseWhether to print the individual classification decisions in cross-validation training, if crossValidationFolds is positive.
shuffleTrainingDatabooleanfalseIf true, the training data is shuffled prior to training and cross-validation. This is vital in cross-validation if the training data is otherwise sorted by class.
shuffleSeedlong0If non-zero, and the training data is being shuffled, this is used as the seed for the Random. Otherwise, System.nanoTime() is used.
csvInputbooleanfalseIf true, reads train and test file in csv format, with support for quoted fields.
inputFormatStringnullIf "header" then reads file with first line treated as header; if "comments" treats lines starting with # as comments; else treated as "plain" tsv/csv file
csvOutputStringnullIf non-null, used to format the output of the classifier. This is a printf-style format specification where %0 through %9 can print columns of the input, %c prints the assigned class and %n a newline character. This option can produce Kaggle-format output files!
* * @author Christopher Manning * @author Anna Rafferty * @author Angel Chang (added options for using l1reg) */ public class ColumnDataClassifier { private static final double DEFAULT_VALUE = 1.0; // default value for setting categorical, boolean features private static final String DEFAULT_IGNORE_REGEXP = "\\s+"; private final Flags[] flags; private final Flags globalFlags; // simply points to flags[0] private Classifier classifier; // really only assigned once too (either in train or load in setProperties) private TokenizerFactory ptbFactory; private enum InputFormat { PLAIN, COMMENTS, HEADER } /** A logger for this class */ private static final Redwood.RedwoodChannels logger = Redwood.channels(ColumnDataClassifier.class); /** * Entry point for taking a String (formatted as a line of a TSV file) and * translating it into a Datum of features. If real-valued features are used, * this method returns an RVFDatum; otherwise, categorical features are used. * * @param line Line of file * @return A Datum (may be an RVFDatum; never null) */ public Datum makeDatumFromLine(String line) { return makeDatumFromStrings(splitLineToFields(line)); } /** * Takes a String[] of elements and translates them into a Datum of features. * If real-valued features are used, this method accesses makeRVFDatumFromLine * and returns an RVFDatum; otherwise, categorical features are used. * * @param strings The elements that features are made from (the columns of a TSV/CSV file) * @return A Datum (may be an RVFDatum; never null) */ public Datum makeDatumFromStrings(String[] strings) { if (globalFlags.usesRealValues) { return makeRVFDatumFromStrings(strings); } if (globalFlags.featureFormat) { Collection theFeatures = new ArrayList<>(); for (int i = 0; i < strings.length; i++) { if (i != globalFlags.goldAnswerColumn) if (globalFlags.significantColumnId) { theFeatures.add(String.format("%d:%s", i, strings[i])); } else { theFeatures.add(strings[i]); } } return new BasicDatum<>(theFeatures, strings[globalFlags.goldAnswerColumn]); } else { //logger.info("Read in " + strings); return makeDatum(strings); } } private RVFDatum makeRVFDatumFromStrings(String[] strings) { if (globalFlags.featureFormat) { ClassicCounter theFeatures = new ClassicCounter<>(); for (int i = 0; i < strings.length; i++) { if (i != globalFlags.goldAnswerColumn) { if (flags[i] != null && (flags[i].isRealValued || flags[i].logTransform || flags[i].logitTransform || flags[i].sqrtTransform)) { addFeatureValue(strings[i], flags[i], theFeatures); } else { theFeatures.setCount(strings[i], 1.0); } } } return new RVFDatum<>(theFeatures, strings[globalFlags.goldAnswerColumn]); } else { //logger.info("Read in " + strings); return makeRVFDatum(strings); } } /** Read a set of training examples from a file, and return the data in a * featurized form. If feature selection is asked for, the returned * featurized form is after feature selection has been applied. * * @param fileName File with supervised training examples. * @return A GeneralDataset, where the labels and features are Strings. */ public GeneralDataset readTrainingExamples(String fileName) { return readAndReturnTrainingExamples(fileName).first(); } /** Read a set of training examples from a file, and return the data in a * featurized form and in String form. If feature selection is asked for, the returned * featurized form is after feature selection has been applied. * * @param fileName File with supervised training examples. * @return A Pair of a GeneralDataset, where the labels and features are Strings and a List of the input examples */ public Pair, List> readAndReturnTrainingExamples(String fileName) { if (globalFlags.printFeatures != null) { newFeaturePrinter(globalFlags.printFeatures, "train", Flags.encoding); } Pair, List> dataInfo = readDataset(fileName, true); GeneralDataset train = dataInfo.first(); if (globalFlags.featureMinimumSupport > 1) { logger.info("Removing Features with counts < " + globalFlags.featureMinimumSupport); train.applyFeatureCountThreshold(globalFlags.featureMinimumSupport); } train.summaryStatistics(); return dataInfo; } /** Read a data set from a file at test time, and return it. * * @param filename The file to read the examples from. * @return A Pair. The first item of the pair is the featurized data set, * ready for passing to the classifier. The second item of the pair * is simply each line of the file split into tab-separated columns. * This is at present necessary for the built-in evaluation, which uses * the gold class from here, and may also be helpful when wanting to * print extra output about the classification process. */ public Pair,List> readTestExamples(String filename) { return readDataset(filename, true); } private static List makeSVMLightLineInfos(List lines) { List lineInfos = new ArrayList<>(lines.size()); for (String line : lines) { line = line.replaceFirst("#.*$", ""); // remove any trailing comments // in principle, it'd be nice to save the comment, though, for possible use as a displayedColumn - make it column 1?? lineInfos.add(line.split("\\s+")); } return lineInfos; } /** NB: This is meant to do splitting strictly only on tabs, and to thus * work with things that are exactly TSV files. It shouldn't split on * all whitespace, because it is useful to be able to have spaces inside * fields for short text documents, and then to be able to split them into * words with features like useSplitWords. */ private static final Pattern tab = Pattern.compile("\\t"); /** Read a data set from a file and convert it into a Dataset object. * In test phase, returns the {@code List} with the data columns for printing purposes. * Otherwise, returns {@code null} for the second item. * * @param filename Where to read data from * @param inTestPhase Whether to return the read String[] for each data item * @return A Pair of a GeneralDataSet of Datums and a List of datums in String form. */ @SuppressWarnings("NonThreadSafeLazyInitialization") private Pair, List> readDataset(String filename, boolean inTestPhase) { Timing tim = new Timing(); GeneralDataset dataset; List lineInfos = null; if ((inTestPhase && Flags.testFromSVMLight) || (!inTestPhase && Flags.trainFromSVMLight)) { List lines = null; if (inTestPhase) { lines = new ArrayList<>(); } if (globalFlags.usesRealValues) { dataset = RVFDataset.readSVMLightFormat(filename, lines); } else { dataset = Dataset.readSVMLightFormat(filename, lines); } if (lines != null) { lineInfos = makeSVMLightLineInfos(lines); } } else { try { if (inTestPhase) { lineInfos = new ArrayList<>(); } if (globalFlags.usesRealValues) { dataset = new RVFDataset<>(); } else { dataset = new Dataset<>(); } int lineNo = 0; int minColumns = Integer.MAX_VALUE; int maxColumns = 0; for (String line : ObjectBank.getLineIterator(new File(filename), Flags.encoding)) { lineNo++; if (Flags.inputFormat == InputFormat.HEADER) { if (lineNo == 1) { if (storedHeader == null) { storedHeader = line; // store it because need elements of it to print header in output } continue; } } else if (Flags.inputFormat == InputFormat.COMMENTS) { if (line.matches("\\s#.*")) { continue; } } String[] strings = splitLineToFields(line); if (strings.length < 2) { throw new RuntimeException("Line format error at line " + lineNo + ": " + line); } if (strings.length < minColumns) { minColumns = strings.length; } if (strings.length > maxColumns) { maxColumns = strings.length; } if (inTestPhase) { lineInfos.add(strings); } if (strings.length < flags.length) { throw new RuntimeException("Error: Line has too few tab-separated columns (" + maxColumns + ") for " + flags.length + " columns required by specified properties: " + line); } dataset.add(makeDatumFromStrings(strings)); } if (lineNo > 0 && minColumns != maxColumns) { logger.info("WARNING: Number of tab-separated columns in " + filename + " varies between " + minColumns + " and " + maxColumns); } } catch (Exception e) { throw new RuntimeException("Dataset could not be loaded", e); } } logger.info("Reading dataset from " + filename + " ... done [" + tim.toSecondsString() + "s, " + dataset.size() + " items]."); return new Pair<>(dataset, lineInfos); } //Split according to whether we are using tsv file (default) or csv files private String[] splitLineToFields(String line) { if(globalFlags.csvInput) { String[] strings = StringUtils.splitOnCharWithQuoting(line, ',', '"', '"'); for (int i = 0; i < strings.length; ++i) { if (strings[i].startsWith("\"") && strings[i].endsWith("\"")) strings[i] = strings[i].substring(1,strings[i].length()-1); } return strings; } else return tab.split(line); } /** * Write summary statistics about a group of answers. */ private Pair writeResultsSummary(int num, Counter contingency, Collection labels) { logger.info(""); String message = ""; message += num + " examples"; if (globalFlags.groupingColumn >= 0 && globalFlags.rankingAccuracyClass != null) { message += " and " + numGroups + " ranking groups"; } logger.info(message + " in test set"); int numClasses = 0; double microAccuracy = 0.0; double macroF1 = 0.0; for (String key : labels) { numClasses++; int tp = (int) contingency.getCount(key + "|TP"); int fn = (int) contingency.getCount(key + "|FN"); int fp = (int) contingency.getCount(key + "|FP"); int tn = (int) contingency.getCount(key + "|TN"); double p = (tp + fp == 0) ? 1.0 : ((double) tp) / (tp + fp); // If nothing selected, then vacuous 1.0 double r = (tp + fn == 0) ? 1.0 : ((double) tp) / (tp + fn); // If nothing to find, then vacuous 1.0 double f = (p == 0.0 && r == 0.0) ? 0.0 : 2 * p * r / (p + r); double acc = ((double) tp + tn)/num; macroF1 += f; microAccuracy += tp; logger.info("Cls " + key + ": TP=" + tp + " FN=" + fn + " FP=" + fp + " TN=" + tn + "; Acc " + nf.format(acc) + " P " + nf.format(p) + " R " + nf.format(r) + " F1 " + nf.format(f)); } if (globalFlags.groupingColumn >= 0 && globalFlags.rankingAccuracyClass != null) { double cor = (int) contingency.getCount("Ranking|Correct"); double err = (int) contingency.getCount("Ranking|Error"); double rankacc = (cor + err == 0) ? 0 : cor / (cor + err); logger.info("Ranking accuracy: " + nf.format(rankacc)); double cov = (int) contingency.getCount("Ranking|Covered"); double coverr = (int) contingency.getCount("Ranking|Uncovered"); double covacc = (cov + coverr == 0) ? 0 : cov / (cov + coverr); if (coverr > 0.5) { double ce = (int) (contingency.getCount("Ranking|Error") - contingency.getCount("Ranking|Uncovered")); double crankacc = (cor + ce == 0) ? 0 : cor / (cor + ce); logger.info(" (on " + nf.format(covacc) + " of groups with correct answer: " + nf.format(crankacc) + ')'); } else { logger.info(""); } if (globalFlags.rankingScoreColumn >= 0) { double totalSim = contingency.getCount("Ranking|Score"); double ranksim = (cor + err == 0) ? 0 : totalSim / (cor + err); logger.info("Ranking average score: " + nf.format(ranksim)); } } microAccuracy = microAccuracy / num; macroF1 = macroF1 / numClasses; NumberFormat nf2 = new DecimalFormat("0.00000"); logger.info("Accuracy/micro-averaged F1: " + nf2.format(microAccuracy)); logger.info("Macro-averaged F1: " + nf2.format(macroF1)); return new Pair<>(microAccuracy, macroF1); } // These variables are only used by the private methods used by main() for displaying // performance statistics when running the command-line version. So their being // static does no harm. private static int numGroups = 0; private static String lastGroup = ""; private static int numInGroup = 0; private static double bestProb = 0.0; private static double bestSim = 0.0; private static boolean currentHighestProbCorrect = false; private static boolean foundAnswerInGroup = false; private static String storedHeader; private static final NumberFormat nf = new DecimalFormat("0.000"); /** * Write out an answer, and update statistics. */ private void writeAnswer(String[] strs, String clAnswer, Distribution cntr) { String goldAnswer = globalFlags.goldAnswerColumn < strs.length ? strs[globalFlags.goldAnswerColumn]: ""; String printedText = ""; if (globalFlags.displayedColumn >= 0) { printedText = strs[globalFlags.displayedColumn]; } String results; if (globalFlags.displayAllAnswers) { // sort the labels by probability TreeSet> sortedLabels = new TreeSet<>(); for (String key : cntr.keySet()) { sortedLabels.add(new Pair<>(cntr.probabilityOf(key), key)); } StringBuilder builder = new StringBuilder(); for (Pair pair : sortedLabels.descendingSet()) { if (builder.length() > 0) { builder.append('\t'); } builder.append(pair.first().toString()).append('\t').append(pair.second()); } results = builder.toString(); } else { results = clAnswer + '\t' + nf.format(cntr.probabilityOf(clAnswer)) + '\t' + nf.format(cntr.probabilityOf(goldAnswer)); } String line; if (printedText.isEmpty()) { line = goldAnswer + '\t' + results; } else { line = printedText + '\t' + goldAnswer + '\t' + results; } System.out.println(line); } private void updatePerformanceStatistics(String[] strs, String clAnswer, Distribution cntr, Counter contingency, Classifier c, double sim) { String goldAnswer = globalFlags.goldAnswerColumn < strs.length ? strs[globalFlags.goldAnswerColumn]: ""; // NB: This next bit judges correctness by surface String equality, not our internal indices, so strs has to be right even for svmlightFormat for (String next : c.labels()) { if (next.equals(goldAnswer)) { if (next.equals(clAnswer)) { contingency.incrementCount(next + "|TP"); } else { contingency.incrementCount(next + "|FN"); } } else { if (next.equals(clAnswer)) { contingency.incrementCount(next + "|FP"); } else { contingency.incrementCount(next + "|TN"); } } } if (globalFlags.groupingColumn >= 0 && globalFlags.rankingAccuracyClass != null) { String group = strs[globalFlags.groupingColumn]; // logger.info("Group is " + group); if (group.equals(lastGroup)) { numInGroup++; double prob = cntr.probabilityOf(globalFlags.rankingAccuracyClass); // logger.info(" same group; prob is " + prob); if (prob > bestProb) { bestProb = prob; bestSim = sim; // logger.info(" better prob than before"); currentHighestProbCorrect = goldAnswer.equals(globalFlags.rankingAccuracyClass); } if (globalFlags.rankingAccuracyClass.equals(goldAnswer)) { foundAnswerInGroup = true; } } else { finishRanking(contingency, bestSim); numGroups++; lastGroup = group; bestProb = cntr.probabilityOf(globalFlags.rankingAccuracyClass); bestSim = sim; // logger.info(" different; prob is " + bestProb); numInGroup = 1; currentHighestProbCorrect = goldAnswer.equals(globalFlags.rankingAccuracyClass); foundAnswerInGroup = globalFlags.rankingAccuracyClass.equals(goldAnswer); } } } private void finishRanking(Counter contingency, double sim) { if (numInGroup > 0) { if (globalFlags.justify) { String message = "Previous group of " + numInGroup + ": "; if (!foundAnswerInGroup) { message += "no correct answer; "; } message += "highest ranked guess was: " + ((currentHighestProbCorrect ? "correct" : "incorrect")); logger.info(message); logger.info(" (sim. = " + nf.format(sim) + ')'); } if (currentHighestProbCorrect) { contingency.incrementCount("Ranking|Correct"); } else { contingency.incrementCount("Ranking|Error"); } if (foundAnswerInGroup) { contingency.incrementCount("Ranking|Covered"); } else { contingency.incrementCount("Ranking|Uncovered"); } contingency.incrementCount("Ranking|Score", sim); } } /** Test and evaluate classifier on examples with their String representation and gold classification available. * * @param cl The classifier to test * @param test The dataset to test on * @param lineInfos Duplicate List of the items to be classified, each an array of Strings (like a line of a TSV file) * @return A Pair consisting of the accuracy (micro-averaged F1) and macro-averaged F1 for the dataset */ private Pair testExamples(Classifier cl, GeneralDataset test, List lineInfos) { // usually suppress item level printing in crossvalidation if ( ! (globalFlags.crossValidationFolds > 0 && ! globalFlags.printCrossValidationDecisions)) { String message = ""; if (globalFlags.csvOutput != null) { message += formatCsv(globalFlags.csvOutput, storedHeader.split("\t"), null); } else { message += "Output format: "; if (globalFlags.displayedColumn >= 0) { message += "dataColumn" + globalFlags.displayedColumn + "\t"; } message += "goldAnswer\t"; if (globalFlags.displayAllAnswers) { logger.info(message + "[P(class) class]+ {sorted by probability}"); } else { logger.info(message + "classifierAnswer\tP(clAnswer)\tP(goldAnswer)"); } } } Counter contingency = new ClassicCounter<>(); // store tp,fp,fn,tn for (int i = 0, sz = test.size(); i < sz; i++) { testExample(cl, test, lineInfos, contingency, i); } if (globalFlags.groupingColumn >= 0 && globalFlags.rankingAccuracyClass != null) finishRanking(contingency, bestSim); return writeResultsSummary(test.size(), contingency, cl.labels()); } private void testExample(Classifier cl, GeneralDataset test, List lineInfos, Counter contingency, int i) { String[] example = lineInfos.get(i); Datum d; if (globalFlags.usesRealValues) { d = test.getRVFDatum(i); } else { d = test.getDatum(i); } if (globalFlags.justify) { logger.info("### Test item " + i); String message = ""; for (String field : example) { message += field + "\t"; } logger.info(message + "\n"); if (cl instanceof LinearClassifier) { ((LinearClassifier)cl).justificationOf(d); } logger.info(""); } Counter logScores; if (globalFlags.usesRealValues) { logScores = ErasureUtils.>uncheckedCast(cl).scoresOf((RVFDatum)d); } else { logScores = cl.scoresOf(d); } Distribution dist = Distribution.distributionFromLogisticCounter(logScores); String answer = null; if (globalFlags.biasedHyperplane != null) { // logger.info("Biased using counter: " + // globalFlags.biasedHyperplane); List biggestKeys = new ArrayList<>(logScores.keySet()); Collections.sort(biggestKeys, Counters.toComparatorDescending(logScores)); for (String key : biggestKeys) { double prob = dist.probabilityOf(key); double threshold = globalFlags.biasedHyperplane.getCount(key); // logger.info(" Trying " + key + " prob is " + prob + // " threshold is " + threshold); if (prob > threshold) { answer = key; break; } } } if (answer == null) { if (globalFlags.usesRealValues) { answer = ErasureUtils.>uncheckedCast(cl).classOf((RVFDatum) d); } else { answer = cl.classOf(d); } } double sim = 0.0; if (globalFlags.rankingScoreColumn >= 0) { try { sim = Double.parseDouble(example[globalFlags.rankingScoreColumn]); } catch (NumberFormatException nfe) { // just don't print it } } if ( ! (globalFlags.crossValidationFolds > 0 && ! globalFlags.printCrossValidationDecisions)) { if (globalFlags.csvOutput != null) { System.out.print(formatCsv(globalFlags.csvOutput, example, answer)); } else { writeAnswer(example, answer, dist); } } updatePerformanceStatistics(example, answer, dist, contingency, cl, sim); } private String formatCsv(String format, String[] fields, String answer) { StringBuilder out = new StringBuilder(); for (int i = 0, len = format.length(); i < len; i++) { char ch = format.charAt(i); if (ch == '%' && i + 1 < len) { char ch2 = format.charAt(i + 1); if (ch2 >= '0' && ch2 <= '9') { int field = ch2 - '0'; if (field < fields.length) { out.append(fields[field]); } else { throw new IllegalArgumentException("Not enough columns for format " + format); } } else if (ch2 == 'c') { if (answer != null) { out.append(answer); } else if (globalFlags.goldAnswerColumn < fields.length) { out.append(fields[globalFlags.goldAnswerColumn]); } else { out.append("Class"); } } else if (ch2 == 'n') { out.append('\n'); } else { throw new IllegalArgumentException("Unrecognized format specification in " + format); } i++; // have also dealt with next character giving format } else { out.append(ch); } } return out.toString(); } /** * Extracts all the features from a certain input datum. * * @param strs The data String[] to extract features from * @return The constructed Datum */ private Datum makeDatum(String[] strs) { String goldAnswer = globalFlags.goldAnswerColumn < strs.length ? strs[globalFlags.goldAnswerColumn]: ""; List theFeatures = new ArrayList<>(); Collection globalFeatures = Generics.newHashSet(); if (globalFlags.useClassFeature) { globalFeatures.add("CLASS"); } addAllInterningAndPrefixing(theFeatures, globalFeatures, ""); for (int i = 0; i < flags.length; i++) { Collection featuresC = Generics.newHashSet();//important that this is a hash set to prevent same feature from being added multiple times makeDatum(strs[i], flags[i], featuresC, goldAnswer); addAllInterningAndPrefixing(theFeatures, featuresC, i + "-"); } if (globalFlags.printFeatures != null) { printFeatures(strs, theFeatures); } //System.out.println("Features are: " + theFeatures); return new BasicDatum<>(theFeatures, goldAnswer); } /** * Extracts all the features from a certain input array and makes * a real valued feature datum; those features that are not real valued * are given value 1.0. * * @param strs The data String[] to extract features from * @return The constructed RVFDatum */ private RVFDatum makeRVFDatum(String[] strs) { String goldAnswer = globalFlags.goldAnswerColumn < strs.length ? strs[globalFlags.goldAnswerColumn]: ""; ClassicCounter theFeatures = new ClassicCounter<>(); ClassicCounter globalFeatures = new ClassicCounter<>(); if (globalFlags.useClassFeature) { globalFeatures.setCount("CLASS", 1.0); } addAllInterningAndPrefixingRVF(theFeatures, globalFeatures, ""); for (int i = 0; i < flags.length; i++) { ClassicCounter featuresC = new ClassicCounter<>(); makeDatum(strs[i], flags[i], featuresC, goldAnswer); addAllInterningAndPrefixingRVF(theFeatures, featuresC, i + "-"); } if (globalFlags.printFeatures != null) { printFeatures(strs, theFeatures); } //System.out.println("Features are: " + theFeatures); return new RVFDatum<>(theFeatures, goldAnswer); } private void addAllInterningAndPrefixingRVF(ClassicCounter accumulator, ClassicCounter addend, String prefix) { assert prefix != null; for (String protoFeat : addend.keySet()) { double count = addend.getCount(protoFeat); if ( ! prefix.isEmpty()) { protoFeat = prefix + protoFeat; } if (globalFlags.intern) { protoFeat = protoFeat.intern(); } accumulator.incrementCount(protoFeat, count); } } private void addAllInterningAndPrefixing(Collection accumulator, Collection addend, String prefix) { assert prefix != null; for (String protoFeat : addend) { if ( ! prefix.isEmpty()) { protoFeat = prefix + protoFeat; } if (globalFlags.intern) { protoFeat = protoFeat.intern(); } accumulator.add(protoFeat); } } /** * This method takes care of adding features to the collection-ish object features when * the value of the feature must be parsed as a real number, including performing * math transforms. */ private static void addFeatureValue(String cWord, Flags flags, Object featuresC) { double value = Double.valueOf(cWord); if (flags.logTransform) { double log = Math.log(value); if(Double.isInfinite(log) || Double.isNaN(log)) { logger.info("WARNING: Log transform attempted on out of range value; feature ignored"); } else addFeature(featuresC, "Log", log); } else if(flags.logitTransform) { double logit = Math.log(value/(1-value)); if(Double.isInfinite(logit) || Double.isNaN(logit)) { logger.info("WARNING: Logit transform attempted on out of range value; feature ignored"); } else { addFeature(featuresC, "Logit", logit); } } else if(flags.sqrtTransform) { double sqrt = Math.sqrt(value); addFeature(featuresC, "Sqrt", sqrt); } else { addFeature(featuresC, Flags.realValuedFeaturePrefix, value); } } /** * This method takes care of adding features to the collection-ish object features via * instanceof checks. Features must be a type of collection or a counter, and value is used * iff it is a counter */ private static void addFeature(Object features, F newFeature, double value) { if (features instanceof Counter) { ErasureUtils.>uncheckedCast(features).setCount(newFeature, value); } else if(features instanceof Collection) { ErasureUtils.>uncheckedCast(features).add(newFeature); } else { throw new RuntimeException("addFeature was called with a features object that is neither a counter nor a collection!"); } } /** * Extracts all the features from a certain input column. * * @param cWord The String to extract data from * @param goldAns The goldAnswer for this whole datum or emptyString if none. * This is used only for filling in the binned lengths histogram counters */ private void makeDatum(String cWord, Flags flags, Object featuresC, String goldAns) { //logger.info("Making features for " + cWord + " flags " + flags); if (flags == null) { // no features for this column return; } if (flags.filename) { cWord = IOUtils.slurpFileNoExceptions(cWord); } if (flags.lowercase) { cWord = cWord.toLowerCase(Locale.ENGLISH); } if (flags.useString) { addFeature(featuresC,"S-"+cWord,DEFAULT_VALUE); } if (flags.binnedLengths != null) { int len = cWord.length(); String featureName = null; for (int i = 0; i <= flags.binnedLengths.length; i++) { if (i == flags.binnedLengths.length || len <= flags.binnedLengths[i]) { featureName = "Len-" + ((i == 0) ? 0 : (flags.binnedLengths[i - 1] + 1)) + '-' + ((i == flags.binnedLengths.length) ? "Inf" : Integer.toString(flags.binnedLengths[i])); if (flags.binnedLengthsCounter != null) { flags.binnedLengthsCounter.incrementCount(featureName, goldAns); } break; } } addFeature(featuresC,featureName,DEFAULT_VALUE); } if (flags.binnedValues != null) { double val = flags.binnedValuesNaN; try { val = Double.parseDouble(cWord); } catch (NumberFormatException nfe) { // do nothing -- keeps value of flags.binnedValuesNaN } String featureName = null; for (int i = 0; i <= flags.binnedValues.length; i++) { if (i == flags.binnedValues.length || val <= flags.binnedValues[i]) { featureName = "Val-(" + ((i == 0) ? "-Inf" : Double.toString(flags.binnedValues[i - 1])) + ',' + ((i == flags.binnedValues.length) ? "Inf" : Double.toString(flags.binnedValues[i])) + ']'; if (flags.binnedValuesCounter != null) { flags.binnedValuesCounter.incrementCount(featureName, goldAns); } break; } } addFeature(featuresC,featureName,DEFAULT_VALUE); } if (flags.countChars != null) { int[] cnts = new int[flags.countChars.length]; for (int i = 0; i < cnts.length; i++) { cnts[i] = 0; } for (int i = 0, len = cWord.length(); i < len; i++) { char ch = cWord.charAt(i); for (int j = 0; j < cnts.length; j++) { if (ch == flags.countChars[j]) { cnts[j]++; } } } for (int j = 0; j < cnts.length; j++) { String featureName = null; for (int i = 0; i <= flags.countCharsBins.length; i++) { if (i == flags.countCharsBins.length || cnts[j] <= flags.countCharsBins[i]) { featureName = "Char-" + flags.countChars[j] + '-' + ((i == 0) ? 0 : (flags.countCharsBins[i - 1] + 1)) + '-' + ((i == flags.countCharsBins.length) ? "Inf" : Integer.toString(flags.countCharsBins[i])); break; } } addFeature(featuresC,featureName,DEFAULT_VALUE); } } if (flags.splitWordsPattern != null || flags.splitWordsTokenizerPattern != null || flags.splitWordsWithPTBTokenizer) { String[] bits; if (flags.splitWordsTokenizerPattern != null) { bits = regexpTokenize(flags.splitWordsTokenizerPattern, flags.splitWordsIgnorePattern, cWord); } else if (flags.splitWordsPattern != null) { bits = splitTokenize(flags.splitWordsPattern, flags.splitWordsIgnorePattern, cWord); } else { //PTB tokenizer bits = ptbTokenize(cWord); } if (flags.showTokenization) { logger.info("Tokenization: " + Arrays.toString(bits)); } if (flags.splitWordCount) { addFeature(featuresC, "SWNUM", bits.length); } if (flags.logSplitWordCount) { addFeature(featuresC, "LSWNUM", Math.log(bits.length)); } if (flags.binnedSplitWordCounts != null) { String featureName = null; for (int i = 0; i <= flags.binnedSplitWordCounts.length; i++) { if (i == flags.binnedSplitWordCounts.length || bits.length <= flags.binnedSplitWordCounts[i]) { featureName = "SWNUMBIN-" + ((i == 0) ? 0 : (flags.binnedSplitWordCounts[i - 1] + 1)) + '-' + ((i == flags.binnedSplitWordCounts.length) ? "Inf" : Integer.toString(flags.binnedSplitWordCounts[i])); break; } } addFeature(featuresC, featureName, DEFAULT_VALUE); } // add features over splitWords for (int i = 0; i < bits.length; i++) { if (flags.useSplitWords) { addFeature(featuresC, "SW-" + bits[i], DEFAULT_VALUE); } if (flags.useLowercaseSplitWords) { addFeature(featuresC, "LSW-" + bits[i].toLowerCase(), DEFAULT_VALUE); } if (flags.useSplitWordPairs) { if (i + 1 < bits.length) { addFeature(featuresC, "SWP-" + bits[i] + '-' + bits[i + 1], DEFAULT_VALUE); } } if (flags.useLowercaseSplitWordPairs) { if (i + 1 < bits.length) { addFeature(featuresC, "LSWP-" + bits[i].toLowerCase() + '-' + bits[i + 1].toLowerCase(), DEFAULT_VALUE); } } if (flags.useAllSplitWordPairs) { for (int j = i + 1; j < bits.length; j++) { // sort lexicographically if (bits[i].compareTo(bits[j]) < 0) { addFeature(featuresC, "ASWP-" + bits[i] + '-' + bits[j], DEFAULT_VALUE); } else { addFeature(featuresC, "ASWP-" + bits[j] + '-' + bits[i], DEFAULT_VALUE); } } } if (flags.useAllSplitWordTriples) { for (int j = i + 1; j < bits.length; j++) { for (int k = j + 1; k < bits.length; k++) { // sort lexicographically String[] triple = new String[3]; triple[0] = bits[i]; triple[1] = bits[j]; triple[2] = bits[k]; Arrays.sort(triple); addFeature(featuresC, "ASWT-" + triple[0] + '-' + triple[1] + '-' + triple[2], DEFAULT_VALUE); } } } if (flags.useSplitWordNGrams) { StringBuilder sb = new StringBuilder("SW#"); for (int j = i; j < i+flags.minWordNGramLeng-1 && j < bits.length; j++) { sb.append('-'); sb.append(bits[j]); } int maxIndex = (flags.maxWordNGramLeng > 0)? Math.min(bits.length, i + flags.maxWordNGramLeng): bits.length; for (int j = i + flags.minWordNGramLeng-1; j < maxIndex; j++) { if (flags.wordNGramBoundaryRegexp != null) { if (flags.wordNGramBoundaryPattern.matcher(bits[j]).matches()) { break; } } sb.append('-'); sb.append(bits[j]); addFeature(featuresC, sb.toString(), DEFAULT_VALUE); } } // this is equivalent to having boundary tokens in splitWordPairs -- they get a special feature if (flags.useSplitFirstLastWords) { if (i == 0) { addFeature(featuresC,"SFW-" + bits[i], DEFAULT_VALUE); } else if (i == bits.length - 1) { addFeature(featuresC,"SLW-" + bits[i], DEFAULT_VALUE); } } if (flags.useLowercaseSplitFirstLastWords) { if (i == 0) { addFeature(featuresC,"LSFW-" + bits[i].toLowerCase(), DEFAULT_VALUE); } else if (i == bits.length - 1) { addFeature(featuresC,"SLW-" + bits[i].toLowerCase(), DEFAULT_VALUE); } } if (flags.useSplitNGrams || flags.useSplitPrefixSuffixNGrams) { Collection featureNames = makeNGramFeatures(bits[i], flags, true, "S#"); for(String featureName : featureNames) addFeature(featuresC, featureName, DEFAULT_VALUE); } if (flags.splitWordShape > edu.stanford.nlp.process.WordShapeClassifier.NOWORDSHAPE) { String shape = edu.stanford.nlp.process.WordShapeClassifier.wordShape(bits[i], flags.splitWordShape); // logger.info("Shaper is " + flags.splitWordShape + " word len " + bits[i].length() + " shape is " + shape); addFeature(featuresC,"SSHAPE-" + shape,DEFAULT_VALUE); } } // for bits if (flags.wordVectors != null) { double[] averages = null; for (String bit : bits) { float[] wv = flags.wordVectors.get(bit); if (wv != null) { if (averages == null) { averages = new double[wv.length]; for (int j = 0; j < wv.length; j++) { averages[j] += wv[j]; } } } } if (averages != null) { for (int j = 0; j < averages.length; j++) { averages[j] /= bits.length; addFeature(featuresC, "SWV-" + j, averages[j]); } // } else { // logger.info("No word vectors found for words in |" + cWord + '|'); } } // end if wordVectors } // end if uses some split words features if (flags.wordShape > WordShapeClassifier.NOWORDSHAPE) { String shape = edu.stanford.nlp.process.WordShapeClassifier.wordShape(cWord, flags.wordShape); addFeature(featuresC, "SHAPE-" + shape, DEFAULT_VALUE); } if (flags.useNGrams || flags.usePrefixSuffixNGrams) { Collection featureNames = makeNGramFeatures(cWord, flags, false, "#"); for(String featureName : featureNames) addFeature(featuresC,featureName,DEFAULT_VALUE); } if (flags.isRealValued || flags.logTransform || flags.logitTransform || flags.sqrtTransform) { addFeatureValue(cWord, flags, featuresC); } //logger.info("Made featuresC " + featuresC); } //end makeDatum //return the tokens using PTB tokenizer private String[] ptbTokenize(String cWord) { if(ptbFactory==null) ptbFactory = PTBTokenizer.factory(); Tokenizer tokenizer = ptbFactory.getTokenizer(new StringReader(cWord)); List words = tokenizer.tokenize(); String[] res = new String[words.size()]; for(int i = 0; i < words.size(); ++i) { res[i] = words.get(i).word(); } return res; } /** * Caches a hash of word to all substring features. A lot of memory! * If the String space is large, you shouldn't turn this on. */ private static final Map> wordToSubstrings = new ConcurrentHashMap<>(); private String intern(String s) { if (globalFlags.intern) { return s.intern(); } return s; } /** * Return a Collection of NGrams from the input String. */ private Collection makeNGramFeatures(final String input, Flags flags, boolean useSplit, String featPrefix) { String toNGrams = input; boolean internalNGrams; boolean prefixSuffixNGrams; if (useSplit) { internalNGrams = flags.useSplitNGrams; prefixSuffixNGrams = flags.useSplitPrefixSuffixNGrams; } else { internalNGrams = flags.useNGrams; prefixSuffixNGrams = flags.usePrefixSuffixNGrams; } if (flags.lowercaseNGrams) { toNGrams = toNGrams.toLowerCase(Locale.ENGLISH); } if (flags.partialNGramRegexp != null) { Matcher m = flags.partialNGramPattern.matcher(toNGrams); // log.info("Matching |" + flags.partialNGramRegexp + // "| against |" + toNGrams + "|"); if (m.find()) { if (m.groupCount() > 0) { toNGrams = m.group(1); } else { toNGrams = m.group(); } // log.info(" Matched |" + toNGrams + "|"); } // logger.info(); } Collection subs = null; if (flags.cacheNGrams) { subs = wordToSubstrings.get(toNGrams); } if (subs == null) { subs = new ArrayList<>(); String strN = featPrefix + '-'; String strB = featPrefix + "B-"; String strE = featPrefix + "E-"; int wleng = toNGrams.length(); for (int i = 0; i < wleng; i++) { for (int j = i + flags.minNGramLeng, min = Math.min(wleng, i + flags.maxNGramLeng); j <= min; j++) { if (prefixSuffixNGrams) { if (i == 0) { subs.add(intern(strB + toNGrams.substring(i, j))); } if (j == wleng) { subs.add(intern(strE + toNGrams.substring(i, j))); } } if (internalNGrams) { subs.add(intern(strN + toNGrams.substring(i, j))); } } } if (flags.cacheNGrams) { wordToSubstrings.put(toNGrams, subs); } } return subs; } private static PrintWriter cliqueWriter; private static void newFeaturePrinter(String prefix, String suffix, String encoding) { if (cliqueWriter != null) { closeFeaturePrinter(); } try { cliqueWriter = IOUtils.getPrintWriter(prefix + '.' + suffix, encoding); } catch (IOException ioe) { cliqueWriter = null; } } private static void closeFeaturePrinter() { cliqueWriter.close(); cliqueWriter = null; } private static void printFeatures(String[] wi, ClassicCounter features) { if (cliqueWriter != null) { for (int i = 0; i < wi.length; i++) { if (i > 0) { cliqueWriter.print("\t"); } cliqueWriter.print(wi[i]); } for (String feat : features.keySet()) { cliqueWriter.print("\t"); cliqueWriter.print(feat); cliqueWriter.print("\t"); cliqueWriter.print(features.getCount(feat)); } cliqueWriter.println(); } } private static void printFeatures(String[] wi, List features) { if (cliqueWriter != null) { for (int i = 0; i < wi.length; i++) { if (i > 0) { cliqueWriter.print("\t"); } cliqueWriter.print(wi[i]); } for (String feat : features) { cliqueWriter.print("\t"); cliqueWriter.print(feat); } cliqueWriter.println(); } } /** * Creates a classifier from training data. A specialized training regimen. * It searches for an appropriate l1reg parameter to use to get specified number of features. * This is called from makeClassifier() when certain properties are set. * * @param train training data * @return trained classifier */ private Classifier makeClassifierAdaptL1(GeneralDataset train) { assert(globalFlags.useAdaptL1 && globalFlags.limitFeatures > 0); Classifier lc = null; double l1reg = globalFlags.l1reg; double l1regmax = globalFlags.l1regmax; double l1regmin = globalFlags.l1regmin; if (globalFlags.l1reg <= 0.0) { logger.info("WARNING: useAdaptL1 set and limitFeatures to " + globalFlags.limitFeatures + ", but invalid value of l1reg=" + globalFlags.l1reg + ", defaulting to " + globalFlags.l1regmax); l1reg = l1regmax; } else { logger.info("TRAIN: useAdaptL1 set and limitFeatures to " + globalFlags.limitFeatures + ", l1reg=" + globalFlags.l1reg + ", l1regmax=" + globalFlags.l1regmax + ", l1regmin=" + globalFlags.l1regmin); } Set limitFeatureLabels = null; if (globalFlags.limitFeaturesLabels != null) { String[] labels = globalFlags.limitFeaturesLabels.split(","); limitFeatureLabels = Generics.newHashSet(); for (String label:labels) { limitFeatureLabels.add(label.trim()); } } // Do binary search starting with specified l1reg to find reasonable value of l1reg that gives desired number of features double l1regtop = l1regmax; double l1regbottom = l1regmin; int limitFeatureTol = 5; double l1regminchange = 0.05; while (true) { logger.info("Training: l1reg=" + l1reg + ", threshold=" + globalFlags.featureWeightThreshold + ", target=" + globalFlags.limitFeatures); LinearClassifierFactory lcf; Minimizer minim = ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg); lcf = new LinearClassifierFactory<>(minim, globalFlags.tolerance, globalFlags.useSum, globalFlags.prior, globalFlags.sigma, globalFlags.epsilon); int featureCount = -1; try { LinearClassifier c = lcf.trainClassifier(train); lc = c; featureCount = c.getFeatureCount(limitFeatureLabels, globalFlags.featureWeightThreshold, false /*useMagnitude*/); logger.info("Training Done: l1reg=" + l1reg + ", threshold=" + globalFlags.featureWeightThreshold + ", features=" + featureCount + ", target=" + globalFlags.limitFeatures); // String classifierDesc = c.toString(globalFlags.printClassifier, globalFlags.printClassifierParam); List> topFeatures = c.getTopFeatures( limitFeatureLabels, globalFlags.featureWeightThreshold, false /*useMagnitude*/, globalFlags.limitFeatures, true /*descending order*/); String classifierDesc = c.topFeaturesToString(topFeatures); logger.info("Printing top " + globalFlags.limitFeatures + " features with weights above " + globalFlags.featureWeightThreshold); if (globalFlags.limitFeaturesLabels != null) { logger.info(" Limited to labels: " + globalFlags.limitFeaturesLabels); } logger.info(classifierDesc); } catch (RuntimeException ex) { if (ex.getMessage() != null && ex.getMessage().startsWith("L-BFGS chose a non-descent direction")) { logger.info("Error in optimization, will try again with different l1reg"); ex.printStackTrace(System.err); } else { throw ex; } } if (featureCount < 0 || featureCount < globalFlags.limitFeatures - limitFeatureTol) { // Too few features or some other bad thing happened => decrease l1reg l1regtop = l1reg; l1reg = 0.5*(l1reg + l1regbottom); if (l1regtop - l1reg < l1regminchange) { logger.info("Stopping: old l1reg " + l1regtop + "- new l1reg " + l1reg + ", difference less than " + l1regminchange); break; } } else if (featureCount > globalFlags.limitFeatures + limitFeatureTol) { // Too many features => increase l1reg l1regbottom = l1reg; l1reg = 0.5*(l1reg + l1regtop); if (l1reg - l1regbottom < l1regminchange) { logger.info("Stopping: new l1reg " + l1reg + "- old l1reg " + l1regbottom + ", difference less than " + l1regminchange); break; } } else { logger.info("Stopping: # of features within " + limitFeatureTol + " of target"); break; } } // Update flags for later serialization globalFlags.l1reg = l1reg; return lc; } /** * Creates a classifier from training data. * * @param train training data * @return trained classifier */ public Classifier makeClassifier(GeneralDataset train) { Classifier lc; if (globalFlags.useClassifierFactory != null) { ClassifierFactory> cf; if (globalFlags.classifierFactoryArgs != null) { cf = ReflectionLoading.loadByReflection(globalFlags.useClassifierFactory, globalFlags.classifierFactoryArgs); } else { cf = ReflectionLoading.loadByReflection(globalFlags.useClassifierFactory); } lc = cf.trainClassifier(train); } else if (globalFlags.useNB) { double sigma = (globalFlags.prior == 0) ? 0.0 : globalFlags.sigma; lc = new NBLinearClassifierFactory(sigma, globalFlags.useClassFeature).trainClassifier(train); } else if (globalFlags.useBinary) { LogisticClassifierFactory lcf = new LogisticClassifierFactory<>(); LogPrior prior = new LogPrior(globalFlags.prior, globalFlags.sigma, globalFlags.epsilon); lc = lcf.trainClassifier(train, globalFlags.l1reg, globalFlags.tolerance, prior, globalFlags.biased); } else if (globalFlags.biased) { LogisticClassifierFactory lcf = new LogisticClassifierFactory<>(); LogPrior prior = new LogPrior(globalFlags.prior, globalFlags.sigma, globalFlags.epsilon); lc = lcf.trainClassifier(train, prior, true); } else if (globalFlags.useAdaptL1 && globalFlags.limitFeatures > 0) { lc = makeClassifierAdaptL1(train); } else { LinearClassifierFactory lcf; if (globalFlags.l1reg > 0.0) { Minimizer minim = ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", globalFlags.l1reg); lcf = new LinearClassifierFactory<>(minim, globalFlags.tolerance, globalFlags.useSum, globalFlags.prior, globalFlags.sigma, globalFlags.epsilon); } else { lcf = new LinearClassifierFactory<>(globalFlags.tolerance, globalFlags.useSum, globalFlags.prior, globalFlags.sigma, globalFlags.epsilon, globalFlags.QNsize); } lcf.setVerbose(globalFlags.verboseOptimization); if ( ! globalFlags.useQN) { lcf.useConjugateGradientAscent(); } lc = lcf.trainClassifier(train); } return lc; } private static String[] regexpTokenize(Pattern tokenizerRegexp, Pattern ignoreRegexp, String inWord) { List al = new ArrayList<>(); String word = inWord; while ( ! word.isEmpty()) { // logger.info("String to match on is " + word); Matcher mig = null; if (ignoreRegexp != null) { mig = ignoreRegexp.matcher(word); } if (mig != null && mig.lookingAt()) { word = word.substring(mig.end()); } else { Matcher m = tokenizerRegexp.matcher(word); if (m.lookingAt()) { // Logging.logger(ColumnDataClassifier.class).info("Matched " + m.end() + " chars: " + // word.substring(0, m.end())); al.add(word.substring(0, m.end())); word = word.substring(m.end()); } else { logger.info("Warning: regexpTokenize pattern " + tokenizerRegexp + " didn't match on |" + word.substring(0, 1) + "| of |" + word + '|'); // logger.info("Default matched 1 char: " + // word.substring(0, 1)); al.add(word.substring(0, 1)); word = word.substring(1); } } } String[] bits = al.toArray(new String[al.size()]); return bits; } private static String[] splitTokenize(Pattern splitRegexp, Pattern ignoreRegexp, String cWord) { String[] bits = splitRegexp.split(cWord); if (ignoreRegexp != null) { List keepBits = new ArrayList<>(bits.length); for (String bit : bits) { if ( ! ignoreRegexp.matcher(bit).matches()) { keepBits.add(bit); } } if (keepBits.size() != bits.length) { bits = new String[keepBits.size()]; keepBits.toArray(bits); } } return bits; } private static Map loadWordVectors(String filename) { Timing timing = new Timing(); Map map = new HashMap<>(10000); // presumably they'll load a fair-sized vocab!? BufferedReader br = null; try { br = IOUtils.readerFromString(filename); int numDimensions = -1; boolean warned = false; for (String line; (line = br.readLine()) != null; ) { String[] fields = line.split("\\s+"); if (numDimensions < 0) { numDimensions = fields.length - 1; } else { if (numDimensions != fields.length -1 && ! warned) { logger.info("loadWordVectors: Inconsistent vector size: " + numDimensions + " vs. " + (fields.length - 1)); warned = true; } } float[] vector = new float[fields.length - 1]; for (int i = 1; i < fields.length; i++) { vector[i-1] = Float.parseFloat(fields[i]); } map.put(fields[0], vector); } } catch (IOException ioe) { throw new RuntimeIOException("Couldn't load word vectors", ioe); } finally { IOUtils.closeIgnoringExceptions(br); } timing.done("Loading word vectors from " + filename + " ... "); return map; } /** * Initialize using values in Properties file. * * @param props Properties, with the special format of column.flag used in ColumnDataClassifier * @return An array of flags for each data column, with additional global flags in element [0] */ private Flags[] setProperties(Properties props) { Flags[] myFlags; boolean myUsesRealValues = false; Pattern prefix; try { prefix = Pattern.compile("([0-9]+)\\.(.*)"); } catch (PatternSyntaxException pse) { throw new RuntimeException(pse); } // if we are loading a classifier then we have to load it before we do anything // else and have its Properties be the defaults that can then be overridden by // other command-line arguments String loadPath = props.getProperty("loadClassifier"); if (loadPath != null) { logger.info("Loading classifier from " + loadPath + "..."); ObjectInputStream ois = null; try { // load the classifier ois = IOUtils.readStreamFromString(loadPath); classifier = ErasureUtils.>uncheckedCast(ois.readObject()); myFlags = (Flags[]) ois.readObject(); assert myFlags.length > 0; logger.info("Done."); } catch (Exception e) { throw new RuntimeIOException("Error deserializing " + loadPath, e); } finally { IOUtils.closeIgnoringExceptions(ois); } } else { myFlags = new Flags[1]; myFlags[0] = new Flags(); // initialize zero column flags used for global flags; it can't be null } for (String key : props.stringPropertyNames()) { String val = props.getProperty(key); int col = 0; // the default (first after class) logger.info(key + " = " + val); Matcher matcher = prefix.matcher(key); if (matcher.matches()) { col = Integer.parseInt(matcher.group(1)); key = matcher.group(2); } if (col >= myFlags.length) { myFlags = Arrays.copyOf(myFlags, col + 1); } if (myFlags[col] == null) { myFlags[col] = new Flags(); } if (key.equals("useString")) { myFlags[col].useString = Boolean.parseBoolean(val); } else if (key.equals("binnedLengths")) { if (val != null) { String[] binnedLengthStrs = val.split("[, ]+"); myFlags[col].binnedLengths = new int[binnedLengthStrs.length]; for (int i = 0; i < myFlags[col].binnedLengths.length; i++) { myFlags[col].binnedLengths[i] = Integer.parseInt(binnedLengthStrs[i]); } } } else if (key.equals("binnedLengthsStatistics")) { if (Boolean.parseBoolean(val)) { myFlags[col].binnedLengthsCounter = new TwoDimensionalCounter<>(); } } else if (key.equals("splitWordCount")) { myFlags[col].splitWordCount = Boolean.parseBoolean(val); } else if (key.equals("logSplitWordCount")) { myFlags[col].logSplitWordCount = Boolean.parseBoolean(val); } else if (key.equals("binnedSplitWordCounts")) { if (val != null) { String[] binnedSplitWordCountStrs = val.split("[, ]+"); myFlags[col].binnedSplitWordCounts = new int[binnedSplitWordCountStrs.length]; for (int i = 0; i < myFlags[col].binnedSplitWordCounts.length; i++) { myFlags[col].binnedSplitWordCounts[i] = Integer.parseInt(binnedSplitWordCountStrs[i]); } } } else if (key.equals("countChars")) { myFlags[col].countChars = val.toCharArray(); } else if (key.equals("countCharsBins")) { if (val != null) { String[] binnedCountStrs = val.split("[, ]+"); myFlags[col].countCharsBins = new int[binnedCountStrs.length]; for (int i = 0; i < binnedCountStrs.length; i++) { myFlags[col].countCharsBins[i] = Integer.parseInt(binnedCountStrs[i]); } } } else if (key.equals("binnedValues")) { if (val != null) { String[] binnedValuesStrs = val.split("[, ]+"); myFlags[col].binnedValues = new double[binnedValuesStrs.length]; for (int i = 0; i < myFlags[col].binnedValues.length; i++) { myFlags[col].binnedValues[i] = Double.parseDouble(binnedValuesStrs[i]); } } } else if (key.equals("binnedValuesNaN")) { myFlags[col].binnedValuesNaN = Double.parseDouble(val); } else if (key.equals("binnedValuesStatistics")) { if (Boolean.parseBoolean(val)) { myFlags[col].binnedValuesCounter = new TwoDimensionalCounter<>(); } } else if (key.equals("useNGrams")) { myFlags[col].useNGrams = Boolean.parseBoolean(val); } else if (key.equals("usePrefixSuffixNGrams")) { myFlags[col].usePrefixSuffixNGrams = Boolean.parseBoolean(val); } else if (key.equals("useSplitNGrams")) { myFlags[col].useSplitNGrams = Boolean.parseBoolean(val); } else if (key.equals("wordShape")) { myFlags[col].wordShape = WordShapeClassifier.lookupShaper(val); } else if (key.equals("splitWordShape")) { myFlags[col].splitWordShape = WordShapeClassifier.lookupShaper(val); } else if (key.equals("useSplitPrefixSuffixNGrams")) { myFlags[col].useSplitPrefixSuffixNGrams = Boolean.parseBoolean(val); } else if (key.equals("lowercaseNGrams")) { myFlags[col].lowercaseNGrams = Boolean.parseBoolean(val); } else if (key.equals("lowercase")) { myFlags[col].lowercase = Boolean.parseBoolean(val); } else if (key.equals("useLowercaseSplitWords")) { myFlags[col].useLowercaseSplitWords = Boolean.parseBoolean(val); } else if (key.equals("useSum")) { myFlags[col].useSum = Boolean.parseBoolean(val); } else if (key.equals("tolerance")) { myFlags[col].tolerance = Double.parseDouble(val); } else if (key.equals("printFeatures")) { myFlags[col].printFeatures = val; } else if (key.equals("printClassifier")) { myFlags[col].printClassifier = val; } else if (key.equals("printClassifierParam")) { myFlags[col].printClassifierParam = Integer.parseInt(val); } else if (key.equals("exitAfterTrainingFeaturization")) { myFlags[col].exitAfterTrainingFeaturization = Boolean.parseBoolean(val); } else if (key.equals("intern") || key.equals("intern2")) { myFlags[col].intern = Boolean.parseBoolean(val); } else if (key.equals("cacheNGrams")) { myFlags[col].cacheNGrams = Boolean.parseBoolean(val); } else if (key.equals("useClassifierFactory")) { myFlags[col].useClassifierFactory = val; } else if (key.equals("classifierFactoryArgs")) { myFlags[col].classifierFactoryArgs = val; } else if (key.equals("useNB")) { myFlags[col].useNB = Boolean.parseBoolean(val); } else if (key.equals("useBinary")) { myFlags[col].useBinary = Boolean.parseBoolean(val); } else if (key.equals("l1reg")) { myFlags[col].l1reg = Double.parseDouble(val); } else if (key.equals("useAdaptL1")) { myFlags[col].useAdaptL1 = Boolean.parseBoolean(val); } else if (key.equals("limitFeatures")) { myFlags[col].limitFeatures = Integer.parseInt(val); } else if (key.equals("l1regmin")) { myFlags[col].l1regmin = Double.parseDouble(val); } else if (key.equals("l1regmax")) { myFlags[col].l1regmax = Double.parseDouble(val); } else if (key.equals("limitFeaturesLabels")) { myFlags[col].limitFeaturesLabels = val; } else if (key.equals("featureWeightThreshold")) { myFlags[col].featureWeightThreshold = Double.parseDouble(val); } else if (key.equals("useClassFeature")) { myFlags[col].useClassFeature = Boolean.parseBoolean(val); } else if (key.equals("featureMinimumSupport")) { myFlags[col].featureMinimumSupport = Integer.parseInt(val); } else if (key.equals("prior")) { if (val.equalsIgnoreCase("no")) { myFlags[col].prior = LogPrior.LogPriorType.NULL.ordinal(); } else if (val.equalsIgnoreCase("huber")) { myFlags[col].prior = LogPrior.LogPriorType.HUBER.ordinal(); } else if (val.equalsIgnoreCase("quadratic")) { myFlags[col].prior = LogPrior.LogPriorType.QUADRATIC.ordinal(); } else if (val.equalsIgnoreCase("quartic")) { myFlags[col].prior = LogPrior.LogPriorType.QUARTIC.ordinal(); } else { try { myFlags[col].prior = Integer.parseInt(val); } catch (NumberFormatException nfe) { logger.info("Unknown prior " + val + "; using none."); } } } else if (key.equals("sigma")) { myFlags[col].sigma = Double.parseDouble(val); } else if (key.equals("epsilon")) { myFlags[col].epsilon = Double.parseDouble(val); } else if (key.equals("maxNGramLeng")) { myFlags[col].maxNGramLeng = Integer.parseInt(val); } else if (key.equals("minNGramLeng")) { myFlags[col].minNGramLeng = Integer.parseInt(val); } else if (key.equals("partialNGramRegexp")) { myFlags[col].partialNGramRegexp = val; try { myFlags[col].partialNGramPattern = Pattern.compile(myFlags[col].partialNGramRegexp); } catch (PatternSyntaxException pse) { logger.info("Ill-formed partialNGramPattern: " + myFlags[col].partialNGramPattern); myFlags[col].partialNGramRegexp = null; } } else if (key.equals("splitWordsRegexp")) { try { myFlags[col].splitWordsPattern = Pattern.compile(val); } catch (PatternSyntaxException pse) { logger.info("Ill-formed splitWordsRegexp: " + val); } } else if (key.equals("splitWordsTokenizerRegexp")) { try { myFlags[col].splitWordsTokenizerPattern = Pattern.compile(val); } catch (PatternSyntaxException pse) { logger.info("Ill-formed splitWordsTokenizerRegexp: " + val); } } else if (key.equals("splitWordsIgnoreRegexp")) { String trimVal = val.trim(); if (trimVal.isEmpty()) { myFlags[col].splitWordsIgnorePattern = null; } else { try { myFlags[col].splitWordsIgnorePattern = Pattern.compile(trimVal); } catch (PatternSyntaxException pse) { logger.info("Ill-formed splitWordsIgnoreRegexp: " + trimVal); } } } else if (key.equals("useSplitWords")) { myFlags[col].useSplitWords = Boolean.parseBoolean(val); } else if (key.equals("useSplitWordPairs")) { myFlags[col].useSplitWordPairs = Boolean.parseBoolean(val); } else if (key.equals("useLowercaseSplitWordPairs")) { myFlags[col].useLowercaseSplitWordPairs = Boolean.parseBoolean(val); } else if (key.equals("useAllSplitWordPairs")) { myFlags[col].useAllSplitWordPairs = Boolean.parseBoolean(val); } else if (key.equals("useAllSplitWordTriples")) { myFlags[col].useAllSplitWordTriples = Boolean.parseBoolean(val); } else if (key.equals("useSplitWordNGrams")) { myFlags[col].useSplitWordNGrams = Boolean.parseBoolean(val); } else if (key.equals("maxWordNGramLeng")) { myFlags[col].maxWordNGramLeng = Integer.parseInt(val); } else if (key.equals("minWordNGramLeng")) { myFlags[col].minWordNGramLeng = Integer.parseInt(val); if (myFlags[col].minWordNGramLeng < 1) { logger.info("minWordNGramLeng set to " + myFlags[col].minWordNGramLeng + ", resetting to 1"); myFlags[col].minWordNGramLeng = 1; } } else if (key.equals("wordNGramBoundaryRegexp")) { myFlags[col].wordNGramBoundaryRegexp = val; try { myFlags[col].wordNGramBoundaryPattern = Pattern.compile(myFlags[col].wordNGramBoundaryRegexp); } catch (PatternSyntaxException pse) { logger.info("Ill-formed wordNGramBoundary regexp: " + myFlags[col].wordNGramBoundaryRegexp); myFlags[col].wordNGramBoundaryRegexp = null; } } else if (key.equals("useSplitFirstLastWords")) { myFlags[col].useSplitFirstLastWords = Boolean.parseBoolean(val); } else if (key.equals("useLowercaseSplitFirstLastWords")) { myFlags[col].useLowercaseSplitFirstLastWords = Boolean.parseBoolean(val); } else if (key.equals("loadClassifier")) { myFlags[col].loadClassifier = val; } else if (key.equals("serializeTo")) { Flags.serializeTo = val; } else if (key.equals("printTo")) { Flags.printTo = val; } else if (key.equals("trainFile")) { Flags.trainFile = val; } else if (key.equals("displayAllAnswers")) { Flags.displayAllAnswers = Boolean.parseBoolean(val); } else if (key.equals("testFile")) { myFlags[col].testFile = val; } else if (key.equals("trainFromSVMLight")) { Flags.trainFromSVMLight = Boolean.parseBoolean(val); } else if (key.equals("testFromSVMLight")) { Flags.testFromSVMLight = Boolean.parseBoolean(val); } else if (key.equals("encoding")) { Flags.encoding = val; } else if (key.equals("printSVMLightFormatTo")) { Flags.printSVMLightFormatTo = val; } else if (key.equals("displayedColumn")) { myFlags[col].displayedColumn = Integer.parseInt(val); } else if (key.equals("groupingColumn")) { myFlags[col].groupingColumn = Integer.parseInt(val); // logger.info("Grouping column is " + (myFlags[col].groupingColumn)); } else if (key.equals("rankingScoreColumn")) { myFlags[col].rankingScoreColumn = Integer.parseInt(val); // logger.info("Ranking score column is " + (myFlags[col].rankingScoreColumn)); } else if (key.equals("rankingAccuracyClass")) { myFlags[col].rankingAccuracyClass = val; } else if (key.equals("goldAnswerColumn")) { myFlags[col].goldAnswerColumn = Integer.parseInt(val); // logger.info("Gold answer column is " + (myFlags[col].goldAnswerColumn)); // it's a nuisance to print this when used programmatically } else if (key.equals("useQN")) { myFlags[col].useQN = Boolean.parseBoolean(val); } else if (key.equals("QNsize")) { myFlags[col].QNsize = Integer.parseInt(val); } else if (key.equals("featureFormat")) { myFlags[col].featureFormat = Boolean.parseBoolean(val); } else if (key.equals("significantColumnId")) { myFlags[col].significantColumnId = Boolean.parseBoolean(val); } else if (key.equals("justify")) { myFlags[col].justify = Boolean.parseBoolean(val); } else if (key.equals("verboseOptimization")) { myFlags[col].verboseOptimization = Boolean.parseBoolean(val); } else if (key.equals("realValued")) { myFlags[col].isRealValued = Boolean.parseBoolean(val); myUsesRealValues = myUsesRealValues || myFlags[col].isRealValued; } else if (key.equals("logTransform")) { myFlags[col].logTransform = Boolean.parseBoolean(val); myUsesRealValues = myUsesRealValues || myFlags[col].logTransform; } else if (key.equals("logitTransform")) { myFlags[col].logitTransform = Boolean.parseBoolean(val); myUsesRealValues = myUsesRealValues || myFlags[col].logitTransform; } else if (key.equals("sqrtTransform")) { myFlags[col].sqrtTransform = Boolean.parseBoolean(val); myUsesRealValues = myUsesRealValues || myFlags[col].sqrtTransform; } else if (key.equals("filename")) { myFlags[col].filename = Boolean.parseBoolean(val); } else if (key.equals("biased")) { myFlags[col].biased = Boolean.parseBoolean(val); } else if (key.equals("biasedHyperplane")) { // logger.info("Constraints is " + constraints); if (val != null && val.trim().length() > 0) { String[] bits = val.split("[, ]+"); myFlags[col].biasedHyperplane = new ClassicCounter<>(); for (int i = 0; i < bits.length; i += 2) { myFlags[col].biasedHyperplane.setCount(bits[i], Double.parseDouble(bits[i + 1])); } } // logger.info("Biased Hyperplane is " + biasedHyperplane); } else if (key.equals("crossValidationFolds")) { myFlags[col].crossValidationFolds = Integer.parseInt(val); } else if (key.equals("printCrossValidationDecisions")) { myFlags[col].printCrossValidationDecisions = Boolean.parseBoolean(val); } else if (key.equals("shuffleTrainingData")) { myFlags[col].shuffleTrainingData = Boolean.parseBoolean(val); } else if (key.equals("shuffleSeed")) { myFlags[col].shuffleSeed = Long.parseLong(val); } else if (key.equals("csvInput")) { myFlags[col].csvInput = Boolean.parseBoolean(val); } else if (key.equals("inputFormat")) { if (val.equalsIgnoreCase("header")) { myFlags[col].inputFormat = InputFormat.HEADER; } else if (val.equalsIgnoreCase("comments")) { myFlags[col].inputFormat = InputFormat.COMMENTS; } else if (val.equalsIgnoreCase("plain")) { myFlags[col].inputFormat = InputFormat.PLAIN; } else { logger.info("Unknown inputFormat: " + val); } } else if (key.equals("splitWordsWithPTBTokenizer")) { // System.out.println("splitting with ptb tokenizer"); myFlags[col].splitWordsWithPTBTokenizer = Boolean.parseBoolean(val); } else if (key.equals("useSplitWordVectors")) { myFlags[col].wordVectors = loadWordVectors(val); myUsesRealValues = true; } else if (key.equals("showTokenization")) { myFlags[col].showTokenization = Boolean.parseBoolean(val); } else if (key.equals("csvOutput")) { myFlags[col].csvOutput = val; } else if ( ! key.isEmpty() && ! key.equals("prop")) { logger.info("Unknown property: |" + key + '|'); } } myFlags[0].usesRealValues = myUsesRealValues; return myFlags; } /** Construct a ColumnDataClassifier. * * @param filename The file with properties which specifies all aspects of behavior. * See the class documentation for details of the properties. */ public ColumnDataClassifier(String filename) { this(StringUtils.propFileToProperties(filename)); } /** Construct a ColumnDataClassifier. * * @param props The properties object specifies all aspects of its behavior. * See the class documentation for details of the properties. */ public ColumnDataClassifier(Properties props) { flags = setProperties(props); globalFlags = flags[0]; } /** * Runs the ColumnDataClassifier from the command-line. Usage:

{@code * java edu.stanford.nlp.classify.ColumnDataClassifier -trainFile trainFile * -testFile testFile [-useNGrams|-useString|-sigma sigma|...] * }

or

{@code * java ColumnDataClassifier -prop propFile * } * * @param args Command line arguments, as described in the class * documentation * @throws IOException If IO problems */ public static void main(String[] args) throws IOException { StringUtils.logInvocationString(logger, args); // the constructor will load a classifier if one is specified with loadClassifier ColumnDataClassifier cdc = new ColumnDataClassifier(StringUtils.argsToProperties(args)); String testFile = cdc.globalFlags.testFile; // check that we have roughly sensible options or else warn and exit if ((testFile == null && Flags.serializeTo == null && cdc.globalFlags.crossValidationFolds < 2) || (Flags.trainFile == null && cdc.globalFlags.loadClassifier == null)) { logger.info("usage: java edu.stanford.nlp.classify.ColumnDataClassifier -prop propFile"); logger.info(" and/or: -trainFile trainFile -testFile testFile|-serializeTo modelFile [-useNGrams|-sigma sigma|...]"); return; // ENDS PROCESSING } if (cdc.globalFlags.loadClassifier == null) { // Otherwise we attempt to train one and exit if we don't succeed if ( ! cdc.trainClassifier()) { return; } } if (testFile != null) { cdc.testClassifier(testFile); } } // end main() private boolean trainClassifier() throws IOException { // build dataset of training data featurized Pair, List> dataInfo = readAndReturnTrainingExamples(Flags.trainFile); GeneralDataset train = dataInfo.first(); List lineInfos = dataInfo.second(); // For things like cross validation, we may well need to sort data! Data sets are often ordered by class. if (globalFlags.shuffleTrainingData) { long seed; if (globalFlags.shuffleSeed != 0) { seed = globalFlags.shuffleSeed; } else { seed = System.nanoTime(); } train.shuffleWithSideInformation(seed, lineInfos); } // print any binned value histograms for (int i = 0; i < flags.length; i++) { if (flags[i] != null && flags[i].binnedValuesCounter != null) { logger.info("BinnedValuesStatistics for column " + i); logger.info(flags[i].binnedValuesCounter.toString()); } } // print any binned length histograms for (int i = 0; i < flags.length; i++) { if (flags[i] != null && flags[i].binnedLengthsCounter != null) { logger.info("BinnedLengthsStatistics for column " + i); logger.info(flags[i].binnedLengthsCounter.toString()); } } // print the training data in SVMlight format if desired if (Flags.printSVMLightFormatTo != null) { PrintWriter pw = IOUtils.getPrintWriter(Flags.printSVMLightFormatTo, Flags.encoding); train.printSVMLightFormat(pw); IOUtils.closeIgnoringExceptions(pw); train.featureIndex().saveToFilename(Flags.printSVMLightFormatTo + ".featureIndex"); train.labelIndex().saveToFilename(Flags.printSVMLightFormatTo + ".labelIndex"); } if (globalFlags.crossValidationFolds > 1) { crossValidate(train, lineInfos); } if (globalFlags.exitAfterTrainingFeaturization) { return false; // ENDS PROCESSING } // build the classifier classifier = makeClassifier(train); printClassifier(classifier); // serialize the classifier String serializeTo = Flags.serializeTo; if (serializeTo != null) { logger.info("Serializing classifier to " + serializeTo + "..."); ObjectOutputStream oos = IOUtils.writeStreamFromString(serializeTo); oos.writeObject(classifier); // Fiddle: Don't write a testFile to the serialized classifier. It makes no sense and confuses people String testFile = globalFlags.testFile; globalFlags.testFile = null; oos.writeObject(flags); globalFlags.testFile = testFile; oos.close(); logger.info("Done."); } return true; } private void printClassifier(Classifier classifier) { String classString; if (classifier instanceof LinearClassifier) { classString = ((LinearClassifier)classifier).toString(globalFlags.printClassifier, globalFlags.printClassifierParam); } else { classString = classifier.toString(); } if (Flags.printTo != null) { PrintWriter fw = null; try { fw = IOUtils.getPrintWriter(Flags.printTo, Flags.encoding); fw.write(classString); fw.println(); } catch (IOException ioe) { logger.warn(ioe); } finally { IOUtils.closeIgnoringExceptions(fw); } logger.info("Built classifier described in file " + Flags.printTo); } else { logger.info("Built this classifier: " + classString); } } private void testClassifier(String testFile) { if (globalFlags.printFeatures != null) { newFeaturePrinter(globalFlags.printFeatures, "test", Flags.encoding); } Pair,List> testInfo = readTestExamples(testFile); GeneralDataset test = testInfo.first(); List lineInfos = testInfo.second(); testExamples(classifier, test, lineInfos); // ((LinearClassifier) classifier).dumpSorted(); if (globalFlags.printFeatures != null) { closeFeaturePrinter(); } } /** Run cross-validation on a dataset, and return accuracy and macro-F1 scores. * The number of folds is given by the crossValidationFolds property. * * @param dataset The dataset of examples to cross-validate on. * @param lineInfos The String form of the items in the dataset. (Must be present.) * @return Accuracy and macro F1 */ public Pair crossValidate(GeneralDataset dataset, List lineInfos) { final int numFolds = globalFlags.crossValidationFolds; double accuracySum = 0.0; double macroF1Sum = 0.0; for (int fold = 0; fold < numFolds; fold++) { logger.info(""); logger.info("### Fold " + fold); Pair,GeneralDataset> split = dataset.splitOutFold(fold, numFolds); GeneralDataset devTrain = split.first(); GeneralDataset devTest = split.second(); Classifier cl = makeClassifier(devTrain); printClassifier(cl); int normalFoldSize = lineInfos.size()/numFolds; int start = normalFoldSize * fold; int end = start + normalFoldSize; if (fold == (numFolds - 1)) { end = lineInfos.size(); } List devTestLineInfos = lineInfos.subList(start, end); Pair accuracies = testExamples(cl, devTest, devTestLineInfos); accuracySum += accuracies.first(); macroF1Sum += accuracies.second(); } double averageAccuracy = accuracySum / numFolds; double averageMacroF1 = macroF1Sum / numFolds; NumberFormat nf2 = new DecimalFormat("0.00000"); logger.info("Average accuracy/micro-averaged F1: " + nf2.format(averageAccuracy)); logger.info("Average macro-averaged F1: " + nf2.format(averageMacroF1)); logger.info(""); return new Pair<>(averageAccuracy, averageMacroF1); } public String classOf(Datum example) { if(classifier==null) throw new RuntimeException("Classifier is not initialized"); return classifier.classOf(example); } static class Flags implements Serializable { // PLEASE ADD NEW FLAGS AT THE END OF THIS CLASS (SO AS TO NOT UNNECESSARILY BREAK SERIALIZED CLASSIFIERS) private static final long serialVersionUID = -7076671761070232566L; boolean useNGrams = false; boolean usePrefixSuffixNGrams = false; boolean lowercaseNGrams = false; boolean lowercase; boolean useSplitNGrams = false; boolean useSplitPrefixSuffixNGrams = false; boolean cacheNGrams = false; int maxNGramLeng = -1; int minNGramLeng = 2; String partialNGramRegexp = null; Pattern partialNGramPattern = null; boolean useSum = false; double tolerance = 1e-4; String printFeatures = null; String printClassifier = null; int printClassifierParam = 100; boolean exitAfterTrainingFeaturization = false; boolean intern = false; Pattern splitWordsPattern = null; Pattern splitWordsTokenizerPattern = null; Pattern splitWordsIgnorePattern = Pattern.compile(DEFAULT_IGNORE_REGEXP); boolean useSplitWords = false; boolean useSplitWordPairs = false; boolean useLowercaseSplitWordPairs = false; boolean useSplitFirstLastWords = false; boolean useLowercaseSplitWords = false; boolean useLowercaseSplitFirstLastWords = false; int wordShape = edu.stanford.nlp.process.WordShapeClassifier.NOWORDSHAPE; int splitWordShape = WordShapeClassifier.NOWORDSHAPE; boolean useString = false; boolean useClassFeature = false; int[] binnedLengths = null; TwoDimensionalCounter binnedLengthsCounter = null; double[] binnedValues = null; TwoDimensionalCounter binnedValuesCounter = null; double binnedValuesNaN = -1.0; //true iff this feature is real valued boolean isRealValued = false; public static final String realValuedFeaturePrefix = "Value"; //prefix to add before column number for denoting real valued features boolean logitTransform = false; boolean logTransform = false; boolean sqrtTransform = false; char[] countChars = null; int[] countCharsBins = {0, 1}; ClassicCounter biasedHyperplane = null; boolean justify = false; boolean featureFormat = false; boolean significantColumnId = false; String useClassifierFactory; String classifierFactoryArgs; boolean useNB = false; boolean useQN = true; int QNsize = 15; int prior = LogPrior.LogPriorType.QUADRATIC.ordinal(); double sigma = 1.0; double epsilon = 0.01; int featureMinimumSupport = 0; int displayedColumn = 1; // = 2nd column of data file! (Because we count from 0.) int groupingColumn = -1; int rankingScoreColumn = -1; String rankingAccuracyClass = null; int goldAnswerColumn = 0; boolean biased; boolean useSplitWordNGrams = false; int maxWordNGramLeng = -1; int minWordNGramLeng = 1; boolean useBinary = false; double l1reg = 0.0; String wordNGramBoundaryRegexp; Pattern wordNGramBoundaryPattern; boolean useAdaptL1 = false; int limitFeatures = 0; String limitFeaturesLabels = null; double l1regmin = 0.0; double l1regmax = 500.0; double featureWeightThreshold = 0; String testFile = null; // this one would be better off static (we avoid serializing it) String loadClassifier = null; // this one could also be static // these are static because we don't want them serialized static String trainFile = null; static String serializeTo = null; static String printTo = null; static boolean trainFromSVMLight = false; //train file is in SVMLight format static boolean testFromSVMLight = false; //test file is in SVMLight format static String encoding = null; static String printSVMLightFormatTo; static boolean displayAllAnswers = false; // Distinguishes whether this file has real valued features or if the more efficient non-RVF representation can be used. // This is set as a summary flag in globalFeatures based on whether anything uses real values. boolean usesRealValues; boolean filename; boolean useAllSplitWordPairs; boolean useAllSplitWordTriples; boolean showTokenization = false; int crossValidationFolds = -1; boolean shuffleTrainingData = false; long shuffleSeed = 0; static boolean csvInput = false; //train and test files are in csv format static InputFormat inputFormat = InputFormat.PLAIN; boolean splitWordsWithPTBTokenizer = false; boolean splitWordCount; boolean logSplitWordCount; int[] binnedSplitWordCounts; Map wordVectors; static String csvOutput = null; boolean printCrossValidationDecisions = false; boolean verboseOptimization = false; @Override public String toString() { return "Flags[" + "goldAnswerColumn = " + goldAnswerColumn + ", useString = " + useString + ", useNGrams = " + useNGrams + ", usePrefixSuffixNGrams = " + usePrefixSuffixNGrams + ']'; } } // end class Flags } // end class ColumnDataClassifier





© 2015 - 2024 Weber Informatics LLC | Privacy Policy