org.bigml.binding.LocalLogisticRegression Maven / Gradle / Ivy
package org.bigml.binding;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.bigml.binding.utils.Utils;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.bigml.binding.resources.AbstractResource;
/**
* A local Predictive Logistic Regression.
*
* This module defines a Logistic Regression to make predictions locally or
* embedded into your application without needing to send requests to BigML.io.
*
* This module cannot only save you a few credits, but also enormously reduce
* the latency for each prediction and let you use your logistic regressions
* offline.
*
* Example usage (assuming that you have previously set up the BIGML_USERNAME
* and BIGML_API_KEY environment variables and that you own the
* logisticregression/id below):
*
*
* import org.bigml.binding.LocalLogisticRegression;
*
* // API client
* BigMLClient api = new BigMLClient();
*
* JSONObject logisticRegression = api.
* getLogisticRegression("logisticregression/5026965515526876630001b2");
* LocalLogisticRegression logistic =
* LocalLogisticRegression(logisticRegression)
*
* JSONObject predictors = JSONValue.parse("
* {\"petal length\": 3, \"petal width\": 1,
* \"sepal length\": 1, \"sepal width\": 0.5}");
*
* logistic.predict(predictors)
*
*/
public class LocalLogisticRegression extends ModelFields implements SupervisedModelInterface {
private static final long serialVersionUID = 1L;
static String LOGISTICREGRESSION_RE = "^logisticregression/[a-f,0-9]{24}$";
static HashMap EXPANSION_ATTRIBUTES = new HashMap();
static {
EXPANSION_ATTRIBUTES.put("categorical", "categories");
EXPANSION_ATTRIBUTES.put("text", "tag_cloud");
EXPANSION_ATTRIBUTES.put("items", "items");
}
protected static final String[] OPTIONAL_FIELDS = {
"categorical", "text", "items", "datetime" };
/**
* Logging
*/
static Logger logger = LoggerFactory
.getLogger(LocalLogisticRegression.class.getName());
private String logisticRegressionId;
private JSONObject datasetFieldTypes = null;
private JSONArray inputFields = null;
private String objectiveField = null;
private JSONArray objectiveFields = null;
private JSONObject coefficients = null;
private Boolean bias;
//private Double c;
//private Double eps;
private Boolean normalize;
private Boolean balanceFields;
//private String regularization;
private JSONObject fieldCodings;
private List classNames = new ArrayList();
private String weightField;
public LocalLogisticRegression(JSONObject logistic) throws Exception {
super((JSONObject) Utils.getJSONObject(
logistic, "logistic_regression.fields", new JSONObject()));
// checks whether the information needed for local predictions
// is in the first argument
if (!checkModelFields(logistic)) {
// if the fields used by the logistic regression are not
// available, use only ID to retrieve it again
logisticRegressionId = (String) logistic.get("resource");
boolean validId = logisticRegressionId.matches(
LOGISTICREGRESSION_RE);
if (!validId) {
throw new Exception(
logisticRegressionId + " is not a valid resource ID.");
}
}
if (!(logistic.containsKey("resource")
&& logistic.get("resource") != null)) {
BigMLClient client = new BigMLClient(null, null,
BigMLClient.STORAGE);
logistic = client.getLogisticRegression(logisticRegressionId);
if ((String) logistic.get("resource") == null) {
throw new Exception(
logisticRegressionId + " is not a valid resource ID.");
}
}
if (logistic.containsKey("object") &&
logistic.get("object") instanceof JSONObject) {
logistic = (JSONObject) logistic.get("object");
}
logisticRegressionId = (String) logistic.get("resource");
// Check json structure
datasetFieldTypes = (JSONObject) Utils.getJSONObject(logistic,
"dataset_field_types");
inputFields = (JSONArray) Utils.getJSONObject(logistic, "input_fields");
objectiveField = (String) Utils.getJSONObject(logistic,
"objective_field");
objectiveFields = (JSONArray) Utils.getJSONObject(logistic,
"objective_fields");
weightField = (String) Utils.getJSONObject(logistic, "weight_field");
if (datasetFieldTypes == null || inputFields == null
|| (objectiveField == null && objectiveFields == null)) {
throw new Exception(
"Failed to find the logistic regression expected "
+ "JSON structure. Check your arguments.");
}
if (logistic.containsKey("logistic_regression")
&& logistic.get("logistic_regression") instanceof JSONObject) {
JSONObject status = (JSONObject) Utils.getJSONObject(logistic,
"status");
if (status != null && status.containsKey("code")
&& AbstractResource.FINISHED == ((Number) status
.get("code")).intValue()) {
JSONObject logisticInfo = (JSONObject) Utils
.getJSONObject(logistic, "logistic_regression");
// Check if old format for coefficents
JSONArray coefficientsList = (JSONArray) Utils.getJSONObject(
logisticInfo, "coefficients", new JSONArray());
if (coefficientsList.get(0) instanceof String) {
throw new Exception(
"Detected old format of logistic regression detected.");
}
JSONObject fields = (JSONObject) Utils.getJSONObject(
logisticInfo, "fields", new JSONObject());
if (inputFields == null) {
inputFields = new JSONArray();
String[] inputFieldsArray = new String[fields.values().size()];
for (Object fieldId : fields.keySet()) {
int columnNumber = ((Number) Utils.getJSONObject(
fields, fieldId + ".column_number")).intValue();
inputFieldsArray[columnNumber] = (String) fieldId;
}
inputFields.addAll(Arrays.asList(inputFieldsArray));
}
coefficients = new JSONObject();
for (int i=0; i categories.size()) {
classNames.add("");
}
for (Object cat: categories) {
classNames.add((String) ((JSONArray) cat).get(0));
}
} else {
throw new Exception(
"The logistic regression isn't finished yet");
}
} else {
throw new Exception(String
.format("Cannot create the LogisticRegression instance. "
+ "Could not find the 'logistic_regression' key in "
+ "the resource:\n\n%s", logistic));
}
}
/**
* Returns the resourceId
*/
public String getResourceId() {
return logisticRegressionId;
}
/**
* Returns the class names
*/
public List getClassNames() {
return classNames;
}
/**
* Predicts a probability for each possible output class, based on
* input values. The input fields must be a dictionary keyed by
* field name or field ID.
*
* @param inputData Input data to be predicted
*/
public JSONArray predictProbability(JSONObject inputData) {
try {
return predictProbability(inputData, null);
} catch (Exception e) {
return null;
}
}
/**
* Predicts a probability for each possible output class, based on
* input values. The input fields must be a dictionary keyed by
* field name or field ID.
*
* @param inputData Input data to be predicted
*/
public JSONArray predictProbability(JSONObject inputData,
MissingStrategy missingStrategy) throws Exception {
HashMap prediction = predict(inputData, null, null, true);
JSONArray distribution = (JSONArray) prediction.get("distribution");
Utils.sortPredictions(distribution, "probability", "prediction");
return distribution;
}
/**
* Computes the prediction based on a user-given operating point.
*/
private HashMap predictOperating(
JSONObject inputData, JSONObject operatingPoint) {
String[] operatingKinds = {"probability"};
Object[] operating = Utils.parseOperatingPoint(
operatingPoint, operatingKinds, classNames);
String kind = (String) operating[0];
Double threshold = (Double) operating[1];
String positiveClass = (String) operating[2];
JSONArray predictions = predictProbability(inputData);
for (Object pred: predictions) {
HashMap prediction
= (HashMap) pred;
String category = (String) prediction.get("prediction");
if (category.equals(positiveClass) &&
(Double) prediction.get(kind) > threshold) {
return prediction;
}
}
return (HashMap) predictions.get(0);
}
/**
* Computes the prediction based on a user-given operating kind.
*/
private HashMap predictOperatingKind(
JSONObject inputData, String operatingKind) {
JSONArray predictions = null;
String kind = operatingKind.toLowerCase();
if (kind.equals("probability")) {
predictions = predictProbability(inputData);
} else {
throw new IllegalArgumentException(
"Only probability is allowed as operating kind " +
"for logistic regressions.");
}
return (HashMap) predictions.get(0);
}
/**
* Returns the class prediction and the probability distribution
*
* @param inputData Input data to be predicted
* @param operatingPoint
* In classification models, this is the point of the
* ROC curve where the model will be used at. The
* operating point can be defined in terms of:
* - the positive class, the class that is important to
* predict accurately
* - the probability threshold,
* the probability that is stablished
* as minimum for the positive_class to be predicted.
* The operating point is then defined as a map with
* two attributes, e.g.:
* {"positive_class": "Iris-setosa",
* "probability_threshold": 0.5}
*
* @param operatingKind
* "probability". Sets the property that decides the
* prediction. Used only if no operating point is used
*
* @param full
* Boolean that controls whether to include the prediction's
* attributes. By default, only the prediction is produced. If set
* to True, the rest of available information is added in a
* dictionary format. The dictionary keys can be:
* - prediction: the prediction value
* - probability: prediction's probability
* - distribution: distribution of probabilities for each
* of the objective field classes
* - unused_fields: list of fields in the input data that
*
*/
public HashMap predict(
JSONObject inputData, JSONObject operatingPoint,
String operatingKind, Boolean full) {
if (full == null) {
full = false;
}
// Checks and cleans inputData leaving the fields used in the model
inputData = filterInputData(inputData, full);
List unusedFields = (List)
inputData.get("unusedFields");
inputData = (JSONObject) inputData.get("newInputData");
// Strips affixes for numeric values and casts to the final field type
Utils.cast(inputData, fields);
// When operating_point is used, we need the probabilities
// of all possible classes to decide, so se use
// the `predict_probability` method
if (operatingPoint != null) {
return predictOperating(inputData, operatingPoint);
}
if (operatingKind != null) {
return predictOperatingKind(inputData, operatingKind);
}
// In case that missing_numerics is False, checks that all numeric
// fields are present in input data.
if (!this.missingNumerics) {
Utils.checkNoMissingNumerics(inputData, this.fields, this.weightField);
}
if (balanceFields != null && balanceFields==true) {
balanceInput(inputData, fields);
}
// Computes text and categorical field expansion
Map uniqueTerms = uniqueTerms(inputData);
// Computes the contributions for each category
JSONObject probabilities = new JSONObject();
double total = 0;
for (Object coeff : coefficients.keySet()) {
String category = (String) coeff;
double probability = categoryProbability(
inputData, uniqueTerms, category);
JSONArray objectiveFieldCategory =
(JSONArray) categories.get(objectiveField);
int order = objectiveFieldCategory.indexOf(category);
if (order == -1) {
if (category.equals("")) {
order = objectiveFieldCategory.size();
}
}
JSONObject probabilityCategory = new JSONObject();
probabilityCategory.put("prediction", category);
probabilityCategory.put("probability", probability);
probabilityCategory.put("order", order);
probabilities.put(category, probabilityCategory);
total += probability;
}
// Normalizes the contributions to get a probability
for (Object category: probabilities.keySet()) {
JSONObject probabilityCategory = (JSONObject)
probabilities.get(category);
double probability = ((Number)
probabilityCategory.get("probability")).doubleValue();
probability /= total;
probabilityCategory.put("probability",
Utils.roundOff(probability, Constants.PRECISION));
}
// Chooses the most probable category as prediction
JSONArray distribution = new JSONArray();
for (Object category: probabilities.keySet()) {
JSONObject probabilityCategory = (JSONObject)
probabilities.get(category);
probabilityCategory.remove("order");
distribution.add(probabilityCategory);
}
Utils.sortPredictions(distribution, "probability", "prediction");
JSONObject prediction = (JSONObject) distribution.get(0);
HashMap result = new HashMap();
result.put("prediction", (String) prediction.get("prediction"));
result.put("probability", (Double) prediction.get("probability"));
result.put("distribution", distribution);
if (full) {
result.put("unused_fields", unusedFields);
}
return result;
}
/**
* Computes the probability for a concrete category
*
*/
private double categoryProbability(JSONObject numericInputs,
Map uniqueTerms, String category) {
double probability = 0.0;
double norm2 = 0.0;
// numeric input data
for (Object field: numericInputs.keySet()) {
String fieldId = (String) field;
JSONArray coefficients = getCoefficients(category, fieldId);
double value = ((Number) numericInputs.get(fieldId)).doubleValue();
double coeff = ((Number) coefficients.get(0)).doubleValue();
probability += coeff * value;
if (normalize) {
norm2 += Math.pow(value, 2);
}
}
// text, items and categories
for (Object field: uniqueTerms.keySet()) {
String fieldId = (String) field;
if (inputFields.contains(fieldId)) {
Map uniqueTerm = (Map)
uniqueTerms.get(fieldId);
JSONArray coefficients = getCoefficients(category, fieldId);
for (Object term: uniqueTerm.keySet()) {
int occurrences = ((Number) uniqueTerm.get(term)).intValue();
try {
boolean oneHot = true;
Integer index = null;
if (tagClouds.containsKey(fieldId)) {
index = ((List) tagClouds.get(fieldId)).indexOf(term);
} else {
if (items.containsKey(fieldId)) {
index = ((List) items.get(fieldId)).indexOf(term);
} else {
JSONObject fieldCoding = (JSONObject) fieldCodings.get(fieldId);
if (categories.containsKey(fieldId) &&
(!fieldCodings.containsKey(fieldId) ||
"dummy".equals( (String) fieldCoding.keySet().toArray()[0] ))) {
index = ((JSONArray) categories.get(fieldId)).indexOf(term);
} else {
if (categories.containsKey(fieldId)) {
oneHot = false;
index = ((JSONArray) categories.get(fieldId)).indexOf(term);
int coeffIndex = 0;
JSONArray contributions = (JSONArray) fieldCoding.values().toArray()[0];
for (Object contribValue: contributions) {
JSONArray contribution = (JSONArray) contribValue;
double contrib = ((Number)
contribution.get(index)).doubleValue();
double coeff = ((Number)
coefficients.get(coeffIndex)).doubleValue();
probability += coeff * contrib * occurrences;
coeffIndex++;
}
}
}
}
}
if (oneHot) {
double coeff = ((Number)
coefficients.get(index)).doubleValue();
probability += coeff * occurrences;
}
norm2 += Math.pow(occurrences, 2);
} catch (Exception e) {}
}
}
}
// missings
for (Object field: inputFields) {
String fieldId = (String) field;
boolean contribution = false;
JSONArray coefficients = getCoefficients(category, fieldId);
try {
if (numericFields.containsKey(fieldId) &&
!numericInputs.containsKey(fieldId)) {
Number coeffN = coefficients.size() == 1 ?
(Number) coefficients.get(0):
(Number) coefficients.get(1);
double coeff = ((Number) coeffN).doubleValue();
probability += coeff;
contribution = true;
} else {
boolean uniqueTerm =
!uniqueTerms.containsKey(fieldId) ||
uniqueTerms.get(fieldId)==null ||
((HashMap) uniqueTerms.get(fieldId)).keySet().size() == 0;
if (tagClouds.containsKey(fieldId) && uniqueTerm) {
double coeff = ((Number) coefficients.get(
tagClouds.get(fieldId).size())).doubleValue();
probability += coeff;
contribution = true;
} else {
if (items.containsKey(fieldId) && uniqueTerm) {
double coeff = ((Number) coefficients.get(
items.get(fieldId).size())).doubleValue();
probability += coeff;
contribution = true;
} else {
if (categories.containsKey(fieldId) &&
!objectiveField.equals(fieldId) &&
!uniqueTerms.containsKey(fieldId)) {
JSONObject fieldCoding = (JSONObject) fieldCodings.get(fieldId);
if (!fieldCodings.containsKey(fieldId) ||
"dummy".equals( (String) fieldCoding.keySet().toArray()[0] )) {
double coeff = ((Number)
coefficients.get(((List) categories.get(fieldId)).size())).doubleValue();
probability += coeff;
} else {
// codings are given as arrays of coefficients. The
// last one is for missings and the previous ones are
// one per category as found in summary
int coeffIndex = 0;
JSONArray contributions = (JSONArray) fieldCoding.values().toArray()[0];
for (Object contribValue: contributions) {
JSONArray constribution = (JSONArray) contribValue;
double coeff = ((Number)
coefficients.get(coeffIndex)).doubleValue();
double value = ((Number)
constribution.get(constribution.size()-1)).doubleValue();
probability += coeff * value;
coeffIndex++;
}
}
contribution = true;
}
}
}
}
} catch (Exception e) {
e.printStackTrace();
}
if (contribution && normalize) {
norm2 += 1;
}
}
// the bias term is the last in the coefficients list
JSONArray catCoeff = (JSONArray) coefficients.get(category);
probability += ((Number) ((JSONArray)
catCoeff.get(catCoeff.size()-1)).get(0)).doubleValue();
if (bias) {
norm2 += 1;
}
if (normalize) {
try {
probability /= Math.sqrt(norm2);
} catch (Exception e) {
// this should never happen
probability = 0.0;
}
}
try {
probability = 1 / (1 + Math.exp(-probability));
} catch (Exception e) {
probability = probability < 0 ? 0 : 1;
}
// truncate probability to 5 digits, as in the backend
return Utils.roundOff(probability, 5);
}
/**
* Balancing the values in the input_data using the corresponding
* field scales
*/
private void balanceInput(JSONObject inputData, JSONObject fields) {
for (Object fieldId : inputData.keySet()) {
JSONObject field = (JSONObject) fields.get(fieldId);
if ("numeric".equals(field.get("optype"))) {
JSONObject summary = (JSONObject) field.get("summary");
Double mean = (Double) Utils.getJSONObject(
summary, "mean", 0);
Double stddev = (Double) Utils.getJSONObject(
summary, "standard_deviation", 0);
// if stddev is not positive, we only substract the mean
double value = ((Number) inputData.get(fieldId)).doubleValue();
if (stddev <= 0) {
inputData.put(fieldId, (value - mean));
} else {
inputData.put(fieldId, ((value - mean) / stddev));
}
}
}
}
/**
* Returns the set of coefficients for the given category and fieldIds
*/
private JSONArray getCoefficients(String category, String fieldId) {
int coeffIndex = inputFields.indexOf(fieldId);
return (JSONArray) ((JSONArray) coefficients.get(category))
.get(coeffIndex);
}
/**
* Changes the field codings format to the dict notation
*
*/
private void formatFieldCodings(JSONArray fieldCodingsArray) {
fieldCodings = new JSONObject();
for (int i=0; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy