org.kie.kogito.predictions.smile.SmileRandomForest Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of kogito-predictions-smile-addon
Show all versions of kogito-predictions-smile-addon
Kogito Prediction based on SMILE addon
package org.kie.kogito.predictions.smile;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import org.kie.api.runtime.process.WorkItem;
import org.kie.kogito.prediction.api.PredictionOutcome;
import org.kie.kogito.prediction.api.PredictionService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.RandomForest;
import smile.data.Attribute;
import smile.data.AttributeDataset;
import smile.data.NominalAttribute;
import smile.data.NumericAttribute;
import smile.data.StringAttribute;
public class SmileRandomForest extends AbstractPredictionEngine implements PredictionService {
public static final String IDENTIFIER = "SMILERandomForest";
private static final String UNABLE_PARSE_TEXT = "Unable to parse text";
private static final Logger logger = LoggerFactory.getLogger(SmileRandomForest.class);
private final AttributeDataset dataset;
private final Map smileAttributes;
private final Attribute outcomeAttribute;
private final AttributeType outcomeAttributeType;
private final int numAttributes;
private final int numberTrees;
protected List attributeNames = new ArrayList<>();
private Set outcomeSet = new HashSet<>();
private static final int MINIMUM_OBSERVATIONS = 1200;
private int observations = 0;
public SmileRandomForest(RandomForestConfiguration configuration) {
this(configuration.getInputFeatures(),
configuration.getOutcomeName(),
configuration.getOutcomeType(),
configuration.getConfidenceThreshold(),
configuration.getNumTrees());
}
public SmileRandomForest(Map inputFeatures,
String outputFeatureName,
AttributeType outputFeatureType,
double confidenceThreshold,
int numberTrees) {
super(inputFeatures, outputFeatureName, outputFeatureType, confidenceThreshold);
this.numberTrees = numberTrees;
smileAttributes = new HashMap<>();
for (Entry inputFeature : inputFeatures.entrySet()) {
final String name = inputFeature.getKey();
final AttributeType type = inputFeature.getValue();
smileAttributes.put(name, createAttribute(name, type));
attributeNames.add(name);
}
numAttributes = smileAttributes.size();
outcomeAttribute = createAttribute(outputFeatureName, outputFeatureType);
outcomeAttributeType = outputFeatureType;
dataset = new AttributeDataset("dataset", smileAttributes.values().toArray(new Attribute[numAttributes]), outcomeAttribute);
}
protected Attribute createAttribute(String name, AttributeType type) {
if (type == AttributeType.NOMINAL || type == AttributeType.BOOLEAN) {
return new NominalAttribute(name);
} else if (type == AttributeType.NUMERIC) {
return new NumericAttribute(name);
} else {
return new StringAttribute(name);
}
}
protected Object convertValue(String value, AttributeType type) {
if (type == AttributeType.NOMINAL) {
return value;
} else if (type == AttributeType.NUMERIC) {
return Long.valueOf(value);
} else if (type == AttributeType.BOOLEAN) {
return Boolean.valueOf(value);
} else {
return value;
}
}
/**
* Add the data provided as a map to a Smile {@link smile.data.Dataset}.
*
* @param data A map containing the input attribute names as keys and the attribute values as values.
* @param outcome The value of the outcome (output data).
*/
public void addData(Map data, Object outcome) {
final double[] features = new double[numAttributes];
int i = 0;
for (Entry entry : smileAttributes.entrySet()) {
try {
features[i] = smileAttributes.get(entry.getKey()).valueOf(data.get(entry.getKey()).toString());
} catch (ParseException e) {
logger.error(UNABLE_PARSE_TEXT, e);
}
i++;
}
try {
final String outcomeStr = outcome.toString();
outcomeSet.add(outcomeStr);
dataset.add(features, outcomeAttribute.valueOf(outcomeStr));
} catch (ParseException e) {
logger.error(UNABLE_PARSE_TEXT, e);
}
}
/**
* Build a set of features compatible with Smile's datasets from the map of input data
*
* @param data A map containing the input attribute names as keys and the attribute values as values.
* @return A feature vector as a array of doubles.
*/
protected double[] buildFeatures(Map data) {
final double[] features = new double[numAttributes];
for (int i = 0; i < numAttributes; i++) {
final String attrName = attributeNames.get(i);
try {
features[i] = smileAttributes.get(attrName).valueOf(data.get(attrName).toString());
} catch (ParseException e) {
logger.error(UNABLE_PARSE_TEXT, e);
}
}
return features;
}
/**
* Returns the service's identifier
*
* @return The service identifier
*/
@Override
public String getIdentifier() {
return IDENTIFIER;
}
/**
* Returns a model prediction given the input data
*
* @param task Human task data
* @param inputData A map containing the input attribute names as keys and the attribute values as values.
* @return A {@link PredictionOutcome} containing the model's prediction for the input data.
*/
@Override
public PredictionOutcome predict(WorkItem task, Map inputData) {
logger.debug("Predicting with input data: {}", inputData);
RandomForest model = null;
if (observations > MINIMUM_OBSERVATIONS) {
this.confidenceThreshold = 0.75;
}
Map outcomes = new HashMap<>();
if (outcomeSet.size() >= 2) {
model = new RandomForest(dataset, this.numberTrees);
final double[] features = buildFeatures(inputData);
final double[] posteriori = new double[outcomeSet.size()];
double prediction = model.predict(features, posteriori);
String predictionStr = dataset.responseAttribute().toString(prediction);
outcomes.put(outcomeAttribute.getName(), convertValue(predictionStr, outcomeAttributeType));
final double confidence = posteriori[(int) prediction];
outcomes.put("confidence", confidence);
logger.debug("task id {}, total {} observations, prediction = {}, confidence = {} (threshold = {})", task.getId(), this.observations, predictionStr, confidence, this.confidenceThreshold);
return new PredictionOutcome(confidence, this.confidenceThreshold, outcomes);
} else {
outcomes.put("confidence", 0.0);
return new PredictionOutcome(0.0, this.confidenceThreshold, outcomes);
}
}
/**
* Train the random forest model using data from the human task.
*
* @param task Human task data
* @param inputData A map containing the input attribute names as keys and the attribute values as values.
* @param outputData A map containing the output attribute names as keys and the attribute values as values.
*/
@Override
public void train(WorkItem task, Map inputData, Map outputData) {
logger.debug("Training with input data: {}", inputData);
logger.debug("Training with output data: {}", outputData);
this.observations += 1;
addData(inputData, outputData.get(outcomeAttribute.getName()));
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy