boofcv.gui.learning.ConfusionMatrixPanel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of boofcv-swing Show documentation
Show all versions of boofcv-swing Show documentation
BoofCV is an open source Java library for real-time computer vision and robotics applications.
/*
* Copyright (c) 2021, Peter Abeles. All Rights Reserved.
*
* This file is part of BoofCV (http://boofcv.org).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package boofcv.gui.learning;
import boofcv.gui.image.ShowImages;
import lombok.Getter;
import lombok.Setter;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.RandomMatrices_DDRM;
import javax.swing.*;
import java.awt.*;
import java.awt.geom.Rectangle2D;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
/**
* Visualizes a confusion matrix. Each element is assumed to have a value from 0 to 1.0
*
* @author Peter Abeles
*/
@SuppressWarnings({"NullAway.Init"})
public class ConfusionMatrixPanel extends JPanel {
DMatrixRMaj temp = new DMatrixRMaj(1, 1);
DMatrixRMaj confusion = new DMatrixRMaj(1, 1);
boolean dirty = false;
@Getter @Setter boolean gray = false;
@Getter @Setter boolean showNumbers = true;
@Getter @Setter boolean showLabels = true;
@Getter @Setter boolean showZeros = true;
// fraction of the width that labels occupy
double labelViewFraction = 0.30;
List labels;
// if set to a valid category then that category will be highlighted
int highlightCategory = -1;
// internal variables used for rendering
int viewHeight, viewWidth;
int gridHeight, gridWidth;
boolean renderLabels;
/**
* Constructor that specifies the confusion matrix and width/height
*
* @param labels Optional labels for the confusion matrix.
* @param widthPixels preferred width and height of the panel in pixels
* @param gray Render gray scale or color image
*/
public ConfusionMatrixPanel( DMatrixRMaj M, List labels, int widthPixels, boolean gray ) {
this(widthPixels, labels != null);
setLabels(labels);
setMatrix(M);
this.gray = gray;
}
/**
* Constructor in which the prefered width and height is specified in pixels
*
* @param widthPixels preferred width and height
*/
public ConfusionMatrixPanel( int widthPixels, boolean hasLabels ) {
int heightPixels = widthPixels;
if (hasLabels) {
heightPixels = (int)(heightPixels*(1.0 - labelViewFraction));
}
setPreferredSize(new Dimension(widthPixels, heightPixels));
}
public void setMatrix( DMatrixRMaj A ) {
synchronized (this) {
temp.setTo(A);
dirty = true;
}
repaint();
}
public void setLabels( List labels ) {
this.labels = new ArrayList<>(labels);
}
public int getHighlightCategory() {
return highlightCategory;
}
public void setHighlightCategory( int highlightCategory ) {
this.highlightCategory = highlightCategory;
}
@Override
public synchronized void paint( Graphics g ) {
synchronized (this) {
if (dirty) {
confusion.setTo(temp);
dirty = false;
}
}
Graphics2D g2 = (Graphics2D)g;
int numCategories = confusion.getNumRows();
synchronized (this) {
viewHeight = getHeight();
viewWidth = getWidth();
gridHeight = viewHeight;
gridWidth = viewWidth;
renderLabels = this.showLabels && labels != null;
if (renderLabels) {
// gridHeight *= 1.0-labelViewFraction;
gridWidth = (int)(gridWidth*(1.0 - labelViewFraction));
}
}
double fontSize = Math.min(gridWidth/numCategories, gridHeight/numCategories);
g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
if (renderLabels) {
renderLabels(g2, fontSize);
}
renderMatrix(g2, fontSize);
if (highlightCategory >= 0 && highlightCategory < numCategories) {
g2.setColor(new Color(255, 255, 0, 100));
int ry = (int)(0.1*gridHeight/numCategories);
int rx = (int)(0.1*gridWidth/numCategories);
int y0 = highlightCategory*gridHeight/numCategories;
int y1 = (highlightCategory + 1)*gridHeight/numCategories;
int x0 = highlightCategory*gridWidth/numCategories;
int x1 = (highlightCategory + 1)*gridWidth/numCategories;
g2.fillRect(x0 + rx, 0, x1 - x0 - 2*rx, gridHeight);
g2.fillRect(0, y0 + ry, viewWidth, y1 - y0 - 2*ry);
}
}
/**
* Renders the names on each category to the side of the confusion matrix
*/
private void renderLabels( Graphics2D g2, double fontSize ) {
int numCategories = confusion.getNumRows();
int longestLabel = 0;
if (renderLabels) {
for (int i = 0; i < numCategories; i++) {
longestLabel = Math.max(longestLabel, labels.get(i).length());
}
}
Font fontLabel = new Font("monospaced", Font.BOLD, (int)(0.055*longestLabel*fontSize + 0.5));
g2.setFont(fontLabel);
FontMetrics metrics = g2.getFontMetrics(fontLabel);
// clear the background
g2.setColor(Color.WHITE);
g2.fillRect(gridWidth, 0, viewWidth - gridWidth, viewHeight);
// draw the text
g2.setColor(Color.BLACK);
for (int i = 0; i < numCategories; i++) {
String label = labels.get(i);
int y0 = i*gridHeight/numCategories;
int y1 = (i + 1)*gridHeight/numCategories;
Rectangle2D r = metrics.getStringBounds(label, null);
float adjX = (float)(r.getX()*2 + r.getWidth())/2.0f;
float adjY = (float)(r.getY()*2 + r.getHeight())/2.0f;
float x = ((viewWidth + gridWidth)/2f - adjX);
float y = ((y1 + y0)/2f - adjY);
g2.drawString(label, x, y);
}
}
/**
* Renders the confusion matrix and visualizes the value in each cell with a color and optionally a color.
*/
private void renderMatrix( Graphics2D g2, double fontSize ) {
int numCategories = confusion.getNumRows();
Font fontNumber = new Font("Serif", Font.BOLD, (int)(0.6*fontSize + 0.5));
g2.setFont(fontNumber);
FontMetrics metrics = g2.getFontMetrics(fontNumber);
for (int i = 0; i < numCategories; i++) {
int y0 = i*gridHeight/numCategories;
int y1 = (i + 1)*gridHeight/numCategories;
for (int j = 0; j < numCategories; j++) {
int x0 = j*gridWidth/numCategories;
int x1 = (j + 1)*gridWidth/numCategories;
double value = confusion.unsafe_get(i, j);
int red, green, blue;
if (gray) {
red = green = blue = (int)(255*(1.0 - value));
} else {
green = 0;
red = (int)(255*value);
blue = (int)(255*(1.0 - value));
}
g2.setColor(new Color(red, green, blue));
g2.fillRect(x0, y0, x1 - x0, y1 - y0);
// Render numbers inside the squares. Pick a color so that the number is visible no matter what
// the color of the square is
if (showNumbers && (showZeros || value != 0)) {
int a = (red + green + blue)/3;
String text = "" + (int)(value*100.0 + 0.5);
Rectangle2D r = metrics.getStringBounds(text, null);
float adjX = (float)(r.getX()*2 + r.getWidth())/2.0f;
float adjY = (float)(r.getY()*2 + r.getHeight())/2.0f;
float x = ((x1 + x0)/2f - adjX);
float y = ((y1 + y0)/2f - adjY);
int gray = a > 127 ? 0 : 255;
g2.setColor(new Color(gray, gray, gray));
g2.drawString(text, x, y);
}
}
}
}
/**
* Use to sample the panel to see what is being displayed at the location clicked. All coordinates
* are in panel coordinates.
*
* @param pixelX x-axis in panel coordinates
* @param pixelY y-axis in panel coordinates
* @param output (Optional) storage for output.
* @return Information on what is at the specified location
*/
public LocationInfo whatIsAtPoint( int pixelX, int pixelY, LocationInfo output ) {
if (output == null)
output = new LocationInfo();
int numCategories = confusion.getNumRows();
synchronized (this) {
if (pixelX >= gridWidth) {
output.insideMatrix = false;
output.col = output.row = pixelY*numCategories/gridHeight;
} else {
output.insideMatrix = true;
output.row = pixelY*numCategories/gridHeight;
output.col = pixelX*numCategories/gridWidth;
}
}
return output;
}
/**
* Contains information on what was at the point
*/
public static class LocationInfo {
public boolean insideMatrix;
public int row, col;
}
public static void main( String[] args ) {
DMatrixRMaj m = RandomMatrices_DDRM.rectangle(5, 5, 0, 1, new Random(234));
List labels = new ArrayList<>();
for (int i = 0; i < m.numRows; i++) {
labels.add("Label " + i);
}
ConfusionMatrixPanel confusion = new ConfusionMatrixPanel(m, labels, 300, false);
confusion.setHighlightCategory(2);
ShowImages.showWindow(confusion, "Window", true);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy