All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streams.weka.Train Maven / Gradle / Ivy

The newest version!
/**
 * 
 */
package streams.weka;

import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Map;
import java.util.zip.GZIPOutputStream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Element;

import stream.AbstractProcessor;
import stream.Configurable;
import stream.Data;
import stream.runtime.setup.ParameterInjection;
import stream.runtime.setup.factory.ObjectFactory;
import stream.util.Variables;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

/**
 * @author chris
 * 
 */
public class Train extends AbstractProcessor implements Configurable {

	static Logger log = LoggerFactory.getLogger(Train.class);
	final ArrayList attributes = new ArrayList();
	final ArrayList classes = new ArrayList();
	Instances dataset;

	File output;
	Classifier classifier;
	String classifierType = "weka.classifier.tree.RandomForest";

	public Train() {
		// classes.add("gamma");
		// classes.add("proton");
	}

	public Train(String defaultType) {
		this();
		classifierType = defaultType;
	}

	protected Classifier createClassifier() {
		return classifier;
	}

	/**
	 * @see stream.Processor#process(stream.Data)
	 */
	@Override
	public Data process(Data input) {

		if (attributes.isEmpty()) {

			for (String key : input.keySet()) {

				if (key.startsWith("@")) {
					continue;
				}

				Serializable value = input.get(key);
				if (value instanceof Double) {
					Attribute attribute = new Attribute(key);
					attributes.add(attribute);
					continue;
				}

				log.warn(
						"Unsupported feature type '{}' in key '{}'  (skipping)",
						value.getClass().getCanonicalName(), key);
			}

			// add the class attribute ('label')
			//

			Attribute label = new Attribute("@label", classes);
			attributes.add(label);

			dataset = new Instances("trainingData", attributes, 10000);
			dataset.setClass(label);
		}

		Instance example = new DenseInstance(attributes.size());
		for (Attribute a : attributes) {
			Serializable value = input.get(a.name());
			if (value == null) {
				log.warn("Missing value for attribute '{}'", a.name());
				continue;
			}

			if (a.isNumeric()) {
				example.setValue(a, new Double(value.toString()));
			} else {
				example.setValue(a, value.toString());
			}
		}
		dataset.add(example);
		return input;
	}

	/**
	 * @see stream.AbstractProcessor#finish()
	 */
	@Override
	public void finish() throws Exception {
		super.finish();
		log.info("Building classifier using training dataset with {} samples.",
				dataset.size());

		Classifier classifier = createClassifier();
		log.info("Training classifier of type '{}'", classifier.getClass()
				.getCanonicalName());
		long start = System.currentTimeMillis();
		WekaModel wm = new WekaModel(classifier);
		wm.train(dataset);
		// forest.buildClassifier(dataset);
		long end = System.currentTimeMillis();
		log.info("Training took {} ms.", (end - start));

		log.info("Classifier trained:\n{}", wm.toInfoString());

		if (output != null) {
			log.info("Writing model to {}", output);
			OutputStream os;
			if (output.getName().endsWith(".gz")) {
				os = new GZIPOutputStream(new FileOutputStream(output));
			} else {
				os = new FileOutputStream(output);
			}
			WekaModel.write(wm, os);
			os.close();
		} else {
			log.info("Parameter 'output' not set, not writing model.");
		}
	}

	/**
	 * @return the output
	 */
	public File getOutput() {
		return output;
	}

	/**
	 * @param output
	 *            the output to set
	 */
	public void setOutput(File output) {
		this.output = output;
	}

	/**
	 * @see stream.Configurable#configure(org.w3c.dom.Element)
	 */
	@Override
	public void configure(Element document) {
		log.debug("Configuring element {}", document.getNodeName());

		ObjectFactory factory = ObjectFactory.newInstance();
		Map parameters = factory.getAttributes(document);

		String type = classifierType;
		if (parameters.get("classifier") != null) {
			type = parameters.get("classifier");
		}

		try {
			log.debug("Creating new element of class '{}'", type);
			@SuppressWarnings("unchecked")
			Class clazz = (Class) Class.forName(type);
			classifier = clazz.newInstance();
			log.debug("New classifier is: {}", classifier);

			log.debug("Injecting parameters into classifier...");
			ParameterInjection.inject(classifier, parameters, new Variables());

			log.debug("Classifier ready to be trained ;-)");
		} catch (Exception e) {
			e.printStackTrace();
			log.error("Failed to set up classifier of type '{}'", "");
			throw new RuntimeException("Failed to set up classifier!");
		}
	}

	/**
	 * @return the classifierType
	 */
	public String getClassifier() {
		return classifierType;
	}

	/**
	 * @param classifierType
	 *            the classifierType to set
	 */
	public void setClassifier(String classifierType) {
		this.classifierType = classifierType;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy