
streams.weka.PredictionError Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of streams-weka Show documentation
Show all versions of streams-weka Show documentation
Weka adaption for streaming processes in streams.
The newest version!
/**
*
*/
package streams.weka;
import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.TreeSet;
import net.minidev.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import stream.AbstractProcessor;
import stream.Data;
import stream.ProcessContext;
import stream.annotations.Parameter;
import stream.data.Statistics;
import stream.io.Sink;
/**
* This processor computes the prediction error over all the processed data
* items. The items are expected to provide a `@label` and a `@prediction` key
* that holds the true label of an item and the predicted value of some
* classifier (use the parameters `label` and `prediction` to chooose different
* keys).
*
*
* @author Christian Bockermann <[email protected]>
*
*/
public class PredictionError extends AbstractProcessor {
static Logger logger = LoggerFactory.getLogger(PredictionError.class);
@Parameter(description = "The key under which the true label of the item is found.", required = false, defaultValue = "@label")
String label = "@label";
@Parameter(description = "The key under which the predicted class value is found.", required = false, defaultValue = "@prediction")
String prediction = "@prediction";
Statistics statistics = new Statistics();
Long tests = 0L;
String[] classes;
String id = null;
ArrayList labels = new ArrayList();
Sink[] output;
File log;
/**
* @see stream.AbstractProcessor#init(stream.ProcessContext)
*/
@Override
public void init(ProcessContext ctx) throws Exception {
super.init(ctx);
if (classes != null) {
for (String cls : classes) {
if (!labels.contains(cls)) {
labels.add(cls);
}
}
}
}
/**
* @see stream.Processor#process(stream.Data)
*/
@Override
public Data process(Data input) {
if (input.containsKey(this.label) && input.containsKey(this.prediction)) {
String label = input.get(this.label).toString();
if (!labels.contains(label)) {
logger.debug("Found new class '{}'", label);
if (labels.isEmpty()) {
logger.debug("Using class '{}' as positive class.", label);
}
labels.add(label);
}
String positiveClass = label;
String negativeClass = "!" + label;
String pred = input.get(this.prediction).toString();
if (pred.equals(label)) {
input.put("@error", "no");
statistics.add("prediction.correct", 1.0);
} else {
input.put("@error", "yes");
statistics.add("prediction.wrong", 1.0);
}
statistics.add(label + ".total", 1.0);
statistics.add(label + ".predicted." + pred, 1.0);
Double tp = value(positiveClass + ".predicted." + positiveClass);
Double fp = value(positiveClass + ".predicted." + negativeClass);
Double tn = value(negativeClass + ".predicted." + negativeClass);
Double fn = value(negativeClass + ".predicted." + positiveClass);
Double p = value(positiveClass + ".total");
if (p == null)
p = 0.0;
Double n = value(negativeClass + ".total");
if (n == null)
n = 0.0;
Double acc = (tp + tn) / (p + n);
input.put("@accuracy", acc);
input.put("@precision", (tp / (fp + tp)));
input.put("@f-score", (2 + tp) / (2 * tp + fp + fn));
tests++;
}
return input;
}
public List other(String label) {
ArrayList other = new ArrayList();
for (String l : labels) {
if (l.equals(label)) {
continue;
}
other.add(l);
}
return other;
}
/**
* @see stream.AbstractProcessor#finish()
*/
@Override
public void finish() throws Exception {
super.finish();
DecimalFormatSymbols dfs = new DecimalFormatSymbols();
dfs.setDecimalSeparator('.');
DecimalFormat fmt = new DecimalFormat("0.00", dfs);
logger.info("+---------------------------------------------------------");
logger.info("| Prediction Error (ID '{}', {} elements tested.)", id,
tests);
logger.info("| {} classes in test set.", labels.size());
for (String pos : this.labels) {
Double tp = value(pos + ".predicted." + pos);
Double fp = 0.0;
Double fn = 0.0;
Double tn = 0.0;
Double p = 0.0;
Double n = 0.0;
p += value(pos + ".total");
for (String neg : other(pos)) {
fp += value(pos + ".predicted." + neg);
fn += value(neg + ".predicted." + pos);
tn += value(neg + ".predicted." + neg);
n += value(neg + ".total");
}
Double acc = (tp + tn) / (p + n);
Double precision = (tp / (fp + tp));
Double recall = tp / p;
Double fScore = (2 + tp) / (2 * tp + fp + fn);
Double qfactor = qfactor(tp, fp, tn, fn);
statistics.put("class['" + pos + "'].accuracy", acc);
statistics.put("class['" + pos + "'].precision", precision);
statistics.put("class['" + pos + "'].fScore", fScore);
statistics.put("class['" + pos + "'].qfactor", qfactor);
logger.info("|");
logger.info("| class['{}'].total: {}", pos,
fmt.format(value(pos + ".total")));
logger.info("| class['{}'].correct: {}", pos,
fmt.format(value(pos + ".predicted." + pos)));
logger.info("| class['" + pos + "'].accuracy: {}",
fmt.format(acc));
logger.info("| class['" + pos + "'].precision: {}",
fmt.format(precision));
logger.info("| class['" + pos + "'].recall: {}",
fmt.format(recall));
logger.info("| class['" + pos + "'].f-score: {}",
fmt.format(fScore));
}
logger.info("|");
Collections.sort(labels);
logger.info("|");
for (String k : new TreeSet(statistics.keySet())) {
if (k.indexOf(".predicted.") > 0) {
logger.info("| {} => {}", k, fmt.format(statistics.get(k)));
}
}
logger.info("|");
logger.info("+---------------------------------------------------------");
if (log != null) {
try {
logger.info("Writing performance log to {}", log);
// boolean header = !log.isFile() || log.length() == 0L;
FileOutputStream fos = new FileOutputStream(log, true);
PrintStream ps = new PrintStream(fos);
ps.println(JSONObject.toJSONString(statistics));
ps.close();
} catch (Exception e) {
e.printStackTrace();
}
}
}
protected static double qfactor(double tp, double fp, double tn, double fn) {
return (tp / (tp + fp)) / Math.sqrt(fp / (fp + tn));
}
public Double value(String key) {
Double d = statistics.get(key);
if (d == null) {
return 0.0;
}
return d;
}
/**
* @return the id
*/
public String getId() {
return id;
}
/**
* @param id
* the id to set
*/
public void setId(String id) {
this.id = id;
}
/**
* @return the output
*/
public Sink[] getOutput() {
return output;
}
/**
* @param output
* the output to set
*/
public void setOutput(Sink[] output) {
this.output = output;
}
/**
* @return the log
*/
public File getLog() {
return log;
}
/**
* @param log
* the log to set
*/
public void setLog(File log) {
this.log = log;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy