
io.github.mianalysis.mia.module.objects.measure.miscellaneous.ApplyWekaObjectClassification Maven / Gradle / Ivy
package io.github.mianalysis.mia.module.objects.measure.miscellaneous;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import org.scijava.Priority;
import org.scijava.plugin.Plugin;
import io.github.mianalysis.mia.MIA;
import io.github.mianalysis.mia.module.Categories;
import io.github.mianalysis.mia.module.Category;
import io.github.mianalysis.mia.module.Module;
import io.github.mianalysis.mia.module.Modules;
import io.github.mianalysis.mia.object.Measurement;
import io.github.mianalysis.mia.object.Obj;
import io.github.mianalysis.mia.object.ObjMetadata;
import io.github.mianalysis.mia.object.Objs;
import io.github.mianalysis.mia.object.Workspace;
import io.github.mianalysis.mia.object.parameters.BooleanP;
import io.github.mianalysis.mia.object.parameters.FilePathP;
import io.github.mianalysis.mia.object.parameters.InputObjectsP;
import io.github.mianalysis.mia.object.parameters.Parameters;
import io.github.mianalysis.mia.object.parameters.SeparatorP;
import io.github.mianalysis.mia.object.refs.ObjMeasurementRef;
import io.github.mianalysis.mia.object.refs.ObjMetadataRef;
import io.github.mianalysis.mia.object.refs.collections.ImageMeasurementRefs;
import io.github.mianalysis.mia.object.refs.collections.MetadataRefs;
import io.github.mianalysis.mia.object.refs.collections.ObjMeasurementRefs;
import io.github.mianalysis.mia.object.refs.collections.ObjMetadataRefs;
import io.github.mianalysis.mia.object.refs.collections.ParentChildRefs;
import io.github.mianalysis.mia.object.refs.collections.PartnerRefs;
import io.github.mianalysis.mia.object.system.Status;
import weka.classifiers.AbstractClassifier;
import weka.core.Instances;
import weka.core.SparseInstance;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
/**
* Apply a previously-prepared WEKA object classifier to a specified object
* collection from the workspace. Classification can be based on a range of
* measurements associated with the input objects. All measurements used to
* create this model should be present in the input objects and have the same
* names (i.e. measurement names shouldn't be changed during preparation of
* training data).
*
* The probability of each input object belonging to each class is output as a
* measurement associated with that object. Each object also has a class index
* (based on the order the classes are listed in the .model file) indicating the
* most probable class that object belongs to.
*/
@Plugin(type = Module.class, priority = Priority.LOW, visible = true)
public class ApplyWekaObjectClassification extends Module {
/**
*
*/
public static final String INPUT_SEPARATOR = "Objects input";
/**
* Input objects from workspace which will be classified based on model
* specified by "Classifier path" parameter.
*/
public static final String INPUT_OBJECTS = "Input objects";
/**
*
*/
public static final String CLASSIFIER_SEPARATOR = "Classifier controls";
/**
* WEKA model (.model extension) that will be used to classify input objects
* based on a variety of measurements. This model must be created in the
* WEKA software. All
* measurements used to create this model should be present in the input objects
* and have the same names (i.e. measurement names shouldn't be changed during
* preparation of training data).
*/
public static final String CLASSIFIER_PATH = "Classifier path";
/**
* When selected, measurements will be normalised (set to the range 0-1) within
* their respective classes.
*/
public static final String APPLY_NORMALISATION = "Apply normalisation";
private String currClassifierPath = "";
private Instances currInstances = null;
private AbstractClassifier currClassifier = null;
public interface ObjMetadataItems {
public static final String CLASS = "CLASSIFIER // CLASS";
}
public ApplyWekaObjectClassification(Modules modules) {
super("Apply Weka object classification", modules);
}
@Override
public Category getCategory() {
return Categories.OBJECTS_MEASURE_MISCELLANEOUS;
}
@Override
public String getVersionNumber() {
return "1.0.0";
}
@Override
public String getDescription() {
return "Apply a previously-prepared WEKA object classifier to a specified object collection from the workspace. Classification can be based on a range of measurements associated with the input objects. All measurements used to create this model should be present in the input objects and have the same names (i.e. measurement names shouldn't be changed during preparation of training data).
"
+
"The probability of each input object belonging to each class is output as a measurement associated with that object. Each object also has a class index (based on the order the classes are listed in the .model file) indicating the most probable class that object belongs to.";
}
public static String getProbabilityMeasurementName(String className) {
return "CLASSIFIER // " + className + "_PROB";
}
public static String getClassMeasurementName(Instances instances) {
StringBuilder sb = new StringBuilder("CLASSIFIER // CLASS (");
for (int i = 0; i < instances.numClasses(); i++) {
if (i != 0)
sb.append(",");
sb.append(instances.classAttribute().value(i));
}
sb.append(")");
return sb.toString();
}
Instances getInstances(String classifierPath) throws FileNotFoundException, IOException, ClassNotFoundException {
if (currClassifierPath.equals(classifierPath) && currInstances != null)
return currInstances;
currClassifierPath = classifierPath;
if (!new File(classifierPath).exists())
return null;
ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(classifierPath));
currClassifier = (AbstractClassifier) objectInputStream.readObject();
currInstances = (Instances) objectInputStream.readObject();
objectInputStream.close();
return currInstances;
}
AbstractClassifier getClassifier(String classifierPath) throws FileNotFoundException, IOException, ClassNotFoundException {
if (currClassifierPath.equals(classifierPath) && currInstances != null)
return currClassifier;
currClassifierPath = classifierPath;
if (!new File(classifierPath).exists())
return null;
ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(classifierPath));
currClassifier = (AbstractClassifier) objectInputStream.readObject();
currInstances = (Instances) objectInputStream.readObject();
objectInputStream.close();
return currClassifier;
}
public static void addProbabilityMeasurements(Obj inputObject, Instances instances, double[] classification) {
for (int i = 0; i < instances.numClasses(); i++) {
String measName = getProbabilityMeasurementName(instances.classAttribute().value(i));
Measurement measurement = new Measurement(measName, classification[i]);
inputObject.addMeasurement(measurement);
}
}
@Override
public Status process(Workspace workspace) {
// Getting input objects
String inputObjectsName = parameters.getValue(INPUT_OBJECTS, workspace);
Objs inputObjects = workspace.getObjects().get(inputObjectsName);
// Getting other parameters
String classifierPath = parameters.getValue(CLASSIFIER_PATH, workspace);
boolean applyNormalisation = parameters.getValue(APPLY_NORMALISATION, workspace);
Instances instances = null;
AbstractClassifier classifier = null;
try {
instances = getInstances(classifierPath);
classifier = getClassifier(classifierPath);
} catch (IOException | ClassNotFoundException e) {
MIA.log.writeError(e);
return Status.FAIL;
}
// Getting list of measurement names
ArrayList measurementNames = new ArrayList<>();
for (int i = 0; i < instances.numAttributes(); i++) {
String measName = instances.attribute(i).name();
measName = measName.replace("", "");
measurementNames.add(measName);
}
// Adding object instances
ArrayList processedObjects = new ArrayList<>();
for (Obj inputObject : inputObjects.values()) {
double[] objAttr = new double[measurementNames.size() - 1];
for (int i = 0; i < measurementNames.size() - 1; i++) {
Measurement measurement = inputObject.getMeasurement(measurementNames.get(i));
// Objects with missing measurements will cause problems for normalisation
if (measurement == null || Double.isNaN(measurement.getValue()))
continue;
objAttr[i] = measurement.getValue();
}
processedObjects.add(inputObject);
instances.add(new SparseInstance(1, objAttr));
}
try {
// Applying normalisation
if (applyNormalisation) {
Normalize normalize = new Normalize();
normalize.setInputFormat(instances);
instances = Filter.useFilter(instances, normalize);
}
// Applying classifications
double[][] classifications = classifier.distributionsForInstances(instances);
String classMeasName = getClassMeasurementName(instances);
for (int i = 0; i < processedObjects.size(); i++) {
Obj inputObject = processedObjects.get(i);
double[] classification = classifications[i];
addProbabilityMeasurements(inputObject, instances, classification);
int classIndex = (int) classifier.classifyInstance(instances.get(i));
inputObject.addMeasurement(new Measurement(classMeasName, classIndex));
inputObject.addMetadataItem(
new ObjMetadata(ObjMetadataItems.CLASS, instances.classAttribute().value(classIndex)));
}
} catch (Exception e) {
MIA.log.writeError(e);
return Status.FAIL;
}
if (showOutput) {
inputObjects.showMeasurements(this, modules);
inputObjects.showMetadata(this, modules);
}
return Status.PASS;
}
@Override
protected void initialiseParameters() {
parameters.add(new SeparatorP(INPUT_SEPARATOR, this));
parameters.add(new InputObjectsP(INPUT_OBJECTS, this));
parameters.add(new SeparatorP(CLASSIFIER_SEPARATOR, this));
parameters.add(new FilePathP(CLASSIFIER_PATH, this));
parameters.add(new BooleanP(APPLY_NORMALISATION, this, true));
addParameterDescriptions();
}
@Override
public Parameters updateAndGetParameters() {
return parameters;
}
@Override
public ImageMeasurementRefs updateAndGetImageMeasurementRefs() {
return null;
}
@Override
public ObjMeasurementRefs updateAndGetObjectMeasurementRefs() {
Workspace workspace = null;
String inputObjectsName = parameters.getValue(INPUT_OBJECTS, workspace);
try {
// Getting class names
String currClassifierPath = parameters.getValue(CLASSIFIER_PATH, workspace);
ObjMeasurementRefs returnedRefs = new ObjMeasurementRefs();
Instances instances = getInstances(currClassifierPath);
if (instances == null)
return returnedRefs;
for (int i = 0; i < instances.numClasses(); i++) {
String className = instances.classAttribute().value(i);
ObjMeasurementRef ref = objectMeasurementRefs.getOrPut(getProbabilityMeasurementName(className));
ref.setObjectsName(inputObjectsName);
returnedRefs.add(ref);
}
ObjMeasurementRef ref = objectMeasurementRefs.getOrPut(getClassMeasurementName(instances));
ref.setObjectsName(inputObjectsName);
returnedRefs.add(ref);
return returnedRefs;
} catch (IOException | ClassNotFoundException e) {
MIA.log.writeError(e);
return null;
}
}
@Override
public ObjMetadataRefs updateAndGetObjectMetadataRefs() {
Workspace workspace = null;
String inputObjectsName = parameters.getValue(INPUT_OBJECTS, workspace);
ObjMetadataRefs returnedRefs = new ObjMetadataRefs();
ObjMetadataRef ref = objectMetadataRefs.getOrPut(ObjMetadataItems.CLASS);
ref.setObjectsName(inputObjectsName);
returnedRefs.add(ref);
return returnedRefs;
}
@Override
public MetadataRefs updateAndGetMetadataReferences() {
return null;
}
@Override
public ParentChildRefs updateAndGetParentChildRefs() {
return null;
}
@Override
public PartnerRefs updateAndGetPartnerRefs() {
return null;
}
@Override
public boolean verify() {
return true;
}
void addParameterDescriptions() {
parameters.get(INPUT_OBJECTS)
.setDescription("Input objects from workspace which will be classified based on model specified by \""
+ CLASSIFIER_PATH + "\" parameter.");
parameters.get(CLASSIFIER_PATH).setDescription(
"WEKA model (.model extension) that will be used to classify input objects based on a variety of measurements. This model must be created in the WEKA software. All measurements used to create this model should be present in the input objects and have the same names (i.e. measurement names shouldn't be changed during preparation of training data).");
parameters.get(APPLY_NORMALISATION).setDescription(
"When selected, measurements will be normalised (set to the range 0-1) within their respective classes.");
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy