All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streams.weka.WekaUtils Maven / Gradle / Ivy

The newest version!
/**
 * 
 */
package streams.weka;

import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.TreeSet;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import stream.Data;
import stream.Keys;
import stream.io.Stream;
import weka.classifiers.AbstractClassifier;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.ProtectedProperties;

/**
 * @author chris
 * 
 */
public class WekaUtils {

	static Logger log = LoggerFactory.getLogger(WekaUtils.class);

	public static Set getOptions(Object o) {
		Set options = new TreeSet();
		if (o instanceof AbstractClassifier) {
			AbstractClassifier ac = (AbstractClassifier) o;
			for (String opt : ac.getOptions()) {
				options.add(opt);
			}
		}
		try {
			Method m = o.getClass().getMethod("getOptions", new Class[0]);
			if (m != null) {
				log.info("Found 'getOptions()' method!");
				String[] result = (String[]) m.invoke(o, new Object[0]);
				log.info("Result from invoking getOptions(): {}",
						(Object[]) result);
				if (result != null) {
					for (String str : result) {
						options.add(str);
					}
				}
			}
		} catch (Exception e) {
			e.printStackTrace();
		}

		log.info("Found options for '{}': {}", o.getClass().getCanonicalName(),
				options);
		return options;
	}

	public static Instances readInstances(Stream stream) throws Exception {
		return readInstances(stream, new Keys("*"));
	}

	public static Instances readInstances(Stream stream, Keys keys)
			throws Exception {
		log.info("Reading instances from {}", stream);
		ArrayList attributes = new ArrayList();

		Data item = stream.read();
		for (String key : keys.select(item.keySet())) {

			Serializable value = item.get(key);
			if (value instanceof Number) {
				log.info("Adding new numeric attribute {}", key);
				Attribute a = new Attribute(key);
				attributes.add(a);
				continue;
			}

			if (key.equals("@label")) {
				List vals = new ArrayList();
				vals.add("gamma");
				vals.add("proton");
				log.info("Adding nominal attribute '{}' with values {}", key,
						vals);
				Attribute nom = new Attribute(key, vals,
						new ProtectedProperties(new Properties()));
				attributes.add(nom);
				continue;
			}

			log.info("Adding new string attribute {}", key);
			List vals = new ArrayList();
			vals.add(value.toString());
			Attribute a = new Attribute(key, (List) null,
					new ProtectedProperties(new Properties()));
			attributes.add(a);
		}

		Instances instances = new Instances("DataSet[" + stream.getId() + "]",
				attributes, 1000);

		while (item != null) {
			Instance instance = createInstance(attributes, item);
			instances.add(instance);
			item = stream.read();
		}

		log.info("Read {} instances.", instances.size());
		return instances;
	}

	public static Instance createInstance(ArrayList attributes,
			Data item) {

		Instance instance = new DenseInstance(attributes.size());

		for (int i = 0; i < attributes.size(); i++) {
			Attribute a = attributes.get(i);
			Serializable value = item.get(a.name());
			if (value instanceof Number) {
				Number num = (Number) value;
				instance.setValue(a, num.doubleValue());
				continue;
			}

			if (a.isNominal() && !a.isString()) {
				instance.setValue(a, value.toString());
				continue;
			}

			String str = value.toString();
			if (a.indexOfValue(str) < 0) {
				a.addStringValue(str);
			}
			instance.setValue(a, value.toString());
		}

		return instance;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy