All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.evaluation.classification.ConfusionMatrix Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.evaluation.classification;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import lombok.Getter;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class ConfusionMatrix> implements Serializable {
    @Getter
    private volatile Map> matrix;
    private List classes;

    /**
     * Creates an empty confusion Matrix
     */
    public ConfusionMatrix(List classes) {
        this.matrix = new ConcurrentHashMap<>();
        this.classes = classes;
    }

    public ConfusionMatrix() {
        this(new ArrayList());
    }

    /**
     * Creates a new ConfusionMatrix initialized with the contents of another ConfusionMatrix.
     */
    public ConfusionMatrix(ConfusionMatrix other) {
        this(other.getClasses());
        this.add(other);
    }

    /**
     * Increments the entry specified by actual and predicted by one.
     */
    public synchronized void add(T actual, T predicted) {
        add(actual, predicted, 1);
    }

    /**
     * Increments the entry specified by actual and predicted by count.
     */
    public synchronized void add(T actual, T predicted, int count) {
        if (matrix.containsKey(actual)) {
            matrix.get(actual).add(predicted, count);
        } else {
            Multiset counts = HashMultiset.create();
            counts.add(predicted, count);
            matrix.put(actual, counts);
        }
    }

    /**
     * Adds the entries from another confusion matrix to this one.
     */
    public synchronized void add(ConfusionMatrix other) {
        for (T actual : other.matrix.keySet()) {
            Multiset counts = other.matrix.get(actual);
            for (T predicted : counts.elementSet()) {
                int count = counts.count(predicted);
                this.add(actual, predicted, count);
            }
        }
    }

    /**
     * Gives the applyTransformToDestination of all classes in the confusion matrix.
     */
    public List getClasses() {
        if (classes == null)
            classes = new ArrayList<>();
        return classes;
    }

    /**
     * Gives the count of the number of times the "predicted" class was predicted for the "actual"
     * class.
     */
    public synchronized int getCount(T actual, T predicted) {
        if (!matrix.containsKey(actual)) {
            return 0;
        } else {
            return matrix.get(actual).count(predicted);
        }
    }

    /**
     * Computes the total number of times the class was predicted by the classifier.
     */
    public synchronized int getPredictedTotal(T predicted) {
        int total = 0;
        for (T actual : classes) {
            total += getCount(actual, predicted);
        }
        return total;
    }

    /**
     * Computes the total number of times the class actually appeared in the data.
     */
    public synchronized int getActualTotal(T actual) {
        if (!matrix.containsKey(actual)) {
            return 0;
        } else {
            int total = 0;
            for (T elem : matrix.get(actual).elementSet()) {
                total += matrix.get(actual).count(elem);
            }
            return total;
        }
    }

    @Override
    public String toString() {
        return matrix.toString();
    }

    /**
     * Outputs the ConfusionMatrix as comma-separated values for easy import into spreadsheets
     */
    public String toCSV() {
        StringBuilder builder = new StringBuilder();

        // Header Row
        builder.append(",,Predicted Class,\n");

        // Predicted Classes Header Row
        builder.append(",,");
        for (T predicted : classes) {
            builder.append(String.format("%s,", predicted));
        }
        builder.append("Total\n");

        // Data Rows
        String firstColumnLabel = "Actual Class,";
        for (T actual : classes) {
            builder.append(firstColumnLabel);
            firstColumnLabel = ",";
            builder.append(String.format("%s,", actual));

            for (T predicted : classes) {
                builder.append(getCount(actual, predicted));
                builder.append(",");
            }
            // Actual Class Totals Column
            builder.append(getActualTotal(actual));
            builder.append("\n");
        }

        // Predicted Class Totals Row
        builder.append(",Total,");
        for (T predicted : classes) {
            builder.append(getPredictedTotal(predicted));
            builder.append(",");
        }
        builder.append("\n");

        return builder.toString();
    }

    /**
     * Outputs Confusion Matrix in an HTML table. Cascading Style Sheets (CSS) can control the table's
     * appearance by defining the empty-space, actual-count-header, predicted-class-header, and
     * count-element classes. For example
     *
     * @return html string
     */
    public String toHTML() {
        StringBuilder builder = new StringBuilder();

        int numClasses = classes.size();
        // Header Row
        builder.append("\n");
        builder.append("%n",
                        numClasses + 1));

        // Predicted Classes Header Row
        builder.append("");
        // builder.append("");
        for (T predicted : classes) {
            builder.append("");
        }
        builder.append("");
        builder.append("\n");

        // Data Rows
        String firstColumnLabel = String.format(
                        "", numClasses + 1);
        for (T actual : classes) {
            builder.append(firstColumnLabel);
            firstColumnLabel = "";
            builder.append(String.format("", actual));

            for (T predicted : classes) {
                builder.append("");
            }

            // Actual Class Totals Column
            builder.append("");
            builder.append("\n");
        }

        // Predicted Class Totals Row
        builder.append("");
        for (T predicted : classes) {
            builder.append("");
        }
        builder.append("\n");
        builder.append("\n");
        builder.append("
"); builder.append(String.format("Predicted Class
"); builder.append(predicted); builder.append("Total
Actual Class
%s"); builder.append(getCount(actual, predicted)); builder.append(""); builder.append(getActualTotal(actual)); builder.append("
Total"); builder.append(getPredictedTotal(predicted)); builder.append("
\n"); return builder.toString(); } @Override public boolean equals(Object o) { if (!(o instanceof ConfusionMatrix)) return false; ConfusionMatrix c = (ConfusionMatrix) o; return matrix.equals(c.matrix) && classes.equals(c.classes); } @Override public int hashCode() { int result = 17; result = 31 * result + (matrix == null ? 0 : matrix.hashCode()); result = 31 * result + (classes == null ? 0 : classes.hashCode()); return result; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy