org.bigml.binding.LocalFusion Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of bigml-binding Show documentation
Show all versions of bigml-binding Show documentation
An open source Java client that gives you a simple binding to interact with BigML. You can use it to
easily create, retrieve, list, update, and delete BigML resources.
package org.bigml.binding;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.bigml.binding.resources.AbstractResource;
import org.bigml.binding.utils.Utils;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A local Predictive Fusion.
*
* This module defines a Fusion to make predictions locally using its
* associated models.
*
* This module can not only save you a few credits, but also enormously
* reduce the latency for each prediction and let you use your models
* offline.
* Example usage (assuming that you have previously set up the BIGML_USERNAME
* and BIGML_API_KEY environment variables and that you own the
* fusion/id below):
*
*
* import org.bigml.binding.LocalFusion;
*
* // API client
* BigMLClient api = new BigMLClient();
*
* JSONObject fusion = api.
* getFusion("fusion/5026965515526876630001b2");
* LocalFusion localFusion = new LocalFusion(fusion)
*
* JSONObject predictors = JSONValue.parse("{\"petal length\": 3, \"petal width\": 1}");
*
* localFusion.predict(predictors)
*
*/
public class LocalFusion extends ModelFields implements SupervisedModelInterface {
private static final long serialVersionUID = 1L;
static String FUSION_RE = "^fusion/[a-f,0-9]{24}$";
private static final String[] OPERATING_POINT_KINDS = {"probability"};
private static final String[] LOCAL_SUPERVISED = {
"model", "ensemble", "logisticregression", "deepnet",
"fusion" };
/**
* Logging
*/
static Logger logger = LoggerFactory.getLogger(
LocalFusion.class.getName());
private String fusionId;
private String objectiveField = null;
private JSONArray modelsIds;
private List weights = new ArrayList();
private final List modelsSplit = new ArrayList();
private Boolean regression = false;
private List classNames = new ArrayList();
private Boolean missingNumerics = true;
public LocalFusion(JSONObject fusion)
throws Exception {
this(fusion, null);
}
public LocalFusion(JSONObject fusion, Integer maxModels)
throws Exception {
super((JSONObject) Utils.getJSONObject(
fusion, "fusion.fields", new JSONObject()));
// checks whether the information needed for local predictions
// is in the first argument
if (!checkModelFields(fusion)) {
// if the fields used by the logistic regression are not
// available, use only ID to retrieve it again
fusionId = (String) fusion.get("resource");
boolean validId = fusionId.matches(FUSION_RE);
if (!validId) {
throw new Exception(
fusionId + " is not a valid resource ID.");
}
}
if (!(fusion.containsKey("resource")
&& fusion.get("resource") != null)) {
BigMLClient client = new BigMLClient(null, null,
BigMLClient.STORAGE);
fusion = client.getFusion(fusionId);
if ((String) fusion.get("resource") == null) {
throw new Exception(
fusionId + " is not a valid resource ID.");
}
}
if (fusion.containsKey("object") &&
fusion.get("object") instanceof JSONObject) {
fusion = (JSONObject) fusion.get("object");
}
fusionId = (String) fusion.get("resource");
if (fusion.containsKey("fusion")
&& fusion.get("fusion") instanceof JSONObject) {
JSONObject status = (JSONObject) Utils.getJSONObject(fusion,
"status");
if (status != null && status.containsKey("code")
&& AbstractResource.FINISHED == ((Number) status
.get("code")).intValue()) {
JSONObject fusionInfo = (JSONObject) Utils
.getJSONObject(fusion, "fusion");
modelsIds = new JSONArray();
for (Object modelId: (JSONArray) fusion.get("models")) {
String model = null;
if (modelId instanceof String) {
model = (String) modelId;
} else {
model = (String) ((JSONObject) modelId).get("id");
try {
weights.add(((Number) ((JSONObject) modelId).get("weight")).doubleValue());
} catch (Exception e) {
weights = new ArrayList();
}
}
modelsIds.add(model);
String type = model.split("/")[0];
if (!Arrays.asList(LOCAL_SUPERVISED).contains(type)) {
throw new IllegalArgumentException(
String.format("The resource %s has not an allowed supervised model type.", OPERATING_POINT_KINDS));
}
}
missingNumerics = (Boolean) Utils.getJSONObject(fusion, "missing_numerics", true);
JSONObject fields = (JSONObject) Utils.getJSONObject(
fusionInfo, "fields", new JSONObject());
// initialize ModelFields
super.initialize((JSONObject) fields, null, null, null,
true, true, true);
objectiveField = (String) Utils.getJSONObject(
fusion, "objective_field");
// Apply maxModels
int numberOfModels = modelsIds.size();
if( maxModels != null) {
int[] items = Utils.getRange(0, numberOfModels, maxModels);
for (int item : items) {
if( item+maxModels <= numberOfModels ) {
JSONArray arrayOfModels = new JSONArray();
arrayOfModels.addAll(modelsIds.subList(item, item + maxModels));
modelsSplit.add(arrayOfModels);
}
}
} else {
modelsSplit.add(modelsIds);
}
String optype = (String) Utils.getJSONObject(
fields, objectiveField + ".optype");
regression = "numeric".equals(optype);
if (!regression) {
JSONArray categories = (JSONArray) Utils.getJSONObject(
(JSONObject) fields.get(objectiveField),
"summary.categories", new JSONArray());
for (Object cat: categories) {
classNames.add((String) ((JSONArray) cat).get(0));
}
Collections.sort(classNames);
}
} else {
throw new Exception(
"The Fusion isn't finished yet");
}
} else {
throw new Exception(String
.format("Cannot create the Fusion instance. "
+ "Could not find the 'fusion' key in "
+ "the resource:\n\n%s", fusion));
}
}
/**
* Returns the resourceId
*/
public String getResourceId() {
return fusionId;
}
/**
* Returns the class names
*/
public List getClassNames() {
return classNames;
}
/**
* For classification models, Predicts a probability for
* each possible output class, based on input values. The input
* fields must be a dictionary keyed by field name or field ID.
*
* For regressions, the output is a single element list
* containing the prediction.
*
* @param inputData Input data to be predicted
* @param missingStrategy LAST_PREDICTION|PROPORTIONAL missing strategy
* for missing fields
*/
public JSONArray predictProbability(
JSONObject inputData, MissingStrategy missingStrategy)
throws Exception {
if (missingStrategy == null) {
missingStrategy = MissingStrategy.LAST_PREDICTION;
}
MultiVoteList votes = new MultiVoteList(null);
if (!this.missingNumerics) {
Utils.checkNoMissingNumerics(inputData, this.fields, null);
}
BigMLClient bigmlClient = new BigMLClient();
for (Object modelSplit: modelsSplit) {
MultiVoteList votesSplit = new MultiVoteList(null);
List models = new ArrayList();
for (Object modelId: (JSONArray) modelSplit) {
String type = ((String) modelId).split("/")[0];
JSONObject model = null;
if ("model".equals(type)) {
model = bigmlClient.getModel((String) modelId);
models.add(new LocalPredictiveModel(model));
}
if ("ensemble".equals(type)) {
model = bigmlClient.getEnsemble((String) modelId);
models.add(new LocalEnsemble(model));
}
if ("logisticregression".equals(type)) {
model = bigmlClient.getLogisticRegression((String) modelId);
models.add(new LocalLogisticRegression(model));
}
if ("deepnet".equals(type)) {
model = bigmlClient.getDeepnet((String) modelId);
models.add(new LocalDeepnet(model));
}
if ("fusion".equals(type)) {
model = bigmlClient.getFusion((String) modelId);
models.add(new LocalFusion(model));
}
}
JSONArray predictions;
for (SupervisedModelInterface model: models) {
try {
predictions = model.predictProbability(inputData, missingStrategy);
} catch (Exception e) {
// logistic regressions can raise this error if they
// have missing_numerics=False and some numeric missings
// are found
continue;
}
List predictionList = new ArrayList();
for (Object pred : predictions) {
JSONObject p = (JSONObject) pred;
predictionList.add((Double) p.get("probability"));
}
if (!this.weights.isEmpty()) {
predictionList = weight(predictionList, model.getResourceId());
}
// we need to check that all classes in the fusion
// are also in the composing model
if (!this.regression && !this.classNames.equals(model.getClassNames())) {
try {
predictionList = rearrangePrediction(model.getClassNames(), this.classNames, predictionList);
} catch (Exception e) {}
}
votesSplit.append(predictionList);
}
votes.extend(votesSplit);
}
JSONArray output = new JSONArray();
if (this.regression) {
double totalWeight = 1;
if (!this.weights.isEmpty()) {
totalWeight = 0;
for (Double w: this.weights) {
totalWeight += w;
}
}
double sum = 0.0;
for (Object votesPreds: votes.predictions) {
List preds = (List) votesPreds;
for (Double p: preds) {
sum += p;
}
}
float divisor = ((Double) (votes.predictions.size() * totalWeight)).floatValue();
JSONObject prediction = new JSONObject();
prediction.put("prediction", sum / divisor);
output.add(prediction);
} else {
List probabilities = votes.combineToDistribution(true);
for (int i = 0; i < classNames.size(); i++) {
JSONObject prediction = new JSONObject();
prediction.put("prediction", (String) classNames.get(i));
prediction.put("probability", probabilities.get(i));
output.add(prediction);
}
}
return output;
}
/**
* Weighs the prediction according to the weight associated to the
* current model in the fusion.
*/
private List weight(List predictions, String modelId) {
for (Double probability: predictions) {
probability *= this.weights.get(this.modelsIds.indexOf(modelId));
}
return predictions;
}
/**
* Rearranges the probabilities in a compact array when the
* list of classes in the destination resource does not match the
* ones in the origin resource.
*/
private List rearrangePrediction(
List originClasses, List destinationClasses, List predictions) {
List newPrediction = new ArrayList();
for (String className: destinationClasses) {
int originIndex = originClasses.indexOf(className);
if (originIndex > -1) {
newPrediction.add((Double) predictions.get(originIndex));
} else {
newPrediction.add(0.0);
}
}
return newPrediction;
}
/**
* Computes the prediction based on a user-given operating point.
*/
private HashMap predictOperating(
JSONObject inputData, MissingStrategy missingStrategy,
JSONObject operatingPoint) throws Exception {
if (missingStrategy == null) {
missingStrategy = MissingStrategy.LAST_PREDICTION;
}
// only probability is allowed as operating kind
Object[] operating = Utils.parseOperatingPoint(
operatingPoint, OPERATING_POINT_KINDS, classNames);
String kind = (String) operating[0];
Double threshold = (Double) operating[1];
String positiveClass = (String) operating[2];
if (!Arrays.asList(OPERATING_POINT_KINDS).contains(kind)) {
throw new IllegalArgumentException(
String.format("Allowed operating kinds are %", OPERATING_POINT_KINDS));
}
JSONArray predictions = predictProbability(
inputData, missingStrategy);
for (Object pred: predictions) {
HashMap prediction = (HashMap) pred;
String category = (String) prediction.get("category");
if (category.equals(positiveClass) &&
(Double) prediction.get(kind) > threshold) {
return prediction;
}
}
HashMap prediction
= (HashMap) predictions.get(0);
String category = (String) prediction.get("category");
if (category.equals(positiveClass)) {
prediction = (JSONObject) predictions.get(1);
}
prediction.put("prediction", prediction.get("category"));
prediction.remove("category");
return prediction;
}
/**
* Makes a prediction based on a number of field values.
*
* @param inputData Input data to be predicted
* @param missingStrategy numeric key for the individual model's
* prediction method. See the model predict
* method.
* @param operatingPoint
* In classification models, this is the point of the
* ROC curve where the model will be used at. The
* operating point can be defined in terms of:
* - the positive_class, the class that is important to
* predict accurately
* - the probability_threshold,
* the probability that is stablished
* as minimum for the positive_class to be predicted.
* The operating_point is then defined as a map with
* two attributes, e.g.:
* {"positive_class": "Iris-setosa",
* "probability_threshold": 0.5}
* @param full
* Boolean that controls whether to include the prediction's
* attributes. By default, only the prediction is produced. If set
* to True, the rest of available information is added in a
* dictionary format. The dictionary keys can be:
* - prediction: the prediction value
* - probability: prediction's probability
* - unused_fields: list of fields in the input data that
* are not being used in the model
*
*/
public HashMap predict(
JSONObject inputData, MissingStrategy missingStrategy,
JSONObject operatingPoint, Boolean full)
throws Exception {
if (missingStrategy == null) {
missingStrategy = MissingStrategy.LAST_PREDICTION;
}
if (full == null) {
full = false;
}
// Checks and cleans inputData leaving the fields used in the model
inputData = filterInputData(inputData, full);
List unusedFields = (List)
inputData.get("unusedFields");
inputData = (JSONObject) inputData.get("newInputData");
if (!this.missingNumerics) {
Utils.checkNoMissingNumerics(inputData, this.fields, null);
}
// Strips affixes for numeric values and casts to the final field type
Utils.cast(inputData, fields);
// When operating_point is used, we need the probabilities
// of all possible classes to decide, so se use
// the `predict_probability` method
if (operatingPoint != null) {
if (regression) {
throw new IllegalArgumentException(
"The operating_point argument can only be" +
" used in classifications.");
}
HashMap prediction = predictOperating(
inputData, missingStrategy, operatingPoint);
return prediction;
}
JSONArray predictions = predictProbability(
inputData, missingStrategy);
if (!regression) {
Utils.sortPredictions(predictions, "probability", "prediction");
}
HashMap prediction
= (HashMap) predictions.get(0);
// adding unused fields, if any
if (full) {
prediction.put("unused_fields", unusedFields);
}
return prediction;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy