
hex.ConfusionMatrix Maven / Gradle / Ivy
package hex;
import water.Iced;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.TwoDimTable;
import java.util.Arrays;
public class ConfusionMatrix extends Iced {
private TwoDimTable _table;
public final double[][] _cm; // [actual][predicted], typed as double because of observation weights (which can be doubles)
public final String[] _domain;
/**
* Constructor for Confusion Matrix
* @param value 2D square matrix with co-occurrence counts for actual vs predicted class membership
* @param domain class labels (unified domain between actual and predicted class labels)
*/
public ConfusionMatrix(double[][] value, String[] domain) { _cm = value; _domain = domain; }
/** Build the CM data from the actuals and predictions, using the default
* threshold. Print to Log.info if the number of classes is below the
* print_threshold. Actuals might have extra levels not trained on (hence
* never predicted). Actuals with NAs are not scored, and their predictions
* ignored. */
public static ConfusionMatrix buildCM(Vec actuals, Vec predictions) {
if (!actuals.isCategorical()) throw new IllegalArgumentException("actuals must be categorical.");
if (!predictions.isCategorical()) throw new IllegalArgumentException("predictions must be categorical.");
Scope.enter();
try {
Vec adapted = predictions.adaptTo(actuals.domain());
int len = actuals.domain().length;
CMBuilder cm = new CMBuilder(len).doAll(actuals, adapted);
return new ConfusionMatrix(cm._arr, actuals.domain());
} finally {
Scope.exit();
}
}
private static class CMBuilder extends MRTask {
final int _len;
double _arr[/*actuals*/][/*predicted*/];
CMBuilder(int len) { _len = len; }
@Override public void map( Chunk ca, Chunk cp ) {
// After adapting frames, the Actuals have all the levels in the
// prediction results, plus any extras the model was never trained on.
// i.e., Actual levels are at least as big as the predicted levels.
_arr = new double[_len][_len];
for( int i=0; i < ca._len; i++ )
if( !ca.isNA(i) )
_arr[(int)ca.at8(i)][(int)cp.at8(i)]++;
}
@Override public void reduce( CMBuilder cm ) { ArrayUtils.add(_arr,cm._arr); }
}
public void add(int i, int j) { _cm[i][j]++; }
public final int size() { return _cm.length; }
public final double class_error(int c) {
double s = ArrayUtils.sum(_cm[c]);
if( s == 0 ) return 0.0; // Either 0 or NaN, but 0 is nicer
return (s - _cm[c][c]) / s;
}
public double total_rows() {
double n = 0;
for (double[] a_arr : _cm)
n += ArrayUtils.sum(a_arr);
return n;
}
public void add(ConfusionMatrix other) {
ArrayUtils.add(_cm, other._cm);
}
/**
* @return overall classification error
*/
public double err() {
double n = total_rows();
double err = n;
for( int d = 0; d < _cm.length; ++d )
err -= _cm[d][d];
return err / n;
}
public double err_count() {
double err = total_rows();
for( int d = 0; d < _cm.length; ++d )
err -= _cm[d][d];
assert(err >= 0);
return err;
}
/**
* The percentage of predictions that are correct.
*/
public double accuracy() { return 1-err(); }
/**
* The percentage of negative labeled instances that were predicted as negative.
* @return TNR / Specificity
*/
public double specificity() {
if(!isBinary())throw new UnsupportedOperationException("specificity is only implemented for 2 class problems.");
double tn = _cm[0][0];
double fp = _cm[0][1];
return tn / (tn + fp);
}
/**
* The percentage of positive labeled instances that were predicted as positive.
* @return Recall / TPR / Sensitivity
*/
public double recall() {
if(!isBinary())throw new UnsupportedOperationException("recall is only implemented for 2 class problems.");
double tp = _cm[1][1];
double fn = _cm[1][0];
return tp / (tp + fn);
}
/**
* The percentage of positive predictions that are correct.
* @return Precision
*/
public double precision() {
if(!isBinary())throw new UnsupportedOperationException("precision is only implemented for 2 class problems.");
double tp = _cm[1][1];
double fp = _cm[0][1];
return tp / (tp + fp);
}
/**
* The Matthews Correlation Coefficient, takes true negatives into account in contrast to F-Score
* See MCC
* MCC = Correlation between observed and predicted binary classification
* @return mcc ranges from -1 (total disagreement) ... 0 (no better than random) ... 1 (perfect)
*/
public double mcc() {
if(!isBinary())throw new UnsupportedOperationException("precision is only implemented for 2 class problems.");
double tn = _cm[0][0];
double fp = _cm[0][1];
double tp = _cm[1][1];
double fn = _cm[1][0];
return (tp*tn - fp*fn)/Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn));
}
/**
* The maximum per-class error
* @return max[classErr(i)]
*/
public double max_per_class_error() {
int n = nclasses();
if(n == 0)throw new UnsupportedOperationException("max per class error is only defined for classification problems");
double res = class_error(0);
for(int i = 1; i < n; ++i)
res = Math.max(res, class_error(i));
return res;
}
public final int nclasses(){return _cm == null?0: _cm.length;}
public final boolean isBinary(){return nclasses() == 2;}
/**
* Returns the F-measure which combines precision and recall.
* C.f. end of http://en.wikipedia.org/wiki/Precision_and_recall.
*/
public double F1() {
final double precision = precision();
final double recall = recall();
return 2. * (precision * recall) / (precision + recall);
}
/**
* Returns the F-measure which combines precision and recall and weights recall higher than precision.
* See F1_score
*/
public double F2() {
final double precision = precision();
final double recall = recall();
return 5. * (precision * recall) / (4. * precision + recall);
}
/**
* Returns the F-measure which combines precision and recall and weights precision higher than recall.
* See F1_score
*/
public double F0point5() {
final double precision = precision();
final double recall = recall();
return 1.25 * (precision * recall) / (.25 * precision + recall);
}
@Override public String toString() {
StringBuilder sb = new StringBuilder();
for( double[] r : _cm)
sb.append(Arrays.toString(r)).append('\n');
return sb.toString();
}
private static String[] createConfusionMatrixHeader( double xs[], String ds[] ) {
String ss[] = new String[xs.length]; // the same length
for( int i=0; i= 0 || (ds[i] != null && ds[i].length() > 0) && !Double.toString(i).equals(ds[i]) )
ss[i] = ds[i];
if( ds.length == xs.length-1 && xs[xs.length-1] > 0 )
ss[xs.length-1] = "NA";
return ss;
}
public String toASCII() { return table() == null ? "" : _table.toString(); }
/** Convert this ConfusionMatrix into a fully annotated TwoDimTable
* @return TwoDimTable */
public TwoDimTable table() { return _table == null ? (_table=toTable()) : _table; }
// Do the work making a TwoDimTable
private TwoDimTable toTable() {
if (_cm == null || _domain == null) return null;
for( double cm[] : _cm ) assert(_cm.length == cm.length);
// Sum up predicted & actuals
double acts [] = new double[_cm.length];
double preds[] = new double[_cm[0].length];
boolean isInt = true;
for( int a=0; a< _cm.length; a++ ) {
double sum=0;
for( int p=0; p< _cm[a].length; p++ ) {
sum += _cm[a][p];
preds[p] += _cm[a][p];
isInt &= (_cm[a][p] == (long)_cm[a][p]);
}
acts[a] = sum;
}
String adomain[] = createConfusionMatrixHeader(acts , _domain);
String pdomain[] = createConfusionMatrixHeader(preds, _domain);
assert adomain.length == pdomain.length : "The confusion matrix should have the same length for both directions.";
String[] rowHeader = Arrays.copyOf(adomain,adomain.length+1);
rowHeader[adomain.length] = "Totals";
String[] colHeader = Arrays.copyOf(pdomain,pdomain.length+2);
colHeader[colHeader.length-2] = "Error";
colHeader[colHeader.length-1] = "Rate";
String[] colType = new String[colHeader.length];
String[] colFormat = new String[colHeader.length];
for (int i=0; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy