All Downloads are FREE. Search and download functionalities are using the official Maven repository.

boofcv.gui.learning.ConfusionMatrixPanel Maven / Gradle / Ivy

Go to download

BoofCV is an open source Java library for real-time computer vision and robotics applications.

There is a newer version: 1.1.7
Show newest version
/*
 * Copyright (c) 2021, Peter Abeles. All Rights Reserved.
 *
 * This file is part of BoofCV (http://boofcv.org).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package boofcv.gui.learning;

import boofcv.gui.image.ShowImages;
import lombok.Getter;
import lombok.Setter;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.RandomMatrices_DDRM;

import javax.swing.*;
import java.awt.*;
import java.awt.geom.Rectangle2D;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/**
 * Visualizes a confusion matrix. Each element is assumed to have a value from 0 to 1.0
 *
 * @author Peter Abeles
 */
@SuppressWarnings({"NullAway.Init"})
public class ConfusionMatrixPanel extends JPanel {
	DMatrixRMaj temp = new DMatrixRMaj(1, 1);
	DMatrixRMaj confusion = new DMatrixRMaj(1, 1);
	boolean dirty = false;

	@Getter @Setter boolean gray = false;

	@Getter @Setter boolean showNumbers = true;
	@Getter @Setter boolean showLabels = true;
	@Getter @Setter boolean showZeros = true;

	// fraction of the width that labels occupy
	double labelViewFraction = 0.30;
	List labels;

	// if set to a valid category then that category will be highlighted
	int highlightCategory = -1;

	// internal variables used for rendering
	int viewHeight, viewWidth;
	int gridHeight, gridWidth;
	boolean renderLabels;

	/**
	 * Constructor that specifies the confusion matrix and width/height
	 *
	 * @param labels Optional labels for the confusion matrix.
	 * @param widthPixels preferred width and height of the panel in pixels
	 * @param gray Render gray scale or color image
	 */
	public ConfusionMatrixPanel( DMatrixRMaj M, List labels, int widthPixels, boolean gray ) {
		this(widthPixels, labels != null);

		setLabels(labels);
		setMatrix(M);
		this.gray = gray;
	}

	/**
	 * Constructor in which the prefered width and height is specified in pixels
	 *
	 * @param widthPixels preferred width and height
	 */
	public ConfusionMatrixPanel( int widthPixels, boolean hasLabels ) {

		int heightPixels = widthPixels;
		if (hasLabels) {
			heightPixels = (int)(heightPixels*(1.0 - labelViewFraction));
		}

		setPreferredSize(new Dimension(widthPixels, heightPixels));
	}

	public void setMatrix( DMatrixRMaj A ) {
		synchronized (this) {
			temp.setTo(A);
			dirty = true;
		}
		repaint();
	}

	public void setLabels( List labels ) {
		this.labels = new ArrayList<>(labels);
	}

	public int getHighlightCategory() {
		return highlightCategory;
	}

	public void setHighlightCategory( int highlightCategory ) {
		this.highlightCategory = highlightCategory;
	}

	@Override
	public synchronized void paint( Graphics g ) {
		synchronized (this) {
			if (dirty) {
				confusion.setTo(temp);
				dirty = false;
			}
		}

		Graphics2D g2 = (Graphics2D)g;

		int numCategories = confusion.getNumRows();

		synchronized (this) {
			viewHeight = getHeight();
			viewWidth = getWidth();

			gridHeight = viewHeight;
			gridWidth = viewWidth;

			renderLabels = this.showLabels && labels != null;
			if (renderLabels) {
//			gridHeight *= 1.0-labelViewFraction;
				gridWidth = (int)(gridWidth*(1.0 - labelViewFraction));
			}
		}

		double fontSize = Math.min(gridWidth/numCategories, gridHeight/numCategories);

		g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);

		if (renderLabels) {
			renderLabels(g2, fontSize);
		}

		renderMatrix(g2, fontSize);

		if (highlightCategory >= 0 && highlightCategory < numCategories) {
			g2.setColor(new Color(255, 255, 0, 100));

			int ry = (int)(0.1*gridHeight/numCategories);
			int rx = (int)(0.1*gridWidth/numCategories);


			int y0 = highlightCategory*gridHeight/numCategories;
			int y1 = (highlightCategory + 1)*gridHeight/numCategories;

			int x0 = highlightCategory*gridWidth/numCategories;
			int x1 = (highlightCategory + 1)*gridWidth/numCategories;

			g2.fillRect(x0 + rx, 0, x1 - x0 - 2*rx, gridHeight);
			g2.fillRect(0, y0 + ry, viewWidth, y1 - y0 - 2*ry);
		}
	}

	/**
	 * Renders the names on each category to the side of the confusion matrix
	 */
	private void renderLabels( Graphics2D g2, double fontSize ) {
		int numCategories = confusion.getNumRows();

		int longestLabel = 0;
		if (renderLabels) {
			for (int i = 0; i < numCategories; i++) {
				longestLabel = Math.max(longestLabel, labels.get(i).length());
			}
		}

		Font fontLabel = new Font("monospaced", Font.BOLD, (int)(0.055*longestLabel*fontSize + 0.5));
		g2.setFont(fontLabel);
		FontMetrics metrics = g2.getFontMetrics(fontLabel);

		// clear the background
		g2.setColor(Color.WHITE);
		g2.fillRect(gridWidth, 0, viewWidth - gridWidth, viewHeight);

		// draw the text
		g2.setColor(Color.BLACK);
		for (int i = 0; i < numCategories; i++) {
			String label = labels.get(i);

			int y0 = i*gridHeight/numCategories;
			int y1 = (i + 1)*gridHeight/numCategories;

			Rectangle2D r = metrics.getStringBounds(label, null);

			float adjX = (float)(r.getX()*2 + r.getWidth())/2.0f;
			float adjY = (float)(r.getY()*2 + r.getHeight())/2.0f;

			float x = ((viewWidth + gridWidth)/2f - adjX);
			float y = ((y1 + y0)/2f - adjY);

			g2.drawString(label, x, y);
		}
	}

	/**
	 * Renders the confusion matrix and visualizes the value in each cell with a color and optionally a color.
	 */
	private void renderMatrix( Graphics2D g2, double fontSize ) {
		int numCategories = confusion.getNumRows();

		Font fontNumber = new Font("Serif", Font.BOLD, (int)(0.6*fontSize + 0.5));
		g2.setFont(fontNumber);
		FontMetrics metrics = g2.getFontMetrics(fontNumber);
		for (int i = 0; i < numCategories; i++) {
			int y0 = i*gridHeight/numCategories;
			int y1 = (i + 1)*gridHeight/numCategories;

			for (int j = 0; j < numCategories; j++) {
				int x0 = j*gridWidth/numCategories;
				int x1 = (j + 1)*gridWidth/numCategories;

				double value = confusion.unsafe_get(i, j);

				int red, green, blue;
				if (gray) {
					red = green = blue = (int)(255*(1.0 - value));
				} else {
					green = 0;
					red = (int)(255*value);
					blue = (int)(255*(1.0 - value));
				}
				g2.setColor(new Color(red, green, blue));

				g2.fillRect(x0, y0, x1 - x0, y1 - y0);

				// Render numbers inside the squares. Pick a color so that the number is visible no matter what
				// the color of the square is
				if (showNumbers && (showZeros || value != 0)) {
					int a = (red + green + blue)/3;

					String text = "" + (int)(value*100.0 + 0.5);
					Rectangle2D r = metrics.getStringBounds(text, null);

					float adjX = (float)(r.getX()*2 + r.getWidth())/2.0f;
					float adjY = (float)(r.getY()*2 + r.getHeight())/2.0f;

					float x = ((x1 + x0)/2f - adjX);
					float y = ((y1 + y0)/2f - adjY);

					int gray = a > 127 ? 0 : 255;

					g2.setColor(new Color(gray, gray, gray));
					g2.drawString(text, x, y);
				}
			}
		}
	}

	/**
	 * Use to sample the panel to see what is being displayed at the location clicked. All coordinates
	 * are in panel coordinates.
	 *
	 * @param pixelX x-axis in panel coordinates
	 * @param pixelY y-axis in panel coordinates
	 * @param output (Optional) storage for output.
	 * @return Information on what is at the specified location
	 */
	public LocationInfo whatIsAtPoint( int pixelX, int pixelY, LocationInfo output ) {
		if (output == null)
			output = new LocationInfo();

		int numCategories = confusion.getNumRows();

		synchronized (this) {
			if (pixelX >= gridWidth) {
				output.insideMatrix = false;
				output.col = output.row = pixelY*numCategories/gridHeight;
			} else {
				output.insideMatrix = true;
				output.row = pixelY*numCategories/gridHeight;
				output.col = pixelX*numCategories/gridWidth;
			}
		}

		return output;
	}

	/**
	 * Contains information on what was at the point
	 */
	public static class LocationInfo {
		public boolean insideMatrix;
		public int row, col;
	}

	public static void main( String[] args ) {
		DMatrixRMaj m = RandomMatrices_DDRM.rectangle(5, 5, 0, 1, new Random(234));

		List labels = new ArrayList<>();
		for (int i = 0; i < m.numRows; i++) {
			labels.add("Label " + i);
		}

		ConfusionMatrixPanel confusion = new ConfusionMatrixPanel(m, labels, 300, false);
		confusion.setHighlightCategory(2);
		ShowImages.showWindow(confusion, "Window", true);
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy