All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.algcomparison.CompareTwoGraphs Maven / Gradle / Ivy

There is a newer version: 7.6.5
Show newest version
package edu.cmu.tetrad.algcomparison;

import edu.cmu.tetrad.algcomparison.statistic.*;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.search.utils.GraphSearchUtils;
import edu.cmu.tetrad.util.TextTable;
import org.jetbrains.annotations.NotNull;

import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;

import static java.util.Collections.sort;

/**
 * Gives the comparison of a target graph to a reference graph that is implemented in the interface. Three methods are
 * given, one to return the edgewise comparison, one to return the stats list comparison, and one to return the
 * misclassification comparison. Each returns a String, which can be printed.
 *
 * @author josephramsey
 */
public class CompareTwoGraphs {

    /**
     * Returns an edgewise comparison of two graphs. This says, edge by edge, what the differences and similarities are
     * between the two graphs.
     *
     * @param trueGraph   The true graph.
     * @param targetGraph The target graph.
     * @return The comparison string.
     */
    @NotNull
    public static String getEdgewiseComparisonString(Graph trueGraph, Graph targetGraph) {
        boolean printStars = false;

        StringBuilder builder = new StringBuilder();
        GraphUtils.GraphComparison comparison = GraphSearchUtils.getGraphComparison(trueGraph, targetGraph);

        List edgesAdded = comparison.getEdgesAdded();
        List edgesAdded2 = new ArrayList<>();

        for (Edge e1 : edgesAdded) {
            Node n1 = e1.getNode1();
            Node n2 = e1.getNode2();

            boolean twoCycle1 = trueGraph.getDirectedEdge(n1, n2) != null && trueGraph.getDirectedEdge(n2, n1) != null;
            boolean twoCycle2 = targetGraph.getDirectedEdge(n1, n2) != null && targetGraph.getDirectedEdge(n2, n1) != null;

            if (!(twoCycle1 || twoCycle2)) {
                edgesAdded2.add(e1);
            }
        }

        sort(edgesAdded2);

        builder.append("\nAdjacencies added (not involving 2-cycles and not reoriented):");

        if (edgesAdded2.isEmpty()) {
            builder.append("\n  --NONE--");
        } else {
            for (int i = 0; i < edgesAdded2.size(); i++) {
                Edge _edge = edgesAdded2.get(i);

                builder.append("\n").append(i + 1).append(". ").append(_edge.toString());
            }
        }

        builder.append("\n\nAdjacencies removed:");
        List edgesRemoved = comparison.getEdgesRemoved();
        sort(edgesRemoved);

        if (edgesRemoved.isEmpty()) {
            builder.append("\n  --NONE--");
        } else {
            for (int i = 0; i < edgesRemoved.size(); i++) {
                Edge edge = edgesRemoved.get(i);

                Node node1 = trueGraph.getNode(edge.getNode1().getName());
                Node node2 = trueGraph.getNode(edge.getNode2().getName());

                builder.append("\n").append(i + 1).append(". ").append(edge);

                if (printStars) {
                    boolean directedInGraph1 = false;

                    if (Edges.isDirectedEdge(edge) && trueGraph.paths().existsSemiDirectedPath(node1, node2)) {
                        directedInGraph1 = true;
                    } else if ((Edges.isUndirectedEdge(edge) || Edges.isBidirectedEdge(edge))
                            && (trueGraph.paths().existsSemiDirectedPath(node1, node2)
                            || trueGraph.paths().existsSemiDirectedPath(node2, node1))) {
                        directedInGraph1 = true;
                    }

                    if (directedInGraph1) {
                        builder.append(" *");
                    }
                }
            }
        }

        List edges1 = new ArrayList<>(trueGraph.getEdges());

        List twoCycles = new ArrayList<>();
        List allSingleEdges = new ArrayList<>();

        for (Edge edge : edges1) {
            if (edge.isDirected() && targetGraph.containsEdge(edge) && targetGraph.containsEdge(edge.reverse())) {
                twoCycles.add(edge);
            } else if (trueGraph.containsEdge(edge)) {
                allSingleEdges.add(edge);
            }
        }

        builder.append("\n\n"
                + "Two-cycles in true correctly adjacent in estimated");

        sort(allSingleEdges);

        if (twoCycles.isEmpty()) {
            builder.append("\n  --NONE--");
        } else {
            for (int i = 0; i < twoCycles.size(); i++) {
                Edge adj = edges1.get(i);
                builder.append("\n").append(i + 1).append(". ").append(adj).append(" ").append(adj.reverse())
                        .append(" ====> ").append(trueGraph.getEdge(twoCycles.get(i).getNode1(), twoCycles.get(i).getNode2()));
            }
        }

        List incorrect = new ArrayList<>();

        for (Edge adj : allSingleEdges) {
            Edge edge1 = trueGraph.getEdge(adj.getNode1(), adj.getNode2());
            Edge edge2 = targetGraph.getEdge(adj.getNode1(), adj.getNode2());

            if (!edge1.equals(edge2)) {
                incorrect.add(adj);
            }
        }

        {
            builder.append("\n\n" + "Edges incorrectly oriented");

            if (incorrect.isEmpty()) {
                builder.append("\n  --NONE--");
            } else {
                int j1 = 0;
                sort(incorrect);

                for (Edge adj : incorrect) {
                    Edge edge1 = trueGraph.getEdge(adj.getNode1(), adj.getNode2());
                    Edge edge2 = targetGraph.getEdge(adj.getNode1(), adj.getNode2());
                    if (edge1 == null || edge2 == null) continue;
                    builder.append("\n").append(++j1).append(". ").append(edge1).append(" ====> ").append(edge2);
                }
            }
        }

        {
            builder.append("\n\n" + "Edges correctly oriented");

            List correct = new ArrayList<>();

            for (Edge adj : allSingleEdges) {
                Edge edge1 = trueGraph.getEdge(adj.getNode1(), adj.getNode2());
                Edge edge2 = targetGraph.getEdge(adj.getNode1(), adj.getNode2());
                if (edge1.equals(edge2)) {
                    correct.add(edge1);
                }
            }

            if (correct.isEmpty()) {
                builder.append("\n  --NONE--");
            } else {
                sort(correct);

                int j2 = 0;

                for (Edge edge : correct) {
                    builder.append("\n").append(++j2).append(". ").append(edge);
                }
            }
        }
        return builder.toString();
    }

    /**
     * Returns a string representing a table of statistics that can be printed.
     *
     * @param trueGraph   The true graph.
     * @param targetGraph The target graph.
     * @return The comparison string.
     */
    public static String getStatsListTable(Graph trueGraph, Graph targetGraph) {
        return getStatsListTable(trueGraph, targetGraph, null);
    }

    /**
     * Returns a string representing a table of statistics that can be printed.
     *
     * @param trueGraph   The true graph.
     * @param targetGraph The target graph.
     * @param dataModel   The data model; some statistics (like BIC) may use this.
     * @return The comparison string.
     */
    public static String getStatsListTable(Graph trueGraph, Graph targetGraph, DataModel dataModel) {
        Graph _targetGraph = GraphUtils.replaceNodes(targetGraph, trueGraph.getNodes());

        List statistics = statistics();

        TextTable table = new TextTable(statistics.size(), 3);
        NumberFormat nf = new DecimalFormat("0.###");

        List abbr = new ArrayList<>();
        List desc = new ArrayList<>();
        List values = new ArrayList<>();

        for (Statistic statistic : statistics) {
            try {
                values.add(statistic.getValue(trueGraph, _targetGraph, dataModel));
                abbr.add(statistic.getAbbreviation());
                desc.add(statistic.getDescription());
            } catch (Exception ignored) {
            }
        }

        for (int i = 0; i < abbr.size(); i++) {
            double value = values.get(i);
            table.setToken(i, 1, Double.isNaN(value) ? "-" : "" + nf.format(value));
            table.setToken(i, 0, abbr.get(i));
            table.setToken(i, 2, desc.get(i));
        }

        table.setJustification(TextTable.LEFT_JUSTIFIED);

        return table.toString();
    }

    private static List statistics() {
        List statistics = new ArrayList<>();

        // Others
        statistics.add(new AdjacencyPrecision());
        statistics.add(new AdjacencyRecall());
        statistics.add(new ArrowheadPrecision());
        statistics.add(new ArrowheadRecall());
        statistics.add(new ArrowheadPrecisionCommonEdges());
        statistics.add(new ArrowheadRecallCommonEdges());
        statistics.add(new AdjacencyTn());
        statistics.add(new AdjacencyTp());
        statistics.add(new AdjacencyTpr());
        statistics.add(new AdjacencyFpr());
        statistics.add(new AdjacencyFn());
        statistics.add(new AdjacencyFp());
        statistics.add(new AdjacencyFn());
        statistics.add(new ArrowheadTn());
        statistics.add(new ArrowheadTp());
        statistics.add(new F1Adj());
        statistics.add(new F1All());
        statistics.add(new F1Arrow());
        statistics.add(new MathewsCorrAdj());
        statistics.add(new MathewsCorrArrow());
        statistics.add(new NumberOfEdgesEst());
        statistics.add(new NumberOfEdgesTrue());
        statistics.add(new NumCorrectVisibleAncestors());
        statistics.add(new PercentBidirectedEdges());
        statistics.add(new TailPrecision());
        statistics.add(new TailRecall());
        statistics.add(new TwoCyclePrecision());
        statistics.add(new TwoCycleRecall());
        statistics.add(new TwoCycleFalsePositive());
        statistics.add(new TwoCycleFalseNegative());
        statistics.add(new TwoCycleTruePositive());
        statistics.add(new AverageDegreeEst());
        statistics.add(new AverageDegreeTrue());
        statistics.add(new DensityEst());
        statistics.add(new DensityTrue());
        statistics.add(new StructuralHammingDistance());


        // Joe table.
        statistics.add(new NumDirectedEdges());
        statistics.add(new NumUndirectedEdges());
        statistics.add(new NumPartiallyOrientedEdges());
        statistics.add(new NumNondirectedEdges());
        statistics.add(new NumBidirectedEdgesEst());
        statistics.add(new TrueDagPrecisionTails());
        statistics.add(new TrueDagPrecisionArrow());
        statistics.add(new BidirectedLatentPrecision());

        // Greg table
//        statistics.add(new AncestorPrecision());
//        statistics.add(new AncestorRecall());
//        statistics.add(new AncestorF1());
//        statistics.add(new SemidirectedPrecision());
//        statistics.add(new SemidirectedRecall());
//        statistics.add(new SemidirectedPathF1());
//        statistics.add(new NoSemidirectedPrecision());
//        statistics.add(new NoSemidirectedRecall());
//        statistics.add(new NoSemidirectedF1());

        return statistics;
    }


    /**
     * Returns a misclassification comparison of two graphs. This includes both an edge misclassiifcation matrix as well
     * as an endpoint misclassification matrix.
     *
     * @param trueGraph   The true graph.
     * @param targetGraph The target graph.
     * @return The comparison string.
     */
    @NotNull
    public static String getMisclassificationTable(Graph trueGraph, Graph targetGraph) {
        return "Edge Misclassification Table:" +
                "\n" +
                MisclassificationUtils.edgeMisclassifications(targetGraph, trueGraph) +
                "\n\n" +
                "Endpoint Misclassification Table:" +
                "\n\n" +
                MisclassificationUtils.endpointMisclassification(targetGraph, trueGraph);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy