Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.bigml.binding.MultiVote Maven / Gradle / Ivy
Go to download
An open source Java client that gives you a simple binding to interact with BigML. You can use it to
easily create, retrieve, list, update, and delete BigML resources.
package org.bigml.binding;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.bigml.binding.utils.Utils;
import org.json.simple.JSONArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A multiple vote prediction
*
* Uses a number of predictions to generate a combined prediction.
*/
public class MultiVote implements Serializable {
private static final long serialVersionUID = 1L;
/**
* Logging
*/
static Logger LOGGER = LoggerFactory.getLogger(MultiVote.class.getName());
public final static String[] PREDICTION_HEADERS = new String[] {
"prediction", "confidence", "order", "distribution", "count" };
private final static String[] COMBINATION_WEIGHTS = new String[] {
null , "confidence", "probability", null, "weight" };
private final static String[][] WEIGHT_KEYS = new String[][] { {},
{ "confidence" }, { "distribution", "count" }, {}, { "weight" } };
private final static String[] WEIGHT_LABELS = new String[] { "plurality",
"confidence", "probability", "threshold" };
final static int BINS_LIMIT = 32;
final static String BOOSTING_CLASS = "class";
public HashMap[] predictions;
public boolean boosting = false;
public JSONArray boostingOffsets;
/**
* MultiVote: combiner class for ensembles voting predictions.
*/
public MultiVote() {
this(null, null);
}
/**
* MultiVote: combiner class for ensembles voting predictions.
*
* @param predictionsArr {array|object} predictions Array of model's predictions
* @param boostingOffsets
*/
public MultiVote(HashMap[] predictionsArr, JSONArray boostingOffsets) {
int i, len;
if (predictionsArr == null) {
predictionsArr = new HashMap[0];
}
predictions = predictionsArr;
boosting = boostingOffsets != null && !boostingOffsets.isEmpty();
this.boostingOffsets = boostingOffsets;
boolean allOrdered = true;
for (i = 0, len = predictions.length; i < len; i++) {
if (!predictions[i].containsKey("order")) {
allOrdered = false;
break;
}
}
if (!allOrdered) {
for (i = 0, len = predictions.length; i < len; i++) {
predictions[i].put("order", i);
}
}
}
public HashMap[] getPredictions() {
return predictions;
}
/**
* Check if this is a regression model
*
* @return {boolean} True if all the predictions are numbers.
*/
private boolean isRegression() {
int index, len;
HashMap prediction;
if (boosting) {
for (index = 0, len = this.predictions.length; index < len; index++) {
prediction = this.predictions[index];
if (prediction.get("class") == null) {
return true;
}
}
return false;
}
for (index = 0, len = this.predictions.length; index < len; index++) {
prediction = this.predictions[index];
if (!(prediction.get("prediction") instanceof Number)) {
return false;
}
}
return true;
};
/**
* Return the next order to be assigned to a prediction
*
* Predictions in MultiVote are ordered in arrival sequence when
* added using the constructor or the append and extend methods.
* This order is used to break even cases in combination
* methods for classifications.
*
* @return the next order to be assigned to a prediction
*/
private int nextOrder() {
if( predictions != null && predictions.length > 0 ) {
return ((Number) predictions[predictions.length - 1].get("order")).intValue() + 1;
}
return 0;
}
/**
* Adds a new prediction into a list of predictions
*
* prediction_info should contain at least:
* - prediction: whose value is the predicted category or value
*
* for instance:
* {'prediction': 'Iris-virginica'}
*
* it may also contain the keys:
* - confidence: whose value is the confidence/error of the prediction
* - distribution: a list of [category/value, instances] pairs
* describing the distribution at the prediction node
* - count: the total number of instances of the training set in the
* node
*
* @param predictionInfo the prediction to be appended
* @return the this instance
*/
public MultiVote append(HashMap predictionInfo) {
if( predictionInfo == null || predictionInfo.isEmpty() ||
!predictionInfo.containsKey("prediction") ) {
throw new IllegalArgumentException("Failed to add the prediction.\\n" +
"The minimal key for the prediction is 'prediction'" +
":\n{'prediction': 'Iris-virginica'");
}
int order = nextOrder();
predictionInfo.put("order", order);
HashMap[] temp = predictions.clone();
predictions = new HashMap[predictions.length + 1];
System.arraycopy(temp, 0, predictions, 0, temp.length);
predictions[order] = predictionInfo;
return this;
}
/**
* Adds a new prediction into a list of predictions
*
* predictionHeaders should contain the labels for the predictionRow
* values in the same order.
*
* predictionHeaders should contain at least the following string
* - 'prediction': whose associated value in predictionRow
* is the predicted category or value
*
* for instance:
* predictionRow = ['Iris-virginica']
* predictionHeaders = ['prediction']
*
* it may also contain the following headers and values:
* - 'confidence': whose associated value in prediction_row
* is the confidence/error of the prediction
* - 'distribution': a list of [category/value, instances] pairs
* describing the distribution at the prediction node
* - 'count': the total number of instances of the training set in the
* node
*
* @param predictionRow the list of predicted values and extra info
* @param predictionHeaders the name of each value in the predictionRow
* @return the this instance
*/
public MultiVote appendRow(List predictionRow,
List predictionHeaders) {
if( predictionHeaders == null ) {
predictionHeaders = Arrays.asList(PREDICTION_HEADERS);
}
if( predictionRow.size() != predictionHeaders.size() ) {
throw new IllegalArgumentException("WARNING: failed to add the prediction.\\n" +
"The row must have label 'prediction' at least. And the number" +
" of headers must much with the number of elements in the row.");
}
List mutablePredictionRow = new ArrayList(predictionRow);
List mutablePredictionHeaders = new ArrayList(predictionHeaders);
int index = -1;
int order = nextOrder();
try {
index = mutablePredictionHeaders.indexOf("order");
mutablePredictionRow.set(index, order);
} catch (Exception ex) {
mutablePredictionHeaders.add("order");
mutablePredictionRow.add(order);
}
HashMap predictionInfo = new HashMap();
for (int i = 0; i < mutablePredictionHeaders.size(); i++) {
predictionInfo.put(mutablePredictionHeaders.get(i),
mutablePredictionRow.get(i));
}
HashMap[] temp = predictions.clone();
predictions = new HashMap[predictions.length + 1];
System.arraycopy(temp, 0, predictions, 0, temp.length);
predictions[order] = predictionInfo;
return this;
}
/**
* Given a multi vote instance (a list of predictions), extends the list
* with another list of predictions and adds the order information.
*
* For instance, predictions_info could be:
*
* [{'prediction': 'Iris-virginica', 'confidence': 0.3},
* {'prediction': 'Iris-versicolor', 'confidence': 0.8}]
*
* where the expected prediction keys are: prediction (compulsory),
* confidence, distribution and count.
*
* @param votes
*/
public void extend(MultiVote votes) {
if( votes.predictions != null && votes.predictions.length > 0 ) {
int order = nextOrder();
List> predictionsList =
new ArrayList>(Arrays.asList(predictions));
for (HashMap prediction : votes.predictions) {
prediction.put("order", (order + 1));
predictionsList.add(prediction);
}
predictions = (HashMap[]) predictionsList.toArray( new HashMap[predictionsList.size()] );
}
}
/**
* Given a list of predictions, extends the list with another list of
* predictions and adds the order information. For instance,
* predictionsInfo could be:
*
* [{'prediction': 'Iris-virginica', 'confidence': 0.3},
* {'prediction': 'Iris-versicolor', 'confidence': 0.8}]
*
* where the expected prediction keys are: prediction (compulsory),
* confidence, distribution and count.
*
* @param predictionsInfo the list of predictions we want to append
* @return the this instance
*/
public MultiVote extend(List> predictionsInfo) {
if( predictionsInfo == null || predictionsInfo.isEmpty() ) {
throw new IllegalArgumentException("WARNING: failed to add the predictions.\\n" +
"No predictions informed.");
}
int order = nextOrder();
for (int i = 0; i < predictionsInfo.size(); i++) {
HashMap prediction = predictionsInfo.get(i);
prediction.put("order", order + i);
append(prediction);
}
return this;
}
/**
* Given a list of predictions, extends the list with another list of
* predictions and adds the order information. For instance,
* predictionsInfo could be:
*
* [{'prediction': 'Iris-virginica', 'confidence': 0.3},
* {'prediction': 'Iris-versicolor', 'confidence': 0.8}]
*
* where the expected prediction keys are: prediction (compulsory),
* confidence, distribution and count.
*
* @param predictionsRows the list of predictions (in list format) we want to append
* @return the this instance
*/
public MultiVote extendRows(List> predictionsRows,
List predictionsHeader) {
if( predictionsHeader == null ) {
predictionsHeader = Arrays.asList(PREDICTION_HEADERS);
}
int order = nextOrder();
int index = predictionsHeader.indexOf("order");
if( index < 0 ) {
index = predictionsHeader.size();
predictionsHeader.add("order");
}
for( int iPrediction = 0; iPrediction < predictionsRows.size(); iPrediction++ ) {
List predictionRow = predictionsRows.get(iPrediction);
if( index == predictionRow.size() ) {
predictionRow.add(order + 1);
} else {
predictionRow.set(index, order);
}
appendRow(predictionRow, predictionsHeader);
}
return this;
}
/**
* Singles out the votes for a chosen category and returns a prediction
* for this category iff the number of votes reaches at least the given
* threshold.
*
* @param threshold the number of the minimum positive predictions needed for
* a final positive prediction.
* @param category the positive category
* @return MultiVote instance
*/
protected MultiVote singleOutCategory(Integer threshold, String category) {
if( threshold == null || category == null || category.length() == 0 ) {
throw new IllegalArgumentException("No category and threshold information was" +
" found. Add threshold and category info." +
" E.g. {\"threshold\": 6, \"category\":" +
" \"Iris-virginica\"}.");
}
if( threshold > predictions.length ) {
throw new IllegalArgumentException(String.format(
"You cannot set a threshold value larger than " +
"%s. The ensemble has not enough models to use" +
" this threshold value.", predictions.length)
);
}
if( threshold < 1 ) {
throw new IllegalArgumentException("The threshold must be a positive value");
}
List categoryPredictions = new ArrayList();
List restOfPredictions = new ArrayList();
for (HashMap prediction : predictions) {
if( category.equals(prediction.get("prediction")) ) {
categoryPredictions.add(prediction);
} else {
restOfPredictions.add(prediction);
}
}
if( categoryPredictions.size() >= threshold ) {
return new MultiVote((HashMap[]) categoryPredictions.toArray(
new HashMap[categoryPredictions.size()]), null);
} else {
return new MultiVote((HashMap[]) restOfPredictions.toArray(
new HashMap[categoryPredictions.size()]), null);
}
}
/**
* Checks the presence of each of the keys in each of the predictions
*
* @param predictions {array} predictions Array of prediction objects
* @param keys {array} keys Array of key strings
*/
private static boolean checkKeys(HashMap[] predictions,
String[] keys) {
HashMap prediction;
String key;
int index, kindex, len;
for (index = 0, len = predictions.length; index < len; index++) {
prediction = predictions[index];
for (kindex = 0; kindex < keys.length; kindex++) {
key = keys[kindex];
if (!prediction.containsKey(key)) {
throw new Error(
"Not enough data to use the selected prediction method.");
}
}
}
return true;
}
/**
* Normalizes error to a [0, top_range] range and builds probabilities
*
* @param topRange {number} The top range of error to which the original error is
* normalized.
* @return {number} The normalization factor as the sum of the normalized
* error weights.
*/
public Double normalizeError(Double topRange) {
int index, len;
Double error, errorRange, delta, maxError = -1.0d, minError = Double.MAX_VALUE, normalizeFactor = 0.0d;
HashMap prediction;
for (index = 0, len = this.predictions.length; index < len; index++) {
prediction = this.predictions[index];
if (!prediction.containsKey("confidence")) {
throw new Error(
"Not enough data to use the selected prediction method.");
}
error = ((Number) prediction.get("confidence")).doubleValue();
maxError = Math.max(error, maxError);
minError = Math.min(error, minError);
}
errorRange = maxError - minError;
normalizeFactor = 0.0d;
if (errorRange > 0.0d) {
/*
* Shifts and scales predictions errors to [0, top_range]. Then
* builds e^-[scaled error] and returns the normalization factor to
* fit them between [0, 1]
*/
for (index = 0, len = this.predictions.length; index < len; index++) {
prediction = this.predictions[index];
delta = (minError - ((Number) prediction.get("confidence"))
.doubleValue());
this.predictions[index].put("errorWeight",
Math.exp(delta / errorRange * topRange));
normalizeFactor += (Double) this.predictions[index]
.get("errorWeight");
}
} else {
for (index = 0, len = this.predictions.length; index < len; index++) {
prediction = this.predictions[index];
this.predictions[index].put("errorWeight", 1.0d);
}
normalizeFactor = new Double(this.predictions.length);
}
return normalizeFactor;
};
/**
* Wilson score interval computation of the distribution for the prediction
*
* @param prediction {object} prediction Value of the prediction for which confidence
* is computed
* @param distribution {array} distribution Distribution-like structure of predictions
* and the associated weights (only for categoricals). (e.g.
* {'Iris-setosa': 10, 'Iris-versicolor': 5})
* @param n {integer} n Total number of instances in the distribution. If
* absent, the number is computed as the sum of weights in the
* provided distribution
* @param z {float} z Percentile of the standard normal distribution
*/
protected static double wsConfidence(Object prediction,
HashMap distribution, Integer n, Double z) {
double norm, z2, n2, wsSqrt, p = distribution.get(prediction)
.doubleValue(), zDefault = 1.96d;
if (z == null) {
z = zDefault;
}
if (p < 0) {
throw new Error("The distribution weight must be a positive value");
}
if (n != null && n < 1) {
throw new Error(
"The total of instances in the distribution must be"
+ " a positive integer");
}
norm = 0.0d;
for (String key : distribution.keySet()) {
norm += distribution.get(key).doubleValue();
}
if (norm == 0.0d) {
throw new Error("Invalid distribution norm: "
+ distribution.toString());
}
if (norm != 1.0d) {
p = p / norm;
}
if (n == null) {
n = (int) norm;
}
z2 = z * z;
n2 = n * n;
wsSqrt = Math.sqrt((p * (1 - p) / n) + (z2 / (4 * n2)));
return Utils.roundOff((p + (z2 / (2 * n)) - (z * wsSqrt)) / (1 + (z2 / n)), Constants.PRECISION);
}
/**
* Average for regression models' predictions
*
*/
private HashMap avg() {
int i, len, total = this.predictions.length;
double result = 0.0d, confidence = 0.0d, medianResult = 0.0d;
HashMap average = new HashMap();
long instances = 0;
for (i = 0, len = this.predictions.length; i < len; i++) {
result += ((Number) this.predictions[i].get("prediction"))
.doubleValue();
if (this.predictions[i].containsKey("median")) {
medianResult += ((Number) this.predictions[i].get("median"))
.doubleValue();
}
confidence += ((Number) this.predictions[i].get("confidence"))
.doubleValue();
instances += ((Number) this.predictions[i].get("count"))
.longValue();
}
if( total > 0 ) {
average.put("prediction", result / total);
average.put("confidence", confidence / total);
average.put("median", medianResult / total);
} else {
average.put("prediction", Double.NaN);
average.put("confidence", 0.0d);
average.put("median", Double.NaN);
}
average.putAll(getGroupedDistribution(this));
average.put("count", instances);
return average;
}
/**
* Returns the prediction combining votes using error to compute weight
*
* @return {{'prediction': {string|number}, 'confidence': {number}}} The
* combined error is an average of the errors in the MultiVote
* predictions.
*/
public HashMap errorWeighted() {
this.checkKeys(this.predictions, new String[] { "confidence" });
int index, len;
HashMap newPrediction = new HashMap();
Double combinedError = 0.0d, topRange = 10.0d, result = 0.0d, medianResult = 0.0d, normalization_factor = this
.normalizeError(topRange);
Long instances = 0L;
if (normalization_factor == 0.0d) {
newPrediction.put("prediction", Double.NaN);
newPrediction.put("confidence", 0.0d);
return newPrediction;
}
for (index = 0, len = this.predictions.length; index < len; index++) {
HashMap prediction = this.predictions[index];
result += ((Number) prediction.get("prediction")).doubleValue()
* ((Number) prediction.get("errorWeight")).doubleValue();
if (prediction.get("median") != null) {
medianResult += ((Number) prediction.get("median")).doubleValue()
* ((Number) prediction.get("errorWeight")).doubleValue();
}
instances += ((Number) prediction.get("count")).longValue();
combinedError += ((Number) prediction.get("confidence"))
.doubleValue()
* ((Number) prediction.get("errorWeight")).doubleValue();
}
newPrediction.put("prediction", result / normalization_factor);
newPrediction.put("confidence", combinedError / normalization_factor);
newPrediction.put("count", instances);
if (medianResult > 0.0) {
newPrediction.put("median", medianResult / normalization_factor);
}
newPrediction.putAll(getGroupedDistribution(this));
return newPrediction;
};
/**
* Average for regression models' predictions
*
*/
private Double weightedSum(HashMap[] predictions, String key) {
Map prediction = new HashMap();
double weightedSum = 0;
int index, len;
for (index = 0, len = predictions.length; index < len; index++) {
prediction = predictions[index];
Double pred = (Double) prediction.get("prediction");
Double weight = (Double) prediction.get(key);
weightedSum += pred * weight;
}
return weightedSum;
}
/**
* Returns the softmax values from a distribution given as a
* dictionary like:
* {"category": {"probability": probability, "order": order}}
*/
private HashMap softmax(HashMap predictions) {
double total = 0;
HashMap normalized = new HashMap();
for (Map.Entry entry : predictions.entrySet()) {
String key = (String) entry.getKey();
HashMap catInfo = (HashMap) entry.getValue();
Double probability = Math.exp((Double) catInfo.get("probability"));
HashMap pred = new HashMap();
pred.put("probability", probability);
pred.put("order", (Integer) catInfo.get("order"));
normalized.put(key, pred);
total += probability;
}
if (total != 0) {
for (Map.Entry entry : normalized.entrySet()) {
String key = (String) entry.getKey();
HashMap catInfo = (HashMap) entry.getValue();
catInfo.put("probability", ((Double) catInfo.get("probability")) / total);
}
return normalized;
}
return new HashMap();
}
/**
* Combines the predictions for a boosted classification ensemble
* Applies the regression boosting combiner, but per class. Tie breaks
* use the order of the categories in the ensemble summary to decide.
*/
private HashMap classifictionBoostingCombiner(Map options) {
HashMap prediction = new HashMap();
int index, len;
Map groupedPredictions = new HashMap();
for (index = 0, len = this.predictions.length; index < len; index++) {
prediction = this.predictions[index];
if (prediction.get(BOOSTING_CLASS) != null) {
String objectiveClass = (String) prediction.get(BOOSTING_CLASS);
if (!groupedPredictions.containsKey(objectiveClass)) {
List> classList = new ArrayList>();
groupedPredictions.put(objectiveClass, classList);
}
((List>) groupedPredictions.get(objectiveClass)).add(prediction);
}
}
List categories = new ArrayList();
for (Object cats: (JSONArray) options.get("categories")) {
JSONArray cat = (JSONArray) cats;
categories.add((String) cat.get(0));
}
HashMap predictions = new HashMap();
for (Map.Entry entry : groupedPredictions.entrySet()) {
String key = entry.getKey();
ArrayList value = (ArrayList) entry.getValue();
Double boostingOffset = null;
for (Object bOffset: (JSONArray) boostingOffsets) {
JSONArray offset = (JSONArray) bOffset;
if (key.equals((String) offset.get(0))) {
boostingOffset = (Double) offset.get(1);
break;
}
}
HashMap[] preds = new HashMap[value.size()];
for (index = 0, len = preds.length; index < len; index++) {
preds[index] = (HashMap) value.get(index);
}
HashMap pred = new HashMap();
pred.put("probability",
weightedSum(preds, "weight") + boostingOffset);
pred.put("order", categories.indexOf(key));
predictions.put(key, pred);
}
predictions = softmax(predictions);
String predictionName = (String) predictions.keySet().toArray()[0];
HashMap predictionInfo = (HashMap) predictions.get(predictionName);
for (Map.Entry entry : predictions.entrySet()) {
String key = (String) entry.getKey();
HashMap predInfo = (HashMap) entry.getValue();
Double predProbability = (Double) predInfo.get("probability");
Double predictionProbability = (Double) predictionInfo.get("probability");
if (predProbability > predictionProbability) {
predictionName = key;
predictionInfo = predInfo;
} else {
if (predProbability == predictionProbability) {
if ((Integer) predInfo.get("order") <= (Integer) predictionInfo.get("order")) {
predictionName = key;
predictionInfo = predInfo;
}
}
}
}
prediction = new HashMap();
prediction.put("prediction", predictionName);
prediction.put("probability", Utils.roundOff(
(Double) predictionInfo.get("probability"), Constants.PRECISION));
return prediction;
}
/**
* Creates a new predictions array based on the training data probability
*/
public HashMap[] probabilityWeight() {
int index, len, total, order;
Map prediction = new HashMap();
List> predictionsList = new ArrayList>();
for (index = 0, len = this.predictions.length; index < len; index++) {
prediction = this.predictions[index];
if (!prediction.containsKey("distribution")
|| !prediction.containsKey("count")) {
throw new Error(
"Probability weighting is not available because"
+ " distribution information is missing.");
}
total = prediction.get("count") instanceof Long ? ((Long) prediction
.get("count")).intValue() : (Integer) prediction
.get("count");
if (total < 1) {
throw new Error(
"Probability weighting is not available because"
+ " distribution seems to have " + total
+ " as number of instances in a node");
}
order = (Integer) prediction.get("order");
HashMap distribution = (HashMap) prediction.get("distribution");
for (Object key : distribution.keySet()) {
Map newPred = new HashMap();
newPred.put("prediction", key);
newPred.put("probability", ((Integer) distribution.get(key) / (double) total));
newPred.put("count", distribution.get(key));
newPred.put("order", order);
predictionsList.add(newPred);
}
}
HashMap[] predictions = new HashMap[predictionsList.size()];
for (index = 0, len = predictions.length; index < len; index++) {
predictions[index] = (HashMap) predictionsList
.get(index);
}
return predictions;
};
/**
* Builds a distribution based on the predictions of the MultiVote
*
* @param weightLabel {string} weightLabel Label of the value in the prediction object
* whose sum will be used as count in the distribution
*/
public Object[] combineDistribution(String weightLabel) {
int index, len;
int total = 0;
HashMap prediction = new HashMap();
HashMap distribution = new HashMap();
Object[] combinedDistribution = new Object[2];
if( weightLabel == null || weightLabel.trim().length() == 0 ) {
weightLabel = WEIGHT_LABELS[PredictionMethod.PROBABILITY.getCode()];
}
for (index = 0, len = this.predictions.length; index < len; index++) {
prediction = this.predictions[index];
if (!prediction.containsKey(weightLabel)) {
throw new Error(
"Not enough data to use the selected prediction"
+ " method. Try creating your model anew.");
}
String predictionName = (String) prediction.get("prediction");
if (!distribution.containsKey(predictionName)) {
distribution.put(predictionName, 0.0);
}
distribution.put(predictionName, distribution.get(predictionName)
+ (Double) prediction.get(weightLabel));
total += (Integer) prediction.get("count");
}
combinedDistribution[0] = distribution;
combinedDistribution[1] = total;
return combinedDistribution;
}
/**
* Returns the prediction combining votes by using the given weight
*
* @param weightLabel {string} weightLabel Type of combination method: 'plurality':
* plurality (1 vote per prediction) 'confidence': confidence
* weighted (confidence as a vote value) 'probability': probability
* weighted (probability as a vote value)
*
* Will also return the combined confidence, as a weighted average of
* the confidences of the votes.
*/
public HashMap combineCategorical(String weightLabel) {
int index, len;
double weight = 1.0;
Object category;
HashMap prediction = new HashMap();
HashMap mode = new HashMap();
ArrayList tuples = new ArrayList();
for (index = 0, len = this.predictions.length; index < len; index++) {
prediction = this.predictions[index];
if ( weightLabel != null ) {
if (Arrays.asList(WEIGHT_LABELS).indexOf(weightLabel) == -1) {
throw new Error("Wrong weightLabel value.");
}
if ( !prediction.containsKey(weightLabel) ) {
throw new Error(
"Not enough data to use the selected prediction"
+ " method. Try creating your model anew.");
} else {
weight = (Double) prediction.get(weightLabel);
}
}
category = prediction.get("prediction");
HashMap categoryHash = new HashMap();
if (mode.get(category) != null) {
categoryHash.put("count",
((Double) ((HashMap) mode.get(category)).get("count"))
+ weight);
categoryHash.put("order",
((HashMap) mode.get(category)).get("order"));
} else {
categoryHash.put("count", weight);
categoryHash.put("order", prediction.get("order"));
}
mode.put(category, categoryHash);
}
for (Object key : mode.keySet()) {
if (mode.get(key) != null) {
Object[] tuple = new Object[] { key, mode.get(key) };
tuples.add(tuple);
}
}
Collections.sort(tuples, new TupleComparator());
Object[] tuple = (Object[]) tuples.get(0);
Object predictionName = (Object) tuple[0];
HashMap output = new HashMap();
output.put("prediction", predictionName);
if (this.predictions[0].get("confidence") != null) {
return this.weightedConfidence(predictionName, weightLabel);
}
// If prediction had no confidence, compute it from distribution
Object[] distributionInfo = this.combineDistribution(weightLabel);
int count = (Integer) distributionInfo[1];
HashMap distribution = (HashMap) distributionInfo[0];
double combinedConfidence = wsConfidence(predictionName, distribution,
count, null);
output.put("probability", combinedConfidence);
return output;
}
/**
* Compute the combined weighted confidence from a list of predictions
*
* @param combinedPrediction {object} combinedPrediction Prediction object
* @param weightLabel {string} weightLabel Label of the value in the prediction object
* that will be used to weight confidence
*/
public HashMap weightedConfidence(
Object combinedPrediction, Object weightLabel) {
int index, len;
Double finalConfidence = 0.0;
double weight = 1.0;
double totalWeight = 0.0;
HashMap prediction = null;
ArrayList predictionsList = new ArrayList();
for (index = 0, len = this.predictions.length; index < len; index++) {
if (this.predictions[index].get("prediction").equals(
combinedPrediction)) {
predictionsList.add(this.predictions[index]);
}
}
// Convert to array
HashMap[] predictions = (HashMap[]) new HashMap[predictionsList
.size()];
for (index = 0, len = predictions.length; index < len; index++) {
predictions[index] = (HashMap) predictionsList
.get(index);
}
if (weightLabel != null) {
for (index = 0, len = this.predictions.length; index < len; index++) {
prediction = this.predictions[index];
if (prediction.get("confidence") == null
|| prediction.get(weightLabel) == null) {
throw new Error(
"Not enough data to use the selected prediction"
+ " method. Lacks ' + weightLabel + ' information");
}
}
}
for (index = 0, len = predictions.length; index < len; index++) {
prediction = predictions[index];
if (weightLabel != null) {
weight = ((Number) prediction.get("confidence")).doubleValue();
}
finalConfidence += weight
* ((Number) prediction.get("confidence")).doubleValue();
totalWeight += weight;
}
if (totalWeight > 0) {
finalConfidence = finalConfidence / totalWeight;
} else {
finalConfidence = null;
}
HashMap result = new HashMap();
result.put("prediction", combinedPrediction);
result.put("confidence", finalConfidence);
return result;
}
/**
* Returns a distribution formed by grouping the distributions of each predicted node.
*/
protected static Map getGroupedDistribution(MultiVote multiVoteInstance) {
Map joinedDist = new HashMap();
String distributionUnit = "counts";
for (HashMap prediction : multiVoteInstance.getPredictions()) {
HashMap predictionDist = null;
Object distribution = prediction.get("distribution");
if( distribution instanceof Map ) {
predictionDist = (HashMap) distribution;
} else {
predictionDist = (HashMap) Utils.convertDistributionArrayToMap((JSONArray) distribution);
}
joinedDist = Utils.mergeDistributions(joinedDist, predictionDist);
if( "counts".equals(distributionUnit) && joinedDist.size() > BINS_LIMIT) {
distributionUnit = "bins";
}
joinedDist = Utils.mergeBins(joinedDist, BINS_LIMIT);
}
Map distributionInfo = new HashMap();
distributionInfo.put("distribution", Utils.convertDistributionMapToSortedArray(joinedDist));
distributionInfo.put("distributionUnit", distributionUnit);
return distributionInfo;
}
/**
* Reduces a number of predictions voting for classification and averaging
* predictions for regression using the PLURALITY method and without confidence
*
* @return {{"prediction": prediction}}
*/
public HashMap combine() {
return combine((PredictionMethod) null, null);
}
/**
* Reduces a number of predictions voting for classification and
* averaging predictions for regression.
*/
public HashMap combine(PredictionMethod method, Map options) {
if (method == null) {
method = PredictionMethod.PLURALITY;
}
// there must be at least one prediction to be combined
if (this.predictions.length == 0) {
throw new Error("No predictions to be combined.");
}
// and all predictions should have the weight-related keys
String[] keys = WEIGHT_KEYS[method.getCode()];
if (keys.length > 0) {
checkKeys(this.predictions, keys);
}
if (this.boosting) {
for (HashMap prediction : predictions) {
if( !prediction.containsKey("boosting") ) {
prediction.put("boosting", 0.0);
}
}
if (this.isRegression()) {
// sum all gradients weighted by their "weight" plus the
// boosting offset
HashMap prediction = new HashMap();
prediction.put("prediction",
weightedSum(predictions, "weight") +
(Double) this.boostingOffsets.get(0));
return prediction;
} else {
return classifictionBoostingCombiner(options);
}
} else {
if (this.isRegression()) {
for (HashMap prediction : predictions) {
if( !prediction.containsKey("confidence") ) {
prediction.put("confidence", 0.0);
}
}
if (method == PredictionMethod.CONFIDENCE) {
return this.errorWeighted();
}
return this.avg();
}
}
MultiVote multiVote = null;
if (method == PredictionMethod.THRESHOLD) {
Integer threshold = (Integer) options.get("threshold");
String category = (String) options.get("category");
multiVote = singleOutCategory(threshold, category);
} else if (method == PredictionMethod.PROBABILITY) {
multiVote = new MultiVote(this.probabilityWeight(), null);
} else {
multiVote = this;
}
return multiVote.combineCategorical(COMBINATION_WEIGHTS[method.getCode()]);
}
/**
* Comparator
*/
class TupleComparator implements Comparator {
@Override
public int compare(Object[] o1, Object[] o2) {
HashMap hash1 = (HashMap) o1[1];
HashMap hash2 = (HashMap) o2[1];
double weight1 = (Double) hash1.get("count");
double weight2 = (Double) hash2.get("count");
int order1 = (Integer) hash1.get("order");
int order2 = (Integer) hash2.get("order");
return weight1 > weight2 ? -1 : (weight1 < weight2 ? 1
: order1 < order2 ? -1 : 1);
}
}
}