All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.lwhite1.tablesaw.api.ml.classification.StandardConfusionMatrix Maven / Gradle / Ivy

Go to download

High-performance Java Dataframe with integrated columnar storage

There is a newer version: 0.7.7.3
Show newest version
package com.github.lwhite1.tablesaw.api.ml.classification;

import com.github.lwhite1.tablesaw.api.CategoryColumn;
import com.github.lwhite1.tablesaw.api.IntColumn;
import com.google.common.collect.Table;
import com.google.common.collect.TreeBasedTable;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;

/**
 * A confusion matrix is used to measure the accuracy of a classifier by counting the number of correct and
 * incorrect values produced when testing the classifier such that the counts are made for every combination of
 * correct and incorrect classification
 */
public class StandardConfusionMatrix implements ConfusionMatrix {

  private final Table table = TreeBasedTable.create();

  private SortedMap labels = new TreeMap<>();

  public StandardConfusionMatrix(SortedSet labels) {

    int i = 0;
    for (Object object : labels) {
      this.labels.put(i, object);
      i++;
    }
  }

  public void increment(Integer predicted, Integer actual) {
    Integer v = table.get(predicted, actual);
    if (v == null) {
      table.put(predicted, actual, 1);
    } else {
      table.put(predicted, actual, v + 1);
    }
  }

  @Override
  public String toString() {
    return toTable().print();
  }

  public com.github.lwhite1.tablesaw.api.Table toTable() {
    com.github.lwhite1.tablesaw.api.Table t = com.github.lwhite1.tablesaw.api.Table.create("Confusion Matrix");
    t.addColumn(CategoryColumn.create(""));

    // make a set of all the values needed, from the prediction set or the actual set
    TreeSet allValues = new TreeSet<>();
    allValues.addAll(table.columnKeySet());
    allValues.addAll(table.rowKeySet());

    for (Integer comparable : allValues) {
      t.addColumn(IntColumn.create(String.valueOf(labels.get(comparable))));
      t.column(0).addCell("Predicted "  + labels.get(comparable));
    }

    // put them in a list so they can be accessed by index number
    List valuesList = new ArrayList<>(allValues);
    int n = 0;
    for (int r = 0; r < valuesList.size(); r++) {
      for(int c = 0; c < valuesList.size(); c++) {
        Integer value = table.get(valuesList.get(r), valuesList.get(c));
        if (value == null) {
          t.intColumn(c + 1).add(0);
        } else {
          t.intColumn(c + 1).add(value);
          n = n + value;
        }
      }
    }
    t.column(0).setName("n = " + n);
    for (int c = 1; c <= valuesList.size(); c++) {
      t.column(c).setName("Actual " + labels.get(c - 1));
    }
    return t;
  }

  public double accuracy() {
    Set> cellSet = table.cellSet();
    int hits = 0;
    int misses = 0;
    for (Table.Cell cell : cellSet) {
      if (cell.getRowKey().equals(cell.getColumnKey())) {
        hits = hits + (int) cell.getValue();
      } else {
        misses = misses + (int) cell.getValue();
      }
    }
    return hits / ((hits + misses) * 1.0);
  }
}