All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.algcomparison.statistic.utils.BidirectedConfusion Maven / Gradle / Ivy

The newest version!
package edu.cmu.tetrad.algcomparison.statistic.utils;

import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;

import java.util.HashSet;
import java.util.Set;

/**
 * A confusion matrix for bidireced edges--i.e. TP, FP, TN, FN for counts of bidirected edges.
 *
 * @author josephramsey
 * @version $Id: $Id
 */
public class BidirectedConfusion {

    /**
     * The true negative count.
     */
    private final int tn;

    /**
     * The true positive count.
     */
    private int tp;

    /**
     * The false positive count.
     */
    private int fp;

    /**
     * The false negative count.
     */
    private int fn;

    /**
     * Constructs a new confusion matrix for bidirected edges.
     *
     * @param truth The true graph.
     * @param est   The estimated graph.
     */
    public BidirectedConfusion(Graph truth, Graph est) {
        this.tp = 0;
        this.fp = 0;
        this.fn = 0;

        Set allBidirected = new HashSet<>();

        for (Edge edge : truth.getEdges()) {
            if (Edges.isBidirectedEdge(edge)) {
                allBidirected.add(edge);
            }
        }

        for (Edge edge : est.getEdges()) {
            if (Edges.isBidirectedEdge(edge)) {
                allBidirected.add(edge);
            }
        }

        for (Edge edge : allBidirected) {
            if (est.containsEdge(edge) && !truth.containsEdge(edge)) {
                this.fp++;
            }

            if (truth.containsEdge(edge) && !est.containsEdge(edge)) {
                this.fn++;
            }

            if (truth.containsEdge(edge) && est.containsEdge(edge)) {
                this.tp++;
            }
        }

        int all = truth.getNumNodes() * (truth.getNumNodes() - 1) / 2;

        this.tn = all - this.fn - this.fp - this.fn;
    }

    /**
     * Returns the number of true positives.
     *
     * @return The number of true positives.
     */
    public int getTp() {
        return this.tp;
    }

    /**
     * Returns the number of false positives.
     *
     * @return The number of false positives.
     */
    public int getFp() {
        return this.fp;
    }

    /**
     * Returns the number of false negatives.
     *
     * @return The number of false negatives.
     */
    public int getFn() {
        return this.fn;
    }

    /**
     * Returns the number of true negatives.
     *
     * @return The number of true negatives.
     */
    public int getTn() {
        return this.tn;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy