org.apache.mahout.classifier.ConfusionMatrixDumper Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-integration Show documentation
Show all versions of mahout-integration Show documentation
Optional components of Mahout which generally support interaction with third party systems,
formats, APIs, etc.
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier;
import com.google.common.collect.Lists;
import org.apache.commons.io.Charsets;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
/**
* Export a ConfusionMatrix in various text formats: ToString version Grayscale HTML table Summary HTML table
* Table of counts all with optional HTML wrappers
*
* Input format: Hadoop SequenceFile with Text key and MatrixWritable value, 1 pair
*
* Intended to consume ConfusionMatrix SequenceFile output by Bayes TestClassifier class
*/
public final class ConfusionMatrixDumper extends AbstractJob {
private static final String TAB_SEPARATOR = "|";
// HTML wrapper - default CSS
private static final String HEADER = ""
+ "\n"
+ "TITLE \n"
+ ""
+ "\n"
+ "\n";
private static final String FOOTER = "";
// CSS style names.
private static final String CSS_TABLE = "table";
private static final String CSS_LABEL = "label";
private static final String CSS_TALL_HEADER = "tall";
private static final String CSS_VERTICAL = "verticalText";
private static final String CSS_CELL = "cell";
private static final String CSS_EMPTY = "empty";
private static final String[] CSS_GRAY_CELLS = {"white", "gray1", "gray2", "gray3", "black"};
private ConfusionMatrixDumper() {}
public static void main(String[] args) throws Exception {
ToolRunner.run(new ConfusionMatrixDumper(), args);
}
@Override
public int run(String[] args) throws IOException {
addInputOption();
addOption("output", "o", "Output path", null); // AbstractJob output feature requires param
addOption(DefaultOptionCreator.overwriteOption().create());
addFlag("html", null, "Create complete HTML page");
addFlag("text", null, "Dump simple text");
Map> parsedArgs = parseArguments(args);
if (parsedArgs == null) {
return -1;
}
Path inputPath = getInputPath();
String outputFile = hasOption("output") ? getOption("output") : null;
boolean text = parsedArgs.containsKey("--text");
boolean wrapHtml = parsedArgs.containsKey("--html");
PrintStream out = getPrintStream(outputFile);
if (text) {
exportText(inputPath, out);
} else {
exportTable(inputPath, out, wrapHtml);
}
out.flush();
if (out != System.out) {
out.close();
}
return 0;
}
private static void exportText(Path inputPath, PrintStream out) throws IOException {
MatrixWritable mw = new MatrixWritable();
Text key = new Text();
readSeqFile(inputPath, key, mw);
Matrix m = mw.get();
ConfusionMatrix cm = new ConfusionMatrix(m);
out.println(String.format("%-40s", "Label") + TAB_SEPARATOR + String.format("%-10s", "Total")
+ TAB_SEPARATOR + String.format("%-10s", "Correct") + TAB_SEPARATOR
+ String.format("%-6s", "%") + TAB_SEPARATOR);
out.println(String.format("%-70s", "-").replace(' ', '-'));
List labels = stripDefault(cm);
for (String label : labels) {
int correct = cm.getCorrect(label);
double accuracy = cm.getAccuracy(label);
int count = getCount(cm, label);
out.println(String.format("%-40s", label) + TAB_SEPARATOR + String.format("%-10s", count)
+ TAB_SEPARATOR + String.format("%-10s", correct) + TAB_SEPARATOR
+ String.format("%-6s", (int) Math.round(accuracy)) + TAB_SEPARATOR);
}
out.println(String.format("%-70s", "-").replace(' ', '-'));
out.println(cm.toString());
}
private static void exportTable(Path inputPath, PrintStream out, boolean wrapHtml) throws IOException {
MatrixWritable mw = new MatrixWritable();
Text key = new Text();
readSeqFile(inputPath, key, mw);
String fileName = inputPath.getName();
fileName = fileName.substring(fileName.lastIndexOf('/') + 1, fileName.length());
Matrix m = mw.get();
ConfusionMatrix cm = new ConfusionMatrix(m);
if (wrapHtml) {
printHeader(out, fileName);
}
out.println("");
printSummaryTable(cm, out);
out.println("");
printGrayTable(cm, out);
out.println("");
printCountsTable(cm, out);
out.println("");
printTextInBox(cm, out);
out.println("");
if (wrapHtml) {
printFooter(out);
}
}
private static List stripDefault(ConfusionMatrix cm) {
List stripped = Lists.newArrayList(cm.getLabels().iterator());
String defaultLabel = cm.getDefaultLabel();
int unclassified = cm.getTotal(defaultLabel);
if (unclassified > 0) {
return stripped;
}
stripped.remove(defaultLabel);
return stripped;
}
// TODO: test - this should work with HDFS files
private static void readSeqFile(Path path, Text key, MatrixWritable m) throws IOException {
Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(conf);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
reader.next(key, m);
}
// TODO: test - this might not work with HDFS files?
// after all, it does no seeks
private static PrintStream getPrintStream(String outputFilename) throws IOException {
if (outputFilename != null) {
File outputFile = new File(outputFilename);
if (outputFile.exists()) {
outputFile.delete();
}
outputFile.createNewFile();
OutputStream os = new FileOutputStream(outputFile);
return new PrintStream(os, false, Charsets.UTF_8.displayName());
} else {
return System.out;
}
}
private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) {
Iterator iter = cm.getLabels().iterator();
int count = 0;
while (iter.hasNext()) {
count += cm.getCount(rowLabel, iter.next());
}
return count;
}
// HTML generator code
private static void printTextInBox(ConfusionMatrix cm, PrintStream out) {
out.println("");
out.println("");
out.println(cm.toString());
out.println("
");
out.println("");
}
public static void printSummaryTable(ConfusionMatrix cm, PrintStream out) {
format("\n", out, CSS_TABLE);
format("", out, CSS_LABEL);
out.println("Label Total Correct % ");
out.println(" ");
List labels = stripDefault(cm);
for (String label : labels) {
printSummaryRow(cm, out, label);
}
out.println("
");
}
private static void printSummaryRow(ConfusionMatrix cm, PrintStream out, String label) {
format("", out, CSS_CELL);
int correct = cm.getCorrect(label);
double accuracy = cm.getAccuracy(label);
int count = getCount(cm, label);
format("%s %d %d %d ", out, CSS_CELL, label, count, correct,
(int) Math.round(accuracy));
out.println(" ");
}
private static int getCount(ConfusionMatrix cm, String label) {
int count = 0;
for (String s : cm.getLabels()) {
count += cm.getCount(label, s);
}
return count;
}
public static void printGrayTable(ConfusionMatrix cm, PrintStream out) {
format("\n", out, CSS_TABLE);
printCountsHeader(cm, out, true);
printGrayRows(cm, out);
out.println("
");
}
/**
* Print each value in a four-value grayscale based on count/max. Gives a mostly white matrix with grays in
* misclassified, and black in diagonal. TODO: Using the sqrt(count/max) as the rating is more stringent
*/
private static void printGrayRows(ConfusionMatrix cm, PrintStream out) {
List labels = stripDefault(cm);
for (String label : labels) {
printGrayRow(cm, out, labels, label);
}
}
private static void printGrayRow(ConfusionMatrix cm,
PrintStream out,
Iterable labels,
String rowLabel) {
format("", out, CSS_LABEL);
format("%s ", out, rowLabel);
int total = getLabelTotal(cm, rowLabel);
for (String columnLabel : labels) {
printGrayCell(cm, out, total, rowLabel, columnLabel);
}
out.println(" ");
}
// assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of inputs
// assign black to count = total, meaning complete success
// alternative rating is to use sqrt(total) instead of total - this is more drastic
private static void printGrayCell(ConfusionMatrix cm,
PrintStream out,
int total,
String rowLabel,
String columnLabel) {
int count = cm.getCount(rowLabel, columnLabel);
if (count == 0) {
out.format(" ", CSS_EMPTY);
} else {
// 0 is white, full is black, everything else gray
int rating = (int) ((count / (double) total) * 4);
String css = CSS_GRAY_CELLS[rating];
format("%s ", out, css, columnLabel, count);
}
}
public static void printCountsTable(ConfusionMatrix cm, PrintStream out) {
format("\n", out, CSS_TABLE);
printCountsHeader(cm, out, false);
printCountsRows(cm, out);
out.println("
");
}
private static void printCountsRows(ConfusionMatrix cm, PrintStream out) {
List labels = stripDefault(cm);
for (String label : labels) {
printCountsRow(cm, out, labels, label);
}
}
private static void printCountsRow(ConfusionMatrix cm,
PrintStream out,
Iterable labels,
String rowLabel) {
out.println("");
format("%s ", out, CSS_LABEL, rowLabel);
for (String columnLabel : labels) {
printCountsCell(cm, out, rowLabel, columnLabel);
}
out.println(" ");
}
private static void printCountsCell(ConfusionMatrix cm, PrintStream out, String rowLabel, String columnLabel) {
int count = cm.getCount(rowLabel, columnLabel);
String s = count == 0 ? "" : Integer.toString(count);
format("%s ", out, CSS_CELL, columnLabel, s);
}
private static void printCountsHeader(ConfusionMatrix cm, PrintStream out, boolean vertical) {
List labels = stripDefault(cm);
int longest = getLongestHeader(labels);
if (vertical) {
// do vertical - rotation is a bitch
out.format(" %n", CSS_TALL_HEADER, longest / 2);
for (String label : labels) {
out.format("%s ", CSS_VERTICAL, label);
}
out.println(" ");
} else {
// header - empty cell in upper left
out.format(" %n", CSS_TABLE, CSS_LABEL);
for (String label : labels) {
out.format("%s ", label);
}
out.format(" ");
}
}
private static int getLongestHeader(Iterable labels) {
int max = 0;
for (String label : labels) {
max = Math.max(label.length(), max);
}
return max;
}
private static void format(String format, PrintStream out, Object... args) {
String format2 = String.format(format, args);
out.println(format2);
}
public static void printHeader(PrintStream out, CharSequence title) {
out.println(HEADER.replace("TITLE", title));
}
public static void printFooter(PrintStream out) {
out.println(FOOTER);
}
}