org.bigml.binding.LocalPredictiveModel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of bigml-binding Show documentation
Show all versions of bigml-binding Show documentation
An open source Java client that gives you a simple binding to interact with BigML. You can use it to
easily create, retrieve, list, update, and delete BigML resources.
The newest version!
/*
A local Predictive Model.
This module defines a Model to make predictions locally or
embedded into your application without needing to send requests to
BigML.io.
This module cannot only save you a few credits, but also enormously
reduce the latency for each prediction and let you use your models
offline.
You can also visualize your predictive model in IF-THEN rule format
and even generate a java function that implements the model.
Example usage (assuming that you have previously set up the BIGML_USERNAME
and BIGML_API_KEY environment variables and that you own the model/id below):
import org.bigml.binding.BigMLClient;
import org.bigml.binding.resources.Model;
BigMLClient bigmlClient = new BigMLClient();
Model model = new Model(
bigmlClient.getModel('model/5026965515526876630001b2'));
model.predict("{\"petal length\": 3, \"petal width\": 1}");
You can also see model in a IF-THEN rule format with:
model.rules()
Or auto-generate a java function code for the model with:
model.java()
*/
package org.bigml.binding;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.bigml.binding.localmodel.AbstractTree;
import org.bigml.binding.localmodel.BoostedTree;
import org.bigml.binding.localmodel.Predicate;
import org.bigml.binding.localmodel.Prediction;
import org.bigml.binding.localmodel.Tree;
import org.bigml.binding.localmodel.TreeNodeFilter;
import org.bigml.binding.resources.AbstractResource;
import org.bigml.binding.utils.Utils;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.json.simple.JSONValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.text.ParseException;
import java.util.*;
/**
* A lightweight wrapper around a Tree model.
*
* Uses a BigML remote model to build a local version that can be used to
* generate prediction locally.
*
*/
public class LocalPredictiveModel extends ModelFields
implements PredictionConverter, SupervisedModelInterface {
private static final long serialVersionUID = 1L;
private static String MODEL_RE = "^model/[a-f,0-9]{24}$";
/**
* Logging
*/
static Logger logger = LoggerFactory.getLogger(
LocalPredictiveModel.class.getName());
public static Double DEFAULT_IMPURITY = 0.2;
private static final String[] OPERATING_POINT_KINDS = {
"probability", "confidence" };
private String modelId;
private JSONObject root;
private Tree tree;
private BoostedTree boostedTree;
private Map idsMap;
private Map> terms = new HashMap>();
private int maxBins = 0;
protected JSONArray fieldImportance;
protected String objectiveField;
private Boolean regression = false;
private JSONObject boosting = null;
private List classNames = new ArrayList();
private String defaultNumericValue = null;
public LocalPredictiveModel(JSONObject model) throws Exception {
this(null, model);
}
public LocalPredictiveModel(
BigMLClient bigmlClient, JSONObject model) throws Exception {
super(bigmlClient, model);
model = this.model;
JSONObject status = (JSONObject) Utils.getJSONObject(model, "status");
if( status != null &&
status.containsKey("code") &&
AbstractResource.FINISHED == ((Number) status.get("code")).intValue() ) {
JSONObject fields = (JSONObject) Utils.getJSONObject(model, "model.fields");
if (Utils.getJSONObject(model, "model.model_fields") != null) {
fields = (JSONObject) Utils.getJSONObject(model, "model.model_fields");
modelFields = (JSONObject) Utils.getJSONObject(
model, "model.fields");
this.defaultNumericValue = (String) model.get("default_numeric_value");
Iterator iter = fields.keySet().iterator();
while (iter.hasNext()) {
String key = (String) iter.next();
if (modelFields.get(key) == null) {
throw new Exception(
"Some fields are missing to generate a local model. Please, provide a model with the complete list of fields.");
}
}
iter = fields.keySet().iterator();
while (iter.hasNext()) {
String key = (String) iter.next();
JSONObject field = (JSONObject) fields.get(key);
JSONObject modelField = (JSONObject) modelFields.get(key);
field.put("summary", modelField.get("summary"));
field.put("name", modelField.get("name"));
}
}
Object objectiveFields = Utils.getJSONObject(model, "objective_fields");
objectiveField = objectiveFields instanceof JSONArray ? (String) ((JSONArray) objectiveFields)
.get(0) : (String) objectiveFields;
super.initialize(fields, objectiveField, null, null);
JSONArray modelFieldImportance = (JSONArray) Utils.getJSONObject(model, "model.importance", null);
if (modelFieldImportance != null) {
fieldImportance = new JSONArray();
for (Object element : modelFieldImportance) {
JSONArray elementItem = (JSONArray) element;
if (fields.containsKey(elementItem.get(0).toString())) {
fieldImportance.add(elementItem);
}
}
}
// boosting models are to be handled using the BoostedTree
// class
boolean boostedEnsemble = (Boolean) Utils.getJSONObject(
model, "boosted_ensemble", false);
if (boostedEnsemble) {
this.boosting = (JSONObject) Utils.getJSONObject(
model, "boosting", null);
}
String optype = (String) Utils.getJSONObject(
fields, objectiveField + ".optype");
this.regression =
(!isBoosting() && "numeric".equals(optype) ) ||
(isBoosting() && boosting.get("objective_class") == null);
this.root = (JSONObject) Utils.getJSONObject(model, "model.root");
this.idsMap = new HashMap();
if (isBoosting()) {
this.boostedTree = new BoostedTree(
root, this.fields, objectiveField);
} else {
// will store global information in the tree: regression and
// max_bins number
JSONObject distribution = (JSONObject) Utils.getJSONObject(
model, "model.distribution.training");
JSONObject treeInfo = new JSONObject();
treeInfo.put("max_bins", maxBins);
this.tree = new Tree(root, this.fields, objectiveField,
distribution, null, idsMap, true, treeInfo);
if (this.tree.isRegression()) {
this.maxBins = this.tree.getMaxBins();
} else {
JSONArray rootDist = (JSONArray) this.tree.getDistribution();
for (Object dist: rootDist) {
classNames.add((String) ((JSONArray) dist).get(0));
}
Collections.sort(classNames);
JSONArray categories = (JSONArray) Utils.getJSONObject(
(JSONObject) fields.get(objectiveField),
"summary.categories", new JSONArray());
}
}
}
}
/**
* Returns reg expr for model Id.
*/
public String getModelIdRe() {
return MODEL_RE;
}
/**
* Returns bigml resource JSONObject.
*/
public JSONObject getBigMLModel(String modelId) {
return (JSONObject) this.bigmlClient.getModel(modelId);
}
public String getResourceId() {
return modelId;
}
/**
* Describes and return the fields for this model.
*
* @return the fields for the model
*/
public JSONObject fields() {
return isBoosting() ? boostedTree.listFields() : tree.listFields();
}
/**
* Sets the fields for this model.
*
* @param fields the fields
*/
public void setFields(JSONObject fields) {
this.fields = fields;
}
/**
* Returns the class names
*/
public List getClassNames() {
return classNames;
}
/**
* Sets the classNames for this model.
*
* @param classNames the list of classes names
*/
public void setClassNames(List classNames) {
this.classNames = classNames;
}
/**
* Correction term based on the training dataset distribution
*
*/
private HashMap laplacianTerm() {
HashMap categoryMap = new HashMap();
JSONArray rootDist = (JSONArray) this.tree.getDistribution();
if (this.tree.getWeighted()) {
for (Object dist: rootDist) {
JSONArray category = (JSONArray) dist;
String cat = (String) category.get(0);
categoryMap.put(cat, 0.0);
}
} else {
double total = 0.0;
for (Object dist: rootDist) {
total += ((Number) ((JSONArray) dist).get(1)).doubleValue();
}
for (Object dist: rootDist) {
JSONArray category = (JSONArray) dist;
String cat = (String) category.get(0);
Double value = ((Number)category.get(1)).doubleValue();
categoryMap.put(cat, value / total);
}
}
return categoryMap;
}
/**
* Checks if the tree is a regression problem
*
* @return if the tree is a regression problem
*/
public boolean isRegression() {
return tree.isRegression();
}
/**
* Checks if the tree is a boosting problem
*
* @return if the tree is a boosting problem
*/
public boolean isBoosting() {
return this.boosting != null && this.boosting.size() > 0;
}
/**
* Checks if the tree is a boosting problem
*
* @return if the tree is a boosting problem
*/
public JSONObject getBoosting() {
return this.boosting;
}
/**
* Returns a list that includes all the leaves of the model.
*
* @return all the leave nodes
*/
public List getLeaves() {
return this.tree.getLeaves(null);
}
/**
* Returns a list that includes all the leaves of the model.
*
* @param filter should be a function that returns a boolean
* when applied to each leaf node.
*
* @return all the leave nodes after apply the filter
*/
public List getLeaves(TreeNodeFilter filter) {
return this.tree.getLeaves(filter);
}
/**
* Returns a list that includes all the leaves of the model.
*
* @return all the leave nodes
*/
public List getBoostedLeaves() {
return this.boostedTree.getLeaves();
}
/**
* Returns True if the gini impurity of the node distribution
* goes above the impurity threshold.
*
* @param impurityThreshold the degree of impurity
*
* @return all the leave nodes after apply the impurity threshold
*/
public List getImpureLeaves(Double impurityThreshold) {
if (isBoosting() || isRegression()) {
throw new IllegalArgumentException(
"This method is available for non-boosting " +
"categorization models only.");
}
final Double impurityThresholdToUse = (impurityThreshold == null ?
DEFAULT_IMPURITY : impurityThreshold);
return this.tree.getLeaves(new TreeNodeFilter() {
@Override
public boolean filter(Tree node) {
Double nodeImpurity = node.getImpurity();
return (nodeImpurity != null && nodeImpurity > impurityThresholdToUse);
}
});
}
/**
* Makes a prediction based on a number of field values.
*
* The input fields must be keyed by field name.
*
* @param args the data to be predicted
*
* @return the prediction for the args
* @throws InputDataParseException a input data parse exception
*/
public Prediction predict(final String args)
throws InputDataParseException {
return predict(args);
}
/**
* Makes a prediction based on a number of field values.
*
* The input fields must be keyed by field name.
*
* @param args the data to be predicted
*
* @return the prediction for the args
* @throws Exception a generic exception
*/
public Prediction predict(final JSONObject args)
throws Exception {
return (Prediction) predict(args, MissingStrategy.LAST_PREDICTION, null, null, true);
}
/**
* Makes a prediction based on a number of field values using the
* specified Missing Strategy
*
* The input fields must be keyed by field name.
*
* @param args
* the data to be predicted
* @param strategy
* LAST_PREDICTION|PROPORTIONAL missing strategy for missing fields
*
* @return the prediction for the args
* @throws Exception a generic exception
*/
public Prediction predict(final JSONObject args, MissingStrategy strategy)
throws Exception {
return predict(args, strategy, null, null, true, null);
}
/**
* Makes a multiple predictions based on a number of field values using the Last Prediction strategy
*
* The input fields must be keyed by field name.
*
* @param args
* the data to be predicted
* @param multiple
* For categorical fields, it will return the categories
* in the distribution of the predicted node as a list of dicts:
* [{'prediction': 'Iris-setosa',
* 'confidence': 0.9154
* 'probability': 0.97
* 'count': 97},
* {'prediction': 'Iris-virginica',
* 'confidence': 0.0103
* 'probability': 0.03,
* 'count': 3}]
*
* The value of this argument can either be an integer
* (maximum number of categories to be returned), or the
* literal 'all', that will cause the entire distribution
* in the node to be returned.
*
* @return the prediction for the args
* @throws InputDataParseException an input date parse exception
*/
public List predict(final JSONObject args, Object multiple)
throws InputDataParseException {
return predict(args, MissingStrategy.LAST_PREDICTION, multiple);
}
/**
* Convenience version of predict that take as inputs a map from field ids
* or names to their values as Java objects. See also predict(String,
* Boolean, Integer, Boolean).
*
* @param inputs
* Input data to be predicted
* @param missingStrategy
* LAST_PREDICTION|PROPORTIONAL missing strategy for missing fields
*
* @return the prediction for the inputs
* @throws Exception a generic exception
*/
public Prediction predictWithMap(
final Map inputs, MissingStrategy missingStrategy)
throws Exception {
JSONObject inputObj = (JSONObject) JSONValue.parse(JSONValue
.toJSONString(inputs));
return predict(inputObj, missingStrategy, null, null, true, null);
}
public Prediction predictWithMap(
final Map inputs) throws Exception {
JSONObject inputObj = (JSONObject) JSONValue.parse(JSONValue
.toJSONString(inputs));
return predict(inputObj, MissingStrategy.LAST_PREDICTION, null, null, true);
}
/**
* Makes a prediction based on a number of field values.
*
* inputData: Input data to be predicted
*
* @param args
* the data to be predicted
* @param strategy
* LAST_PREDICTION|PROPORTIONAL missing strategy for missing fields
* @param multiple
* For categorical fields, it will return the categories
* in the distribution of the predicted node as a list of dicts:
* [{'prediction': 'Iris-setosa',
* 'confidence': 0.9154
* 'probability': 0.97
* 'count': 97},
* {'prediction': 'Iris-virginica',
* 'confidence': 0.0103
* 'probability': 0.03,
* 'count': 3}]
*
* The value of this argument can either be an integer
* (maximum number of categories to be returned), or the
* literal 'all', that will cause the entire distribution
* in the node to be returned.
*
* @return the prediction for the inputData
* @throws InputDataParseException an input date parse exception
*/
public List predict(final JSONObject args, MissingStrategy strategy, Object multiple)
throws InputDataParseException {
List outputs = new ArrayList();
if (strategy == null) {
strategy = MissingStrategy.LAST_PREDICTION;
}
if (args == null) {
throw new InputDataParseException("Input data format not valid");
}
JSONObject inputData = args;
// Checks and cleans inputData leaving the fields used in the model
inputData = filterInputData(inputData);
Integer multipleNum = Integer.MAX_VALUE;
Boolean multipleAll = false;
if( multiple != null ) {
if( multiple instanceof String ) {
if( "all".equals(multiple) ) {
multipleAll = true;
} else {
throw new IllegalArgumentException("The value of the \"multiple\"" +
" argument can either be an integer" +
" (maximum number of categories to be returned), or the" +
" literal 'all', that will cause the entire distribution" +
" in the node to be returned.");
}
} else if( multiple instanceof Number ) {
multipleNum = ((Number) multiple).intValue();
} else {
throw new IllegalArgumentException("The value of the \"multiple\"" +
" argument can either be an integer" +
" (maximum number of categories to be returned), or the" +
" literal 'all', that will cause the entire distribution" +
" in the node to be returned.");
}
}
// Strips affixes for numeric values and casts to the final field type
Utils.cast(inputData, fields);
Prediction predictionInfo = tree.predict(inputData, null, strategy);
JSONArray distribution = predictionInfo.getDistribution();
Long instances = predictionInfo.getCount();
if( multiple != null && !tree.isRegression() ) {
for( int iDistIndex = 0; iDistIndex < distribution.size(); iDistIndex++ ) {
JSONArray distElement = (JSONArray) distribution.get(iDistIndex);
if( multipleAll || iDistIndex < multipleNum ) {
predictionInfo = new Prediction();
// Category
Object category = distElement.get(0);
predictionInfo.setPrediction(category);
predictionInfo.setConfidence(Tree.wsConfidence(category, distribution));
predictionInfo.setProbability(((Number) distElement.get(1)).doubleValue() / instances);
predictionInfo.setCount(((Number) distElement.get(1)).longValue());
outputs.add(predictionInfo);
}
}
return outputs;
} else {
List children = predictionInfo.getChildren();
String field = (children == null || children.size() == 0 ? null : children.get(0).getPredicate().getField());
if( field != null && fields.containsKey(field) ) {
field = fieldsNameById.get(field);
}
predictionInfo.setNext(field);
outputs.add(predictionInfo);
return outputs;
}
}
public Prediction predict(
JSONObject inputData, MissingStrategy missingStrategy,
JSONObject operatingPoint, String operatingKind, Boolean full)
throws Exception {
return predict(inputData, missingStrategy, operatingPoint,
operatingKind, full, null);
}
/**
* Makes a prediction based on a number of field values.
*
* @param inputData Input data to be predicted
* @param missingStrategy LAST_PREDICTION|PROPORTIONAL missing strategy for
* missing fields
* @param operatingPoint
* In classification models, this is the point of the
* ROC curve where the model will be used at. The
* operating point can be defined in terms of:
* - the positive class, the class that is important to
* predict accurately
* - the probability_threshold (or confidence_threshold),
* the probability (or confidence) that is stablished
* as minimum for the positive_class to be predicted.
* The operating_point is then defined as a map with
* two attributes, e.g.:
* {"positive_class": "Iris-setosa",
* "probability_threshold": 0.5}
* or
* {"positive_class": "Iris-setosa",
* "confidence_threshold": 0.5}
* @param operatingKind
* "probability" or "confidence". Sets the property that
* decides the prediction. Used only if no operating_point
* is used
*
* @param full
* Boolean that controls whether to include the prediction's
* attributes. By default, only the prediction is produced. If set
* to True, the rest of available information is added in a
* dictionary format. The dictionary keys can be:
* - prediction: the prediction value
* - confidence: prediction's confidence
* - probability: prediction's probability
* - path: rules that lead to the prediction
* - count: number of training instances supporting the
* prediction
* - next: field to check in the next split
* - min: minim value of the training instances in the
* predicted node
* - max: maximum value of the training instances in the
* predicted node
* - median: median of the values of the training instances
* in the predicted node
* - unused_fields: list of fields in the input data that
* are not being used in the model
*
* @param unusedFields
* Unused fields to include in the prediction
*
* @return the prediction for the inputData
* @throws Exception a generic exception
*/
public Prediction predict(
JSONObject inputData, MissingStrategy missingStrategy,
JSONObject operatingPoint, String operatingKind, Boolean full,
List unusedFields) throws Exception {
if (missingStrategy == null) {
missingStrategy = MissingStrategy.LAST_PREDICTION;
}
if (full == null) {
full = false;
}
// Checks and cleans inputData leaving the fields used in the model
inputData = filterInputData(inputData, full);
if (unusedFields == null) {
unusedFields = (List) inputData.get("unusedFields");
}
inputData = (JSONObject) inputData.get("newInputData");
// Strips affixes for numeric values and casts to the final field type
Utils.cast(inputData, fields);
// When operating_point is used, we need the probabilities
// (or confidences) of all possible classes to decide, so se use
// the `predict_probability` or `predict_confidence` methods
if (operatingPoint != null) {
if (regression) {
throw new IllegalArgumentException(
"The operating_point argument can only be" +
" used in classifications.");
}
return predictOperating(inputData, missingStrategy, operatingPoint);
}
if (operatingKind != null) {
if (regression) {
throw new IllegalArgumentException(
"The operating_kind argument can only be" +
" used in classifications.");
}
return predictOperatingKind(inputData, missingStrategy, operatingKind);
}
Prediction prediction = isBoosting() ?
this.boostedTree.predict(inputData, null, missingStrategy) :
this.tree.predict(inputData, null, missingStrategy);
if (isBoosting() && missingStrategy == MissingStrategy.PROPORTIONAL) {
// output has to be recomputed and comes in a different format
HashMap pred = (HashMap) prediction.get("prediction");
Double gSum = (Double) pred.get("g_sum");
Double hSum = (Double) pred.get("h_sum");
Long population = ((Number) prediction.get("count")).longValue();
List path = (List) prediction.get("path");
Long lambda = (Long) this.boosting.get("lambda");
prediction = new Prediction(
(- gSum / (hSum + lambda)), population, path, null);
}
// next
List children = (List) prediction.get("children");
String field = (children == null || children.size() == 0 ?
null : ((AbstractTree) children.get(0)).getPredicate().getField());
if( field != null && fields.containsKey(field) ) {
field = fieldsNameById.get(field);
}
prediction.setNext(field);
prediction.remove("children");
if (!isBoosting() && !isRegression()) {
String pred = (String) prediction.get("prediction");
HashMap probabilities = probabilities(
(JSONArray) prediction.get("distribution"));
prediction.put("probability", probabilities.get(pred));
}
if (full) {
prediction.put("unused_fields", unusedFields);
}
return prediction;
}
/**
* Computes the probability of a distribution using a Laplacian correction
*/
private HashMap probabilities(JSONArray distribution) {
HashMap categoryMap = laplacianTerm();
double total = this.tree.getWeighted() ? 0 : 1;
if (distribution != null) {
for (Object item : distribution) {
JSONArray distInfo = (JSONArray) item;
String cat = (String) distInfo.get(0);
Double value = ((Number) distInfo.get(1)).doubleValue();
categoryMap.put(cat, categoryMap.get(cat) + value);
total += value;
}
for (String key : categoryMap.keySet()) {
categoryMap.put(key, categoryMap.get(key) / total);
}
}
return categoryMap;
}
/**
*
*/
private JSONArray toOutput(HashMap categoryMap, String key) {
JSONArray output = new JSONArray();
for (String name: classNames) {
Prediction element = new Prediction();
element.put("category", name);
element.put(key, Utils.roundOff(categoryMap.get(name), Constants.PRECISION));
output.add(element);
}
return output;
}
/**
* For classification models, Predicts a probability for
* each possible output class, based on input values. The input
* fields must be a dictionary keyed by field name or field ID.
*
* For regressions, the output is a single element list
* containing the prediction.
*
* @param inputData Input data to be predicted
* @param missingStrategy LAST_PREDICTION|PROPORTIONAL missing strategy
* for missing fields
*/
public JSONArray predictProbability(
JSONObject inputData, MissingStrategy missingStrategy)
throws Exception {
JSONArray output = new JSONArray();
Prediction prediction = null;
if (isBoosting() || isRegression()) {
prediction = predict(inputData, missingStrategy,
null, null, true);
output.add(prediction);
} else {
prediction = predict(inputData, missingStrategy,
null, null, true);
HashMap categoryMap = probabilities(
(JSONArray) prediction.get("distribution"));
output = toOutput(categoryMap, "probability");
}
return output;
}
/**
* For classification models, Predicts a confidence for
* each possible output class, based on input values. The input
* fields must be a dictionary keyed by field name or field ID.
*
* For regressions, the output is a single element list
* containing the prediction.
*
* @param inputData Input data to be predicted
* @param missingStrategy LAST_PREDICTION|PROPORTIONAL missing strategy
* for missing fields
*
* @return the predicted confidence
* @throws Exception a generic exception
*/
public JSONArray predictConfidence(
JSONObject inputData, MissingStrategy missingStrategy)
throws Exception {
JSONArray output = new JSONArray();
Prediction prediction = null;
if (isRegression()) {
prediction = predict(inputData, missingStrategy,
null, null, true);
output.add(prediction);
} else {
if (isBoosting()) {
throw new IllegalArgumentException(
"This method is available for non-boosting" +
" models only.");
}
}
HashMap categoryMap = new HashMap();
JSONArray distribution = tree.getDistribution();
for (Object item : distribution) {
JSONArray distInfo = (JSONArray) item;
categoryMap.put((String) distInfo.get(0), 0.0);
}
prediction = predict(inputData, missingStrategy,
null, null, true);
distribution = (JSONArray) prediction.get("distribution");
for (Object item : distribution) {
JSONArray distInfo = (JSONArray) item;
String name = (String) distInfo.get(0);
categoryMap.put(name, Tree.wsConfidence(name, distribution));
}
return toOutput(categoryMap, "confidence");
}
/**
* Computes the prediction based on a user-given operating point.
*/
private Prediction predictOperating(
JSONObject inputData, MissingStrategy missingStrategy,
JSONObject operatingPoint) throws Exception {
if (missingStrategy == null) {
missingStrategy = MissingStrategy.LAST_PREDICTION;
}
Object[] operating = Utils.parseOperatingPoint(
operatingPoint, OPERATING_POINT_KINDS, classNames);
String kind = (String) operating[0];
Double threshold = (Double) operating[1];
String positiveClass = (String) operating[2];
JSONArray predictions = null;
if (kind.equals("probability")) {
predictions = predictProbability(inputData, missingStrategy);
} else {
predictions = predictConfidence(inputData, missingStrategy);
}
for (Object pred: predictions) {
Prediction prediction = (Prediction) pred;
String category = (String) prediction.get("category");
if (category.equals(positiveClass) &&
(Double) prediction.get(kind) > threshold) {
prediction.put("prediction", prediction.get("category"));
prediction.remove("category");
return prediction;
}
}
Prediction prediction = (Prediction) predictions.get(0);
String category = (String) prediction.get("category");
if (category.equals(positiveClass)) {
prediction = (Prediction) predictions.get(1);
}
prediction.put("prediction", prediction.get("category"));
prediction.remove("category");
return prediction;
}
/**
* Computes the prediction based on a user-given operating kind.
*/
private Prediction predictOperatingKind(
JSONObject inputData, MissingStrategy missingStrategy,
String operatingKind) throws Exception {
if (missingStrategy == null) {
missingStrategy = MissingStrategy.LAST_PREDICTION;
}
String kind = operatingKind.toLowerCase();
if (!Arrays.asList(OPERATING_POINT_KINDS).contains(kind)) {
throw new IllegalArgumentException(
String.format("Allowed operating kinds are %", (Object[]) OPERATING_POINT_KINDS));
}
JSONArray predictions = null;
if (kind.equals("probability")) {
predictions = predictProbability(inputData, missingStrategy);
} else {
predictions = predictConfidence(inputData, missingStrategy);
}
sortPredictions(predictions, kind);
Prediction prediction = (Prediction) predictions.get(0);
prediction.put("prediction", prediction.get("category"));
prediction.remove("category");
return prediction;
}
/**
* Builds the list of ids that go from a given id to the tree root
*
* @param filterId id for the node to filter with the model
*
* @return a list of ids
*/
public List getIdsPath(String filterId) {
List idsPath = null;
if( (filterId != null && filterId.length() > 0 ) &&
(tree.getId() != null && tree.getId().length() > 0) ) {
if( !idsMap.containsKey(filterId) ) {
throw new IllegalArgumentException(
String.format("The given id for the filter does " +
"not exist. Filter Id: %s", filterId));
} else {
idsPath = new ArrayList();
idsPath.add(filterId);
String lastId = filterId;
while(idsMap.containsKey(lastId) && idsMap.get(lastId).getParentId() != null) {
idsPath.add(idsMap.get(lastId).getParentId());
lastId = idsMap.get(lastId).getParentId();
}
}
}
return idsPath;
}
/**
* Returns a IF-THEN rule set that implements the model.
*
* @return the IF-THEN rule set for the model
*/
public String rules() {
if (isBoosting()) {
throw new IllegalArgumentException(
"This method is not available for boosting models. ");
}
return tree.rules(Predicate.RuleLanguage.PSEUDOCODE);
}
/**
* Returns a IF-THEN rule set that implements the model.
*
* @param language programming language for rules PSEUDOCODE, JAVA, PYTHON, TABLEAU
*
* @return the IF-THEN rule set for the model
*/
public String rules(Predicate.RuleLanguage language) {
if (isBoosting()) {
throw new IllegalArgumentException(
"This method is not available for boosting models. ");
}
return tree.rules(language);
}
/**
* Returns a IF-THEN rule set that implements the model.
*
* @param language programming language for rules PSEUDOCODE, JAVA, PYTHON, TABLEAU
* @param filterId id for the node to filter with the model
* @param subtree the subtree of the model to process
*
* @return the IF-THEN rule set for the model
*/
public String rules(Predicate.RuleLanguage language, final String filterId, boolean subtree) {
if (isBoosting()) {
throw new IllegalArgumentException(
"This method is not available for boosting models. ");
}
List idsPath = getIdsPath(filterId);
return tree.rules(language, idsPath, subtree);
}
/**
* Given a prediction string, returns its value in the required type
*
* @param valueAsString the prediction value as string
*/
@Override
public Object toPrediction(String valueAsString, Locale locale) {
locale = (locale != null ? locale : BigMLClient.DEFAUL_LOCALE);
String objectiveFieldName = isBoosting() ?
boostedTree.getObjectiveField() :
tree.getObjectiveField();
if( "numeric".equals(Utils.getJSONObject(fields, objectiveFieldName + ".optype")) ) {
String dataTypeStr = (String) Utils.getJSONObject(fields, objectiveFieldName + ".'datatype'");
DataTypeEnum dataType = DataTypeEnum.valueOf(dataTypeStr.toUpperCase().replace("-",""));
switch (dataType) {
case DOUBLE:
case FLOAT:
try {
NumberFormat formatter = DecimalFormat.getInstance(locale);
return formatter.parse(valueAsString).doubleValue();
} catch (ParseException e) {
e.printStackTrace();
return valueAsString;
}
case INTEGER:
case INT8:
case INT16:
case INT32:
case INT64:
case DAY:
case MONTH:
case YEAR:
case HOUR:
case MINUTE:
case SECOND:
case MILLISECOND:
case DAYOFMONTH:
case DAYOFWEEK:
try {
NumberFormat formatter = NumberFormat.getInstance(locale);
return formatter.parse(valueAsString).longValue();
} catch (ParseException e) {
e.printStackTrace();
return valueAsString;
}
}
}
return valueAsString;
}
/**
* Average for the confidence of the predictions resulting from
* running the training data through the model
*
* @return the average for the confidence of the predictions
*/
public double getAverageConfidence() {
double total = 0.0, cumulativeConfidence = 0.0;
Map