
streams.weka.Train Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of streams-weka Show documentation
Show all versions of streams-weka Show documentation
Weka adaption for streaming processes in streams.
The newest version!
/**
*
*/
package streams.weka;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Map;
import java.util.zip.GZIPOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Element;
import stream.AbstractProcessor;
import stream.Configurable;
import stream.Data;
import stream.runtime.setup.ParameterInjection;
import stream.runtime.setup.factory.ObjectFactory;
import stream.util.Variables;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
/**
* @author chris
*
*/
public class Train extends AbstractProcessor implements Configurable {
static Logger log = LoggerFactory.getLogger(Train.class);
final ArrayList attributes = new ArrayList();
final ArrayList classes = new ArrayList();
Instances dataset;
File output;
Classifier classifier;
String classifierType = "weka.classifier.tree.RandomForest";
public Train() {
// classes.add("gamma");
// classes.add("proton");
}
public Train(String defaultType) {
this();
classifierType = defaultType;
}
protected Classifier createClassifier() {
return classifier;
}
/**
* @see stream.Processor#process(stream.Data)
*/
@Override
public Data process(Data input) {
if (attributes.isEmpty()) {
for (String key : input.keySet()) {
if (key.startsWith("@")) {
continue;
}
Serializable value = input.get(key);
if (value instanceof Double) {
Attribute attribute = new Attribute(key);
attributes.add(attribute);
continue;
}
log.warn(
"Unsupported feature type '{}' in key '{}' (skipping)",
value.getClass().getCanonicalName(), key);
}
// add the class attribute ('label')
//
Attribute label = new Attribute("@label", classes);
attributes.add(label);
dataset = new Instances("trainingData", attributes, 10000);
dataset.setClass(label);
}
Instance example = new DenseInstance(attributes.size());
for (Attribute a : attributes) {
Serializable value = input.get(a.name());
if (value == null) {
log.warn("Missing value for attribute '{}'", a.name());
continue;
}
if (a.isNumeric()) {
example.setValue(a, new Double(value.toString()));
} else {
example.setValue(a, value.toString());
}
}
dataset.add(example);
return input;
}
/**
* @see stream.AbstractProcessor#finish()
*/
@Override
public void finish() throws Exception {
super.finish();
log.info("Building classifier using training dataset with {} samples.",
dataset.size());
Classifier classifier = createClassifier();
log.info("Training classifier of type '{}'", classifier.getClass()
.getCanonicalName());
long start = System.currentTimeMillis();
WekaModel wm = new WekaModel(classifier);
wm.train(dataset);
// forest.buildClassifier(dataset);
long end = System.currentTimeMillis();
log.info("Training took {} ms.", (end - start));
log.info("Classifier trained:\n{}", wm.toInfoString());
if (output != null) {
log.info("Writing model to {}", output);
OutputStream os;
if (output.getName().endsWith(".gz")) {
os = new GZIPOutputStream(new FileOutputStream(output));
} else {
os = new FileOutputStream(output);
}
WekaModel.write(wm, os);
os.close();
} else {
log.info("Parameter 'output' not set, not writing model.");
}
}
/**
* @return the output
*/
public File getOutput() {
return output;
}
/**
* @param output
* the output to set
*/
public void setOutput(File output) {
this.output = output;
}
/**
* @see stream.Configurable#configure(org.w3c.dom.Element)
*/
@Override
public void configure(Element document) {
log.debug("Configuring element {}", document.getNodeName());
ObjectFactory factory = ObjectFactory.newInstance();
Map parameters = factory.getAttributes(document);
String type = classifierType;
if (parameters.get("classifier") != null) {
type = parameters.get("classifier");
}
try {
log.debug("Creating new element of class '{}'", type);
@SuppressWarnings("unchecked")
Class clazz = (Class) Class.forName(type);
classifier = clazz.newInstance();
log.debug("New classifier is: {}", classifier);
log.debug("Injecting parameters into classifier...");
ParameterInjection.inject(classifier, parameters, new Variables());
log.debug("Classifier ready to be trained ;-)");
} catch (Exception e) {
e.printStackTrace();
log.error("Failed to set up classifier of type '{}'", "");
throw new RuntimeException("Failed to set up classifier!");
}
}
/**
* @return the classifierType
*/
public String getClassifier() {
return classifierType;
}
/**
* @param classifierType
* the classifierType to set
*/
public void setClassifier(String classifierType) {
this.classifierType = classifierType;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy