org.bigml.binding.LocalDeepnet Maven / Gradle / Ivy
package org.bigml.binding;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.bigml.binding.laminar.MathOps;
import org.bigml.binding.laminar.Preprocess;
import org.bigml.binding.resources.AbstractResource;
import org.bigml.binding.utils.Utils;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A local Predictive Deepnet.
*
* This module defines a Deepnet to make predictions locally or
* embedded into your application without needing to send requests to
* BigML.io.
*
* This module cannot only save you a few credits, but also enormously
* reduce the latency for each prediction and let you use your models
* offline.
*
* Example usage (assuming that you have previously set up the
* BIGML_USERNAME and BIGML_API_KEY environment variables and that you
* own the deepnet/id below):
*
*
* import org.bigml.binding.LocalDeepnet;
*
* // API client
* BigMLClient api = new BigMLClient();
*
* JSONObject deepnet = api.
* getDeepnet("deepnet/5026965515526876630001b2");
* LocalDeepnet localDeepnet = new LocalDeepnet(deepnet)
*
* JSONObject predictors = JSONValue.parse(
* "{"petal length": 3, "petal width": 1"});
*
* localDeepnet.predict(predictors)
*
*/
public class LocalDeepnet extends ModelFields implements SupervisedModelInterface {
private static final long serialVersionUID = 1L;
private static String DEEPNET_RE = "^deepnet/[a-f,0-9]{24}$";
/**
* Logging
*/
static Logger logger = LoggerFactory.getLogger(
LocalDeepnet.class.getName());
private String deepnetId;
private JSONArray inputFields = null;
private String objectiveField = null;
private JSONArray objectiveFields = null;
private Boolean regression = false;
private List classNames = new ArrayList();
private JSONObject network = null;
private JSONArray networks = null;
private JSONArray preprocess = null;
//private JSONObject optimizer = null;
public LocalDeepnet(JSONObject deepnet) throws Exception {
super((JSONObject) Utils.getJSONObject(
deepnet, "deepnet.fields", new JSONObject()));
// checks whether the information needed for local predictions
// is in the first argument
if (!checkModelFields(deepnet)) {
// if the fields used by the deepnet are not
// available, use only ID to retrieve it again
deepnetId = (String) deepnet.get("resource");
boolean validId = deepnetId.matches(DEEPNET_RE);
if (!validId) {
throw new Exception(
deepnetId + " is not a valid resource ID.");
}
}
if (!(deepnet.containsKey("resource")
&& deepnet.get("resource") != null)) {
BigMLClient client = new BigMLClient(null, null,
BigMLClient.STORAGE);
deepnet = client.getLogisticRegression(deepnetId);
if ((String) deepnet.get("resource") == null) {
throw new Exception(
deepnetId + " is not a valid resource ID.");
}
}
if (deepnet.containsKey("object") &&
deepnet.get("object") instanceof JSONObject) {
deepnet = (JSONObject) deepnet.get("object");
}
deepnetId = (String) deepnet.get("resource");
inputFields = (JSONArray) Utils.getJSONObject(
deepnet, "input_fields");
if (deepnet.containsKey("deepnet")
&& deepnet.get("deepnet") instanceof JSONObject) {
JSONObject status = (JSONObject) Utils.getJSONObject(
deepnet, "status");
if (status != null && status.containsKey("code")
&& AbstractResource.FINISHED == ((Number) status
.get("code")).intValue()) {
objectiveField = (String) Utils.getJSONObject(
deepnet, "objective_field");
objectiveFields = (JSONArray) Utils.getJSONObject(
deepnet, "objective_fields");
JSONObject deepnetInfo = (JSONObject) Utils
.getJSONObject(deepnet, "deepnet");
JSONObject fields = (JSONObject) Utils.getJSONObject(
deepnetInfo, "fields", new JSONObject());
// initialize ModelFields
objectiveField = objectiveField != null ?
objectiveField : (String) objectiveFields.get(0);
super.initialize((JSONObject) fields, objectiveField,
null, null, true, true, false);
String optype = (String) Utils.getJSONObject(
fields, objectiveFieldId + ".optype", "");
regression = optype.equals(Constants.OPTYPE_NUMERIC);
JSONArray categories = (JSONArray) Utils.getJSONObject(
(JSONObject) fields.get(objectiveField),
"summary.categories", new JSONArray());
if (!regression) {
for (Object cat: categories) {
classNames.add((String) ((JSONArray) cat).get(0));
}
Collections.sort(classNames);
}
missingNumerics = (Boolean) Utils.getJSONObject(
deepnetInfo, "missing_numerics");
if (deepnetInfo.containsKey("network")) {
network = (JSONObject) deepnetInfo.get("network");
networks = (JSONArray) Utils.getJSONObject(
network, "networks", new JSONArray());
preprocess = (JSONArray) Utils.getJSONObject(
network, "preprocess", new JSONArray());
//optimizer = (JSONObject) Utils.getJSONObject(
// network, "optimizer", new JSONObject());
}
} else {
throw new Exception(
"The deepnet isn't finished yet");
}
} else {
throw new Exception(String
.format("Cannot create the Deepnet instance. "
+ "Could not find the 'deepnet' key in "
+ "the resource:\n\n%s", deepnet));
}
}
/**
* Returns the resourceId
*/
public String getResourceId() {
return deepnetId;
}
/**
* Returns the class names
*/
public List getClassNames() {
return classNames;
}
/**
* Makes a prediction based on a number of field values.
*
* @param inputData Input data to be predicted
* @param operatingPoint
* In classification models, this is the point of the
* ROC curve where the model will be used at. The
* operating point can be defined in terms of:
* - the positive class, the class that is important to
* predict accurately
* - the probability threshold,
* the probability that is stablished
* as minimum for the positive_class to be predicted.
* The operating point is then defined as a map with
* two attributes, e.g.:
* {"positive_class": "Iris-setosa",
* "probability_threshold": 0.5}
* @param operatingKind
* "probability". Sets the property that decides the
* prediction. Used only if no operating point is used
*
* @param full
* Boolean that controls whether to include the prediction's
* attributes. By default, only the prediction is produced. If set
* to True, the rest of available information is added in a
* dictionary format. The dictionary keys can be:
* - prediction: the prediction value
* - probability: prediction's probability
* - distribution: distribution of probabilities for each
* of the objective field classes
* - unused_fields: list of fields in the input data that
*
*/
public HashMap predict(
JSONObject inputData, JSONObject operatingPoint,
String operatingKind, Boolean full) {
if (full == null) {
full = false;
}
// Checks and cleans inputData leaving the fields used in the model
inputData = filterInputData(inputData, full);
List unusedFields = (List)
inputData.get("unusedFields");
inputData = (JSONObject) inputData.get("newInputData");
// Strips affixes for numeric values and casts to the final field type
Utils.cast(inputData, fields);
// When operating_point is used, we need the probabilities
// of all possible classes to decide, so se use
// the `predict_probability` method
if (operatingPoint != null) {
if (regression) {
throw new IllegalArgumentException(
"The operating_point argument can only be" +
" used in classifications.");
}
return predictOperating(inputData, operatingPoint);
}
if (operatingKind != null) {
if (regression) {
throw new IllegalArgumentException(
"The operating_kind argument can only be" +
" used in classifications.");
}
return predictOperatingKind(inputData, operatingKind);
}
// Computes text and categorical field expansion
Map uniqueTerms = uniqueTerms(inputData);
ArrayList> inputArray =
fillArray(inputData, uniqueTerms);
HashMap prediction = null;
if (networks != null && networks.size() > 0) {
prediction = predictList(inputArray);
} else {
prediction = predictSingle(inputArray);
}
if (full) {
prediction.put("unused_fields", unusedFields);
}
return prediction;
}
/**
* Predicts a probability for each possible output class, based on
* input values. The input fields must be a dictionary keyed by
* field name or field ID.
*
* @param inputData Input data to be predicted
*/
private JSONArray predictProbability(JSONObject inputData) {
try {
return predictProbability(inputData, null);
} catch (Exception e) {
return null;
}
}
/**
* Predicts a probability for each possible output class, based on
* input values. The input fields must be a dictionary keyed by
* field name or field ID.
*
* @param inputData Input data to be predicted
*/
public JSONArray predictProbability(
JSONObject inputData, MissingStrategy missingStrategy)
throws Exception {
HashMap prediction = null;
if (regression) {
prediction = predict(inputData, null, null, true);
} else {
prediction = predict(inputData, null, null, true);
}
JSONArray distribution = (JSONArray) prediction.get("distribution");
Utils.sortPredictions(distribution, "probability", "prediction");
return distribution;
}
/**
* Computes the prediction based on a user-given operating point.
*/
private HashMap predictOperating(
JSONObject inputData, JSONObject operatingPoint) {
String[] operatingKinds = {"probability"};
Object[] operating = Utils.parseOperatingPoint(
operatingPoint, operatingKinds, classNames);
String kind = (String) operating[0];
Double threshold = (Double) operating[1];
String positiveClass = (String) operating[2];
JSONArray predictions = predictProbability(inputData);
for (Object pred: predictions) {
JSONObject prediction = (JSONObject) pred;
String category = (String) prediction.get("category");
if (category.equals(positiveClass) &&
(Double) prediction.get(kind) > threshold) {
return prediction;
}
}
HashMap prediction
= (HashMap) predictions.get(0);
String category = (String) prediction.get("category");
if (category.equals(positiveClass)) {
prediction = (HashMap) predictions.get(1);
}
prediction.put("prediction", prediction.get("category"));
prediction.remove("category");
return prediction;
}
/**
* Computes the prediction based on a user-given operating kind.
*/
private HashMap predictOperatingKind(
JSONObject inputData, String operatingKind) {
JSONArray predictions = null;
String kind = operatingKind.toLowerCase();
if (kind.equals("probability")) {
predictions = predictProbability(inputData);
} else {
throw new IllegalArgumentException(
"Only probability is allowed as operating kind " +
"for deepnets.");
}
HashMap prediction
= (HashMap) predictions.get(0);
prediction.put("prediction", prediction.get("category"));
prediction.remove("category");
return prediction;
}
/**
* Builds a list of occurrences for all the available terms
*/
private List expandTerms(
List termsList, Map inputTerms) {
Double[] termsOccurrences = new Double[termsList.size()];
Arrays.fill(termsOccurrences, 0.0);
if (inputTerms!=null) {
for (Object term: inputTerms.keySet()) {
double occurrences = ((Number)
inputTerms.get(term)).doubleValue();
Integer index = termsList.indexOf(term);
termsOccurrences[index] = occurrences;
}
}
return Arrays.asList(termsOccurrences);
}
/**
* Filling the input array for the network with the data in the
* input_data dictionary. Numeric missings are added as a new field
* and texts/items are processed.
*/
private ArrayList> fillArray(
JSONObject inputData, Map uniqueTerms) {
ArrayList
© 2015 - 2025 Weber Informatics LLC | Privacy Policy