All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.ConfusionMatrix Maven / Gradle / Ivy

There is a newer version: 0.13.0
Show newest version
/**
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;

import com.google.common.base.Preconditions;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * The ConfusionMatrix Class stores the result of Classification of a Test Dataset.
 * 
 * The fact of whether there is a default is not stored. A row of zeros is the only indicator that there is no default.
 * 
 * See http://en.wikipedia.org/wiki/Confusion_matrix for background
 */
public class ConfusionMatrix {
  private static final Logger LOG = LoggerFactory.getLogger(ConfusionMatrix.class);
  private final Map labelMap = new LinkedHashMap<>();
  private final int[][] confusionMatrix;
  private int samples = 0;
  private String defaultLabel = "unknown";
  
  public ConfusionMatrix(Collection labels, String defaultLabel) {
    confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
    this.defaultLabel = defaultLabel;
    int i = 0;
    for (String label : labels) {
      labelMap.put(label, i++);
    }
    labelMap.put(defaultLabel, i);
  }
  
  public ConfusionMatrix(Matrix m) {
    confusionMatrix = new int[m.numRows()][m.numRows()];
    setMatrix(m);
  }
  
  public int[][] getConfusionMatrix() {
    return confusionMatrix;
  }
  
  public Collection getLabels() {
    return Collections.unmodifiableCollection(labelMap.keySet());
  }

  private int numLabels() {
    return labelMap.size();
  }

  public double getAccuracy(String label) {
    int labelId = labelMap.get(label);
    int labelTotal = 0;
    int correct = 0;
    for (int i = 0; i < numLabels(); i++) {
      labelTotal += confusionMatrix[labelId][i];
      if (i == labelId) {
        correct += confusionMatrix[labelId][i];
      }
    }
    return 100.0 * correct / labelTotal;
  }

  // Producer accuracy
  public double getAccuracy() {
    int total = 0;
    int correct = 0;
    for (int i = 0; i < numLabels(); i++) {
      for (int j = 0; j < numLabels(); j++) {
        total += confusionMatrix[i][j];
        if (i == j) {
          correct += confusionMatrix[i][j];
        }
      }
    }
    return 100.0 * correct / total;
  }

  /** Sum of true positives and false negatives */
  private int getActualNumberOfTestExamplesForClass(String label) {
    int labelId = labelMap.get(label);
    int sum = 0;
    for (int i = 0; i < numLabels(); i++) {
      sum += confusionMatrix[labelId][i];
    }
    return sum;
  }

  public double getPrecision(String label) {
    int labelId = labelMap.get(label);
    int truePositives = confusionMatrix[labelId][labelId];
    int falsePositives = 0;
    for (int i = 0; i < numLabels(); i++) {
      if (i == labelId) {
        continue;
      }
      falsePositives += confusionMatrix[i][labelId];
    }

    if (truePositives + falsePositives == 0) {
      return 0;
    }

    return ((double) truePositives) / (truePositives + falsePositives);
  }

  public double getWeightedPrecision() {
    double[] precisions = new double[numLabels()];
    double[] weights = new double[numLabels()];

    int index = 0;
    for (String label : labelMap.keySet()) {
      precisions[index] = getPrecision(label);
      weights[index] = getActualNumberOfTestExamplesForClass(label);
      index++;
    }
    return new Mean().evaluate(precisions, weights);
  }

  public double getRecall(String label) {
    int labelId = labelMap.get(label);
    int truePositives = confusionMatrix[labelId][labelId];
    int falseNegatives = 0;
    for (int i = 0; i < numLabels(); i++) {
      if (i == labelId) {
        continue;
      }
      falseNegatives += confusionMatrix[labelId][i];
    }
    if (truePositives + falseNegatives == 0) {
      return 0;
    }
    return ((double) truePositives) / (truePositives + falseNegatives);
  }

  public double getWeightedRecall() {
    double[] recalls = new double[numLabels()];
    double[] weights = new double[numLabels()];

    int index = 0;
    for (String label : labelMap.keySet()) {
      recalls[index] = getRecall(label);
      weights[index] = getActualNumberOfTestExamplesForClass(label);
      index++;
    }
    return new Mean().evaluate(recalls, weights);
  }

  public double getF1score(String label) {
    double precision = getPrecision(label);
    double recall = getRecall(label);
    if (precision + recall == 0) {
      return 0;
    }
    return 2 * precision * recall / (precision + recall);
  }

  public double getWeightedF1score() {
    double[] f1Scores = new double[numLabels()];
    double[] weights = new double[numLabels()];

    int index = 0;
    for (String label : labelMap.keySet()) {
      f1Scores[index] = getF1score(label);
      weights[index] = getActualNumberOfTestExamplesForClass(label);
      index++;
    }
    return new Mean().evaluate(f1Scores, weights);
  }

  // User accuracy 
  public double getReliability() {
    int count = 0;
    double accuracy = 0;
    for (String label: labelMap.keySet()) {
      if (!label.equals(defaultLabel)) {
        accuracy += getAccuracy(label);
      }
      count++;
    }
    return accuracy / count;
  }
  
  /**
   * Accuracy v.s. randomly classifying all samples.
   * kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy())
   * Cohen, Jacob. 1960. A coefficient of agreement for nominal scales. 
   * Educational And Psychological Measurement 20:37-46.
   * 
   * Formula and variable names from:
   * http://www.yale.edu/ceo/OEFS/Accuracy.pdf
   * 
   * @return double
   */
  public double getKappa() {
    double a = 0.0;
    double b = 0.0;
    for (int i = 0; i < confusionMatrix.length; i++) {
      a += confusionMatrix[i][i];
      double br = 0;
      for (int j = 0; j < confusionMatrix.length; j++) {
        br += confusionMatrix[i][j];
      }
      double bc = 0;
      for (int[] vec : confusionMatrix) {
        bc += vec[i];
      }
      b += br * bc;
    }
    return (samples * a - b) / (samples * samples - b);
  }
  
  /**
   * Standard deviation of normalized producer accuracy
   * Not a standard score
   * @return double
   */
  public RunningAverageAndStdDev getNormalizedStats() {
    RunningAverageAndStdDev summer = new FullRunningAverageAndStdDev();
    for (int d = 0; d < confusionMatrix.length; d++) {
      double total = 0;
      for (int j = 0; j < confusionMatrix.length; j++) {
        total += confusionMatrix[d][j];
      }
      summer.addDatum(confusionMatrix[d][d] / (total + 0.000001));
    }
    
    return summer;
  }
   
  public int getCorrect(String label) {
    int labelId = labelMap.get(label);
    return confusionMatrix[labelId][labelId];
  }
  
  public int getTotal(String label) {
    int labelId = labelMap.get(label);
    int labelTotal = 0;
    for (int i = 0; i < labelMap.size(); i++) {
      labelTotal += confusionMatrix[labelId][i];
    }
    return labelTotal;
  }
  
  public void addInstance(String correctLabel, ClassifierResult classifiedResult) {
    samples++;
    incrementCount(correctLabel, classifiedResult.getLabel());
  }
  
  public void addInstance(String correctLabel, String classifiedLabel) {
    samples++;
    incrementCount(correctLabel, classifiedLabel);
  }
  
  public int getCount(String correctLabel, String classifiedLabel) {
    if(!labelMap.containsKey(correctLabel)) {
      LOG.warn("Label {} did not appear in the training examples", correctLabel);
      return 0;
    }
    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
    int correctId = labelMap.get(correctLabel);
    int classifiedId = labelMap.get(classifiedLabel);
    return confusionMatrix[correctId][classifiedId];
  }
  
  public void putCount(String correctLabel, String classifiedLabel, int count) {
    if(!labelMap.containsKey(correctLabel)) {
      LOG.warn("Label {} did not appear in the training examples", correctLabel);
      return;
    }
    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
    int correctId = labelMap.get(correctLabel);
    int classifiedId = labelMap.get(classifiedLabel);
    if (confusionMatrix[correctId][classifiedId] == 0.0 && count != 0) {
      samples++;
    }
    confusionMatrix[correctId][classifiedId] = count;
  }
  
  public String getDefaultLabel() {
    return defaultLabel;
  }
  
  public void incrementCount(String correctLabel, String classifiedLabel, int count) {
    putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel));
  }
  
  public void incrementCount(String correctLabel, String classifiedLabel) {
    incrementCount(correctLabel, classifiedLabel, 1);
  }
  
  public ConfusionMatrix merge(ConfusionMatrix b) {
    Preconditions.checkArgument(labelMap.size() == b.getLabels().size(), "The label sizes do not match");
    for (String correctLabel : this.labelMap.keySet()) {
      for (String classifiedLabel : this.labelMap.keySet()) {
        incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel));
      }
    }
    return this;
  }
  
  public Matrix getMatrix() {
    int length = confusionMatrix.length;
    Matrix m = new DenseMatrix(length, length);
    for (int r = 0; r < length; r++) {
      for (int c = 0; c < length; c++) {
        m.set(r, c, confusionMatrix[r][c]);
      }
    }
    Map labels = new HashMap<>();
    for (Map.Entry entry : labelMap.entrySet()) {
      labels.put(entry.getKey(), entry.getValue());
    }
    m.setRowLabelBindings(labels);
    m.setColumnLabelBindings(labels);
    return m;
  }
  
  public void setMatrix(Matrix m) {
    int length = confusionMatrix.length;
    if (m.numRows() != m.numCols()) {
      throw new IllegalArgumentException(
          "ConfusionMatrix: matrix(" + m.numRows() + ',' + m.numCols() + ") must be square");
    }
    for (int r = 0; r < length; r++) {
      for (int c = 0; c < length; c++) {
        confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
      }
    }
    Map labels = m.getRowLabelBindings();
    if (labels == null) {
      labels = m.getColumnLabelBindings();
    }
    if (labels != null) {
      String[] sorted = sortLabels(labels);
      verifyLabels(length, sorted);
      labelMap.clear();
      for (int i = 0; i < length; i++) {
        labelMap.put(sorted[i], i);
      }
    }
  }
  
  private static String[] sortLabels(Map labels) {
    String[] sorted = new String[labels.size()];
    for (Map.Entry entry : labels.entrySet()) {
      sorted[entry.getValue()] = entry.getKey();
    }
    return sorted;
  }
  
  private static void verifyLabels(int length, String[] sorted) {
    Preconditions.checkArgument(sorted.length == length, "One label, one row");
    for (int i = 0; i < length; i++) {
      if (sorted[i] == null) {
        Preconditions.checkArgument(false, "One label, one row");
      }
    }
  }
  
  /**
   * This is overloaded. toString() is not a formatted report you print for a manager :)
   * Assume that if there are no default assignments, the default feature was not used
   */
  @Override
  public String toString() {
    StringBuilder returnString = new StringBuilder(200);
    returnString.append("=======================================================").append('\n');
    returnString.append("Confusion Matrix\n");
    returnString.append("-------------------------------------------------------").append('\n');
    
    int unclassified = getTotal(defaultLabel);
    for (Map.Entry entry : this.labelMap.entrySet()) {
      if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
        continue;
      }
      
      returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t');
    }
    
    returnString.append("<--Classified as").append('\n');
    for (Map.Entry entry : this.labelMap.entrySet()) {
      if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
        continue;
      }
      String correctLabel = entry.getKey();
      int labelTotal = 0;
      for (String classifiedLabel : this.labelMap.keySet()) {
        if (classifiedLabel.equals(defaultLabel) && unclassified == 0) {
          continue;
        }
        returnString.append(
            StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t');
        labelTotal += getCount(correctLabel, classifiedLabel);
      }
      returnString.append(" |  ").append(StringUtils.rightPad(String.valueOf(labelTotal), 6)).append('\t')
      .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
      .append(" = ").append(correctLabel).append('\n');
    }
    if (unclassified > 0) {
      returnString.append("Default Category: ").append(defaultLabel).append(": ").append(unclassified).append('\n');
    }
    returnString.append('\n');
    return returnString.toString();
  }
  
  static String getSmallLabel(int i) {
    int val = i;
    StringBuilder returnString = new StringBuilder();
    do {
      int n = val % 26;
      returnString.insert(0, (char) ('a' + n));
      val /= 26;
    } while (val > 0);
    return returnString.toString();
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy