All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streams.weka.LogSizedPerformance Maven / Gradle / Ivy

The newest version!
/**
 * 
 */
package streams.weka;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import stream.AbstractProcessor;
import stream.Data;
import stream.ProcessContext;
import stream.data.Statistics;

/**
 * @author chris
 * 
 */
public class LogSizedPerformance extends AbstractProcessor {

	static Logger log = LoggerFactory.getLogger(LogSizedPerformance.class);
	int bins = 10;
	String binKey = "size";
	Statistics[] binStats = new Statistics[bins];

	/**
	 * @see stream.AbstractProcessor#init(stream.ProcessContext)
	 */
	@Override
	public void init(ProcessContext ctx) throws Exception {
		super.init(ctx);
		binStats = new Statistics[bins];
		for (int i = 0; i < binStats.length; i++) {
			binStats[i] = new Statistics();
		}
	}

	/**
	 * @see stream.Processor#process(stream.Data)
	 */
	@Override
	public Data process(Data input) {
		if (input.containsKey(binKey) && input.containsKey("@prediction")
				&& input.containsKey("@label")) {
			Double size = (Double) input.get(binKey);
			Double logSize = Math.log10(size);
			int bin = logSize.intValue();

			if (bin >= binStats.length) {
				bin = binStats.length - 1;
			}

			Statistics st = binStats[bin];

			String label = input.get("@label").toString();
			String pred = input.get("@prediction").toString();

			st.add(label + ".predicted." + pred, 1.0d);
			st.add(label + ".total", 1.0d);

			for (String key : input.keySet()) {
				if (key.startsWith("@prob:")) {
					Double p = (Double) input.get(key);
					st.add(key, p);
				}
			}
		}
		return input;
	}

	/**
	 * @see stream.AbstractProcessor#finish()
	 */
	@Override
	public void finish() throws Exception {
		super.finish();

		for (int i = 0; i < binStats.length; i++) {
			Statistics s = binStats[i];
			Double tp = get(s, "gamma.predicted.gamma");
			Double fp = get(s, "gamma.predicted.proton");
			Double tn = get(s, "proton.predicted.proton");
			Double fn = get(s, "proton.predicted.gamma");

			Double tpr = tp / (tp + fn);
			Double fpr = fp / (fp + tn);
			if (!Double.isNaN(tpr)) {
				log.info(
						"bin[{}].gamma-purity = {}      q="
								+ (tpr / Math.sqrt(fpr)), i, tpr);
			}
		}
	}

	public Double get(Statistics s, String key) {
		Double d = s.get(key);
		if (d == null) {
			return 0.0;
		}
		return d;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy