All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streams.weka.PredictionError Maven / Gradle / Ivy

The newest version!
/**
 * 
 */
package streams.weka;

import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.TreeSet;

import net.minidev.json.JSONObject;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import stream.AbstractProcessor;
import stream.Data;
import stream.ProcessContext;
import stream.annotations.Parameter;
import stream.data.Statistics;
import stream.io.Sink;

/**
 * This processor computes the prediction error over all the processed data
 * items. The items are expected to provide a `@label` and a `@prediction` key
 * that holds the true label of an item and the predicted value of some
 * classifier (use the parameters `label` and `prediction` to chooose different
 * keys).
 * 
 * 
 * @author Christian Bockermann <[email protected]>
 * 
 */
public class PredictionError extends AbstractProcessor {

	static Logger logger = LoggerFactory.getLogger(PredictionError.class);

	@Parameter(description = "The key under which the true label of the item is found.", required = false, defaultValue = "@label")
	String label = "@label";

	@Parameter(description = "The key under which the predicted class value is found.", required = false, defaultValue = "@prediction")
	String prediction = "@prediction";

	Statistics statistics = new Statistics();
	Long tests = 0L;
	String[] classes;

	String id = null;
	ArrayList labels = new ArrayList();

	Sink[] output;
	File log;

	/**
	 * @see stream.AbstractProcessor#init(stream.ProcessContext)
	 */
	@Override
	public void init(ProcessContext ctx) throws Exception {

		super.init(ctx);

		if (classes != null) {
			for (String cls : classes) {
				if (!labels.contains(cls)) {
					labels.add(cls);
				}
			}
		}
	}

	/**
	 * @see stream.Processor#process(stream.Data)
	 */
	@Override
	public Data process(Data input) {

		if (input.containsKey(this.label) && input.containsKey(this.prediction)) {
			String label = input.get(this.label).toString();

			if (!labels.contains(label)) {
				logger.debug("Found new class '{}'", label);
				if (labels.isEmpty()) {
					logger.debug("Using class '{}' as positive class.", label);
				}
				labels.add(label);
			}

			String positiveClass = label;
			String negativeClass = "!" + label;

			String pred = input.get(this.prediction).toString();

			if (pred.equals(label)) {
				input.put("@error", "no");
				statistics.add("prediction.correct", 1.0);
			} else {
				input.put("@error", "yes");
				statistics.add("prediction.wrong", 1.0);
			}

			statistics.add(label + ".total", 1.0);
			statistics.add(label + ".predicted." + pred, 1.0);

			Double tp = value(positiveClass + ".predicted." + positiveClass);
			Double fp = value(positiveClass + ".predicted." + negativeClass);
			Double tn = value(negativeClass + ".predicted." + negativeClass);
			Double fn = value(negativeClass + ".predicted." + positiveClass);

			Double p = value(positiveClass + ".total");
			if (p == null)
				p = 0.0;
			Double n = value(negativeClass + ".total");
			if (n == null)
				n = 0.0;

			Double acc = (tp + tn) / (p + n);
			input.put("@accuracy", acc);
			input.put("@precision", (tp / (fp + tp)));
			input.put("@f-score", (2 + tp) / (2 * tp + fp + fn));

			tests++;
		}

		return input;
	}

	public List other(String label) {
		ArrayList other = new ArrayList();
		for (String l : labels) {
			if (l.equals(label)) {
				continue;
			}
			other.add(l);
		}
		return other;
	}

	/**
	 * @see stream.AbstractProcessor#finish()
	 */
	@Override
	public void finish() throws Exception {
		super.finish();
		DecimalFormatSymbols dfs = new DecimalFormatSymbols();
		dfs.setDecimalSeparator('.');
		DecimalFormat fmt = new DecimalFormat("0.00", dfs);

		logger.info("+---------------------------------------------------------");
		logger.info("|  Prediction Error (ID '{}', {} elements tested.)", id,
				tests);
		logger.info("|    {} classes in test set.", labels.size());

		for (String pos : this.labels) {

			Double tp = value(pos + ".predicted." + pos);
			Double fp = 0.0;
			Double fn = 0.0;
			Double tn = 0.0;

			Double p = 0.0;
			Double n = 0.0;
			p += value(pos + ".total");

			for (String neg : other(pos)) {

				fp += value(pos + ".predicted." + neg);
				fn += value(neg + ".predicted." + pos);
				tn += value(neg + ".predicted." + neg);

				n += value(neg + ".total");
			}

			Double acc = (tp + tn) / (p + n);
			Double precision = (tp / (fp + tp));
			Double recall = tp / p;
			Double fScore = (2 + tp) / (2 * tp + fp + fn);
			Double qfactor = qfactor(tp, fp, tn, fn);

			statistics.put("class['" + pos + "'].accuracy", acc);
			statistics.put("class['" + pos + "'].precision", precision);
			statistics.put("class['" + pos + "'].fScore", fScore);
			statistics.put("class['" + pos + "'].qfactor", qfactor);

			logger.info("|");
			logger.info("|  class['{}'].total:    {}", pos,
					fmt.format(value(pos + ".total")));
			logger.info("|  class['{}'].correct:    {}", pos,
					fmt.format(value(pos + ".predicted." + pos)));
			logger.info("|  class['" + pos + "'].accuracy:  {}",
					fmt.format(acc));
			logger.info("|  class['" + pos + "'].precision: {}",
					fmt.format(precision));
			logger.info("|  class['" + pos + "'].recall: {}",
					fmt.format(recall));
			logger.info("|  class['" + pos + "'].f-score:   {}",
					fmt.format(fScore));

		}

		logger.info("|");

		Collections.sort(labels);

		logger.info("|");
		for (String k : new TreeSet(statistics.keySet())) {

			if (k.indexOf(".predicted.") > 0) {
				logger.info("|    {} => {}", k, fmt.format(statistics.get(k)));
			}
		}
		logger.info("|");
		logger.info("+---------------------------------------------------------");

		if (log != null) {
			try {

				logger.info("Writing performance log to {}", log);
				// boolean header = !log.isFile() || log.length() == 0L;
				FileOutputStream fos = new FileOutputStream(log, true);
				PrintStream ps = new PrintStream(fos);
				ps.println(JSONObject.toJSONString(statistics));
				ps.close();
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
	}

	protected static double qfactor(double tp, double fp, double tn, double fn) {
		return (tp / (tp + fp)) / Math.sqrt(fp / (fp + tn));
	}

	public Double value(String key) {
		Double d = statistics.get(key);
		if (d == null) {
			return 0.0;
		}
		return d;
	}

	/**
	 * @return the id
	 */
	public String getId() {
		return id;
	}

	/**
	 * @param id
	 *            the id to set
	 */
	public void setId(String id) {
		this.id = id;
	}

	/**
	 * @return the output
	 */
	public Sink[] getOutput() {
		return output;
	}

	/**
	 * @param output
	 *            the output to set
	 */
	public void setOutput(Sink[] output) {
		this.output = output;
	}

	/**
	 * @return the log
	 */
	public File getLog() {
		return log;
	}

	/**
	 * @param log
	 *            the log to set
	 */
	public void setLog(File log) {
		this.log = log;
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy