All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.graph.MisclassificationUtils Maven / Gradle / Ivy

The newest version!
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////

package edu.cmu.tetrad.graph;

import edu.cmu.tetrad.util.TextTable;

import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * Some utilities for generating misclassification tables for graphs.
 *
 * @author josephramsey
 * @version $Id: $Id
 */
public class MisclassificationUtils {

    /**
     * 

Constructor for MisclassificationUtils.

*/ private MisclassificationUtils() { } /** *

getIndex.

* * @param endpoint a {@link edu.cmu.tetrad.graph.Endpoint} object * @return a int */ public static int getIndex(Endpoint endpoint) { if (endpoint == Endpoint.CIRCLE) return 0; if (endpoint == Endpoint.ARROW) return 1; if (endpoint == Endpoint.TAIL) return 2; if (endpoint == null) return 3; throw new IllegalArgumentException(); } /** *

convertNodes.

* * @param edges a {@link java.util.Set} object * @param newVariables a {@link java.util.List} object * @return a {@link java.util.Set} object */ public static Set convertNodes(Set edges, List newVariables) { Set newEdges = new HashSet<>(); Graph convertedGraph = new EdgeListGraph(newVariables); for (Edge edge : edges) { Node node1 = convertedGraph.getNode(edge.getNode1().getName()); Node node2 = convertedGraph.getNode(edge.getNode2().getName()); if (node1 == null) { node1 = edge.getNode1(); if (!convertedGraph.containsNode(node1)) { convertedGraph.addNode(node1); } } if (node2 == null) { node2 = edge.getNode2(); if (!convertedGraph.containsNode(node2)) { convertedGraph.addNode(node2); } } Endpoint endpoint1 = edge.getEndpoint1(); Endpoint endpoint2 = edge.getEndpoint2(); Edge newEdge = new Edge(node1, node2, endpoint1, endpoint2); newEdges.add(newEdge); } return newEdges; } /** *

endpointMisclassification.

* * @param estGraph a {@link edu.cmu.tetrad.graph.Graph} object * @param refGraph a {@link edu.cmu.tetrad.graph.Graph} object * @return a {@link java.lang.String} object */ public static String endpointMisclassification(Graph estGraph, Graph refGraph) { List _nodes = refGraph.getNodes(); estGraph = GraphUtils.replaceNodes(estGraph, _nodes); refGraph = GraphUtils.replaceNodes(refGraph, _nodes); _nodes = estGraph.getNodes(); int[][] counts = new int[4][4]; for (int i = 0; i < _nodes.size(); i++) { for (int j = 0; j < _nodes.size(); j++) { if (i == j) continue; Endpoint endpoint1 = refGraph.getEndpoint(_nodes.get(i), _nodes.get(j)); Endpoint endpoint2 = estGraph.getEndpoint(_nodes.get(i), _nodes.get(j)); int index1 = MisclassificationUtils.getIndex(endpoint1); int index2 = MisclassificationUtils.getIndex(endpoint2); counts[index1][index2]++; } } TextTable table2 = new TextTable(5, 5); table2.setToken(0, 1, "-o"); table2.setToken(0, 2, "->"); table2.setToken(0, 3, "--"); table2.setToken(0, 4, "no endpoint"); table2.setToken(1, 0, "-o"); table2.setToken(2, 0, "->"); table2.setToken(3, 0, "--"); table2.setToken(4, 0, "no endpoint"); for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) { if (i == 3 && j == 3) table2.setToken(3 + 1, 3 + 1, "*"); else table2.setToken(i + 1, j + 1, "" + counts[i][j]); } } return table2.toString(); } /** *

edgeMisclassifications.

* * @param estGraph a {@link edu.cmu.tetrad.graph.Graph} object * @param refGraph a {@link edu.cmu.tetrad.graph.Graph} object * @return a {@link java.lang.String} object */ public static String edgeMisclassifications(Graph estGraph, Graph refGraph) { estGraph = GraphUtils.replaceNodes(estGraph, refGraph.getNodes()); StringBuilder builder = new StringBuilder(); TextTable table2 = new TextTable(9, 7); table2.setToken(1, 0, "---"); table2.setToken(2, 0, "o-o"); table2.setToken(3, 0, "o->"); table2.setToken(4, 0, "<-o"); table2.setToken(5, 0, "-->"); table2.setToken(6, 0, "<--"); table2.setToken(7, 0, "<->"); table2.setToken(8, 0, "no edge"); table2.setToken(0, 1, "---"); table2.setToken(0, 2, "o-o"); table2.setToken(0, 3, "o->"); table2.setToken(0, 4, "-->"); table2.setToken(0, 5, "<->"); table2.setToken(0, 6, "no edge"); int[][] counts = new int[8][6]; for (Edge est1 : estGraph.getEdges()) { Node x = est1.getNode1(); Node y = est1.getNode2(); Edge true1 = refGraph.getEdge(x, y); if (true1 == null) { true1 = new Edge(x, y, Endpoint.NULL, Endpoint.NULL); } Edge trueConvert = new Edge(x, y, true1.getProximalEndpoint(x), true1.getProximalEndpoint(y)); int m = MisclassificationUtils.getTypeLeft(trueConvert, est1); int n = MisclassificationUtils.getTypeTop(est1); counts[m][n]++; } for (Edge true1 : refGraph.getEdges()) { Node x = true1.getNode1(); Node y = true1.getNode2(); Edge est1 = estGraph.getEdge(x, y); if (est1 == null) { est1 = new Edge(x, y, Endpoint.NULL, Endpoint.NULL); } Edge estConvert = new Edge(x, y, est1.getProximalEndpoint(x), est1.getProximalEndpoint(y)); int m = MisclassificationUtils.getTypeLeft(true1, estConvert); int n = MisclassificationUtils.getTypeTop(estConvert); if (n == 5) { counts[m][5]++; } } for (int i = 0; i < 8; i++) { for (int j = 0; j < 6; j++) { if (i == 7 && j == 5) table2.setToken(7 + 1, 5 + 1, "*"); else table2.setToken(i + 1, j + 1, "" + counts[i][j]); } } builder.append("\n").append(table2); return builder.toString(); } private static int getTypeTop(Edge edgeTop) { if (edgeTop == null) { return 5; } Endpoint e1 = edgeTop.getEndpoint1(); Endpoint e2 = edgeTop.getEndpoint2(); if (e1 == Endpoint.TAIL && e2 == Endpoint.TAIL) { return 0; } if (e1 == Endpoint.CIRCLE && e2 == Endpoint.CIRCLE) { return 1; } if (e1 == Endpoint.CIRCLE && e2 == Endpoint.ARROW) { return 2; } if (e1 == Endpoint.ARROW && e2 == Endpoint.CIRCLE) { return 2; } if (e1 == Endpoint.TAIL && e2 == Endpoint.ARROW) { return 3; } if (e1 == Endpoint.ARROW && e2 == Endpoint.TAIL) { return 3; } if (e1 == Endpoint.ARROW && e2 == Endpoint.ARROW) { return 4; } return 5; // throw new IllegalArgumentException("Unsupported edge type : " + e1 + " " + e2); } private static int getTypeLeft(Edge edgeLeft, Edge edgeTop) { if (edgeLeft == null) { return 7; } Endpoint e1 = edgeLeft.getEndpoint1(); Endpoint e2 = edgeLeft.getEndpoint2(); if (e1 == Endpoint.TAIL && e2 == Endpoint.TAIL) { return 0; } if (e1 == Endpoint.CIRCLE && e2 == Endpoint.CIRCLE) { return 1; } if (e1 == Endpoint.CIRCLE && e2 == Endpoint.ARROW && edgeTop.equals(edgeLeft.reverse())) { return 3; } if (e1 == Endpoint.CIRCLE && e2 == Endpoint.ARROW) { return 2; } if (e1 == Endpoint.TAIL && e2 == Endpoint.ARROW && edgeTop.equals(edgeLeft.reverse())) { return 5; } if (e1 == Endpoint.TAIL && e2 == Endpoint.ARROW) { return 4; } if (e1 == Endpoint.ARROW && e2 == Endpoint.ARROW) { return 6; } if (e1 == Endpoint.NULL && e2 == Endpoint.NULL) { return 7; } throw new IllegalArgumentException("Unsupported edge type : " + e1 + " " + e2); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy