edu.cmu.tetrad.graph.GraphSaveLoadUtils Maven / Gradle / Ivy
The newest version!
package edu.cmu.tetrad.graph;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.SimpleDataLoader;
import edu.cmu.tetrad.util.*;
import edu.pitt.dbmi.data.reader.Data;
import edu.pitt.dbmi.data.reader.Delimiter;
import edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDatasetFileReader;
import nu.xom.*;
import java.io.*;
import java.nio.file.Files;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Methods to load or save graphs.
*
* @author josephramsey
* @version $Id: $Id
*/
public class GraphSaveLoadUtils {
/**
* Private constructor to prevent instantiation.
*/
private GraphSaveLoadUtils() {
}
/**
* loadGraph.
*
* @param file a {@link java.io.File} object
* @return a {@link edu.cmu.tetrad.graph.Graph} object
*/
public static Graph loadGraph(File file) {
Element root;
Graph graph;
try {
root = getRootElement(file);
graph = parseGraphXml(root, null);
} catch (ParsingException e1) {
throw new IllegalArgumentException("Could not parse " + file, e1);
} catch (IOException e1) {
throw new IllegalArgumentException("Could not read " + file, e1);
}
return graph;
}
/**
* loadGraphTxt.
*
* @param file a {@link java.io.File} object
* @return a {@link edu.cmu.tetrad.graph.Graph} object
*/
public static Graph loadGraphTxt(File file) {
try {
Reader in1 = new FileReader(file);
return readerToGraphTxt(in1);
} catch (Exception e) {
e.printStackTrace();
throw new IllegalStateException();
}
}
/**
* loadGraphRuben.
*
* @param file a {@link java.io.File} object
* @return a {@link edu.cmu.tetrad.graph.Graph} object
*/
public static Graph loadGraphRuben(File file) {
try {
final String commentMarker = "//";
final char quoteCharacter = '"';
final String missingValueMarker = "*";
final boolean hasHeader = false;
DataSet dataSet = SimpleDataLoader.loadContinuousData(file, commentMarker, quoteCharacter, missingValueMarker,
hasHeader, Delimiter.COMMA, false);
List nodes = dataSet.getVariables();
Graph graph = new EdgeListGraph(nodes);
for (int i = 0; i < nodes.size(); i++) {
for (int j = i + 1; j < nodes.size(); j++) {
if (dataSet.getDouble(i, j) != 0D) {
graph.addDirectedEdge(nodes.get(i), nodes.get(j));
}
}
}
return graph;
} catch (Exception e) {
e.printStackTrace();
throw new IllegalStateException();
}
}
/**
* loadGraphJson.
*
* @param file a {@link java.io.File} object
* @return a {@link edu.cmu.tetrad.graph.Graph} object
*/
public static Graph loadGraphJson(File file) {
try {
Reader in1 = new FileReader(file);
return readerToGraphJson(in1);
} catch (Exception e) {
e.printStackTrace();
}
throw new IllegalStateException();
}
private static int[][] incidenceMatrix(Graph graph) throws IllegalArgumentException {
List nodes = graph.getNodes();
int[][] m = new int[nodes.size()][nodes.size()];
for (Edge edge : graph.getEdges()) {
if (!Edges.isDirectedEdge(edge)) {
throw new IllegalArgumentException("Not a directed graph.");
}
}
for (int i = 0; i < nodes.size(); i++) {
for (int j = 0; j < nodes.size(); j++) {
Node x1 = nodes.get(i);
Node x2 = nodes.get(j);
Edge edge = graph.getEdge(x1, x2);
if (edge == null) {
m[i][j] = 0;
} else if (edge.getProximalEndpoint(x1) == Endpoint.ARROW) {
m[i][j] = 1;
} else if (edge.getProximalEndpoint(x1) == Endpoint.TAIL) {
m[i][j] = -1;
}
}
}
return m;
}
// Bayes net toolbox.
/**
* loadGraphBNTPcMatrix.
*
* @param vars a {@link java.util.List} object
* @param dataSet a {@link edu.cmu.tetrad.data.DataSet} object
* @return a {@link edu.cmu.tetrad.graph.Graph} object
*/
public static Graph loadGraphBNTPcMatrix(List vars, DataSet dataSet) {
Graph graph = new EdgeListGraph(vars);
for (int i = 0; i < dataSet.getNumRows(); i++) {
for (int j = 0; j < dataSet.getNumColumns(); j++) {
int g = dataSet.getInt(i, j);
int h = dataSet.getInt(j, i);
if (g == 1 && h == 1 && !graph.isAdjacentTo(vars.get(i), vars.get(j))) {
graph.addUndirectedEdge(vars.get(i), vars.get(j));
} else if (g == -1 && h == 0) {
graph.addDirectedEdge(vars.get(i), vars.get(j));
}
}
}
return graph;
}
/**
* graphRMatrixTxt.
*
* @param graph a {@link edu.cmu.tetrad.graph.Graph} object
* @return a {@link java.lang.String} object
* @throws java.lang.IllegalArgumentException if any.
*/
public static String graphRMatrixTxt(Graph graph) throws IllegalArgumentException {
int[][] m = GraphSaveLoadUtils.incidenceMatrix(graph);
TextTable table = new TextTable(m[0].length + 1, m.length + 1);
for (int i = 0; i < m.length; i++) {
for (int j = 0; j < m[0].length; j++) {
table.setToken(i + 1, j + 1, String.valueOf(m[i][j]));
}
}
for (int i = 0; i < m.length; i++) {
table.setToken(i + 1, 0, String.valueOf(i + 1));
}
List nodes = graph.getNodes();
for (int j = 0; j < m[0].length; j++) {
table.setToken(0, j + 1, nodes.get(j).getName());
}
return table.toString();
}
/**
* loadRSpecial.
*
* @param file a {@link java.io.File} object
* @return a {@link edu.cmu.tetrad.graph.Graph} object
*/
public static Graph loadRSpecial(File file) {
DataSet eg = null;
try {
ContinuousTabularDatasetFileReader reader = new ContinuousTabularDatasetFileReader(file.toPath(), Delimiter.COMMA);
reader.setHasHeader(false);
Data data = reader.readInData();
eg = (DataSet) DataConvertUtils.toDataModel(data);
} catch (IOException ioException) {
throw new RuntimeException("Error reading from file.", ioException);
}
if (eg == null) throw new NullPointerException();
List vars = eg.getVariables();
Graph graph = new EdgeListGraph(vars);
for (int i = 0; i < vars.size(); i++) {
for (int j = 0; j < vars.size(); j++) {
if (i == j) continue;
if (eg.getDouble(i, j) == 1 && eg.getDouble(j, i) == 1) {
if (!graph.isAdjacentTo(vars.get(i), vars.get(j))) {
graph.addUndirectedEdge(vars.get(i), vars.get(j));
}
} else if (eg.getDouble(i, j) == 1 && eg.getDouble(j, i) == 0) {
graph.addDirectedEdge(vars.get(i), vars.get(j));
}
}
}
return graph;
}
/**
* Loads a CPDAG in the "amat.cpdag" format of PCALG. We will assume here that the graph in R has been saved to disk
* using the write.table(mat, path) method. For the amat.cpdag format, for a matrix m, there are two cases where
* edges occur in the graph: (1) m[i][j] = 0 and m[j][i] = 1, in which case an edge i->j exists; or, (2) m[i][j] = 1
* and m[j][i] = 1, in which case an undirected edge i--j exists. In all other cases, there is no edge between i and
* j.
*
* @param file a file in the "amat.cpdag" format of PCALG.
* @return a graph.
*/
public static Graph loadGraphAmatCpdag(File file) {
try {
try (BufferedReader reader = new BufferedReader(new FileReader(file))) {
String varNames = reader.readLine();
String[] tokens = varNames.split("[ \t\"]+");
List nodes = new ArrayList<>();
for (String token : tokens) {
if (!token.isBlank()) {
nodes.add(new GraphNode(token));
}
}
Graph graph = new EdgeListGraph(nodes);
int[][] m = new int[nodes.size()][nodes.size()];
for (int i = 0; i < nodes.size(); i++) {
String line = reader.readLine();
tokens = line.split("[ \t]+");
for (int j = 1; j <= nodes.size(); j++) {
m[i][j - 1] = Integer.parseInt(tokens[j]);
}
}
for (int i = 0; i < nodes.size(); i++) {
for (int j = 0; j < nodes.size(); j++) {
Node n1 = nodes.get(i);
Node n2 = nodes.get(j);
int e1 = m[i][j];
int e2 = m[j][i];
if (e1 == 0 && e2 == 1) {
graph.addDirectedEdge(n1, n2);
} else if (e1 == 1 && e2 == 1) {
graph.addUndirectedEdge(n1, n2);
}
}
}
return graph;
}
} catch (IOException e) {
throw new RuntimeException("Error reading from file.", e);
}
}
/**
* Loads a PAG in the "amat.pag" format of PCALG. We will assume here that the graph in R has been saved to disk
* using the write.table(mat, path) method. For the amat.pag format, for a matrix m, endpoints are explicitly
* represented, as follows. 1 is a circle endpoint, 2 is an arrow endpoint, 3 is a tail endpoint, and 0 is a null
* endpoint (i.e., no edge). For an edge i->j, m[i][j] = 2 and m[j][i] = 3.
*
* @param file a file in the "amat.cpdag" format of PCALG.
* @return a graph.
*/
public static Graph loadGraphAmatPag(File file) {
try {
String fileName = "example.txt";
// Use try-with-resources to ensure that the file is closed after reading
try (BufferedReader reader = new BufferedReader(new FileReader(file))) {
String varNames = reader.readLine();
String[] tokens = varNames.split("[ \t\"]+");
List nodes = new ArrayList<>();
for (String token : tokens) {
if (!token.isBlank()) {
nodes.add(new GraphNode(token));
}
}
Graph graph = new EdgeListGraph(nodes);
int[][] m = new int[nodes.size()][nodes.size()];
for (int i = 0; i < nodes.size(); i++) {
String line = reader.readLine();
tokens = line.split("[ \t]+");
for (int j = 1; j <= nodes.size(); j++) {
m[i][j - 1] = Integer.parseInt(tokens[j]);
}
}
for (int i = 0; i < nodes.size(); i++) {
for (int j = i + 1; j < nodes.size(); j++) {
Node n1 = nodes.get(i);
Node n2 = nodes.get(j);
int e1 = m[i][j];
int e2 = m[j][i];
Endpoint e1a = switch (e1) {
case 0 -> Endpoint.NULL;
case 1 -> Endpoint.CIRCLE;
case 2 -> Endpoint.ARROW;
case 3 -> Endpoint.TAIL;
default -> throw new IllegalArgumentException("Unexpected endpoint type: " + e1);
};
Endpoint e2a = switch (e2) {
case 0 -> Endpoint.NULL;
case 1 -> Endpoint.CIRCLE;
case 2 -> Endpoint.ARROW;
case 3 -> Endpoint.TAIL;
default -> throw new IllegalArgumentException("Unexpected endpoint type: " + e1);
};
if (e1a != Endpoint.NULL && e2a != Endpoint.NULL) {
Edge edge = new Edge(n1, n2, e1a, e2a);
graph.addEdge(edge);
} else if (e1a != Endpoint.NULL || e2a != Endpoint.NULL) {
throw new IllegalArgumentException("Invalid endpoint combination: " + e1a + " " + e2a);
}
}
}
return graph;
}
} catch (IOException e) {
throw new RuntimeException("Error reading from file.", e);
}
}
// public static Graph loadGraphPcalg(File file) {
// try {
// DataSet dataSet = SimpleDataLoader.loadContinuousData(file, "//", '\"',
// "*", true, Delimiter.COMMA, false);
//
// List nodes = dataSet.getVariables();
// Graph graph = new EdgeListGraph(nodes);
//
// for (int i = 0; i < nodes.size(); i++) {
// for (int j = i + 1; j < nodes.size(); j++) {
// Node n1 = nodes.get(i);
// Node n2 = nodes.get(j);
//
// int e1 = dataSet.getInt(j, i);
// int e2 = dataSet.getInt(i, j);
//
// Endpoint e1a;
//
// switch (e1) {
// case 0:
// e1a = Endpoint.NULL;
// break;
// case 1:
// e1a = Endpoint.CIRCLE;
// break;
// case 2:
// e1a = Endpoint.ARROW;
// break;
// case 3:
// e1a = Endpoint.TAIL;
// break;
// default:
// throw new IllegalArgumentException("Unexpected endpoint type: " + e1);
// }
//
// Endpoint e2a;
//
// switch (e2) {
// case 0:
// e2a = Endpoint.NULL;
// break;
// case 1:
// e2a = Endpoint.CIRCLE;
// break;
// case 2:
// e2a = Endpoint.ARROW;
// break;
// case 3:
// e2a = Endpoint.TAIL;
// break;
// default:
// throw new IllegalArgumentException("Unexpected endpoint type: " + e1);
// }
//
// if (e1a != Endpoint.NULL && e2a != Endpoint.NULL) {
// Edge edge = new Edge(n1, n2, e1a, e2a);
// graph.addEdge(edge);
// }
// }
// }
//
// return graph;
// } catch (Exception e) {
// e.printStackTrace();
// throw new IllegalStateException();
// }
// }
/**
* loadGraphRMatrix.
*
* @param graph a {@link edu.cmu.tetrad.graph.Graph} object
* @return a {@link java.lang.String} object
* @throws java.lang.IllegalArgumentException if any.
*/
public static String loadGraphRMatrix(Graph graph) throws IllegalArgumentException {
int[][] m = GraphSaveLoadUtils.incidenceMatrix(graph);
TextTable table = new TextTable(m[0].length + 1, m.length + 1);
for (int i = 0; i < m.length; i++) {
for (int j = 0; j < m[0].length; j++) {
table.setToken(i + 1, j + 1, String.valueOf(m[i][j]));
}
}
for (int i = 0; i < m.length; i++) {
table.setToken(i + 1, 0, String.valueOf(i + 1));
}
List nodes = graph.getNodes();
for (int j = 0; j < m[0].length; j++) {
table.setToken(0, j + 1, nodes.get(j).getName());
}
return table.toString();
}
/**
* readerToGraphTxt.
*
* @param graphString a {@link java.lang.String} object
* @return a {@link edu.cmu.tetrad.graph.Graph} object
* @throws java.io.IOException if any.
*/
public static Graph readerToGraphTxt(String graphString) throws IOException {
return readerToGraphTxt(new CharArrayReader(graphString.toCharArray()));
}
/**
* readerToGraphTxt.
*
* @param reader a {@link java.io.Reader} object
* @return a {@link edu.cmu.tetrad.graph.Graph} object
* @throws java.io.IOException if any.
*/
public static Graph readerToGraphTxt(Reader reader) throws IOException {
Graph graph = new EdgeListGraph();
try (BufferedReader in = new BufferedReader(reader)) {
for (String line = in.readLine(); line != null; line = in.readLine()) {
line = line.trim();
switch (line) {
case "Graph Nodes:":
extractGraphNodes(graph, in);
break;
case "Graph Edges:":
extractGraphEdges(graph, in);
break;
}
}
}
return graph;
}
/**
* saveGraph.
*
* @param graph The graph to be saved.
* @param file The file to save it in.
* @param xml True if to be saved in XML, false if in text.
*/
public static void saveGraph(Graph graph, File file, boolean xml) {
PrintWriter out;
try {
out = new PrintWriter(Files.newOutputStream(file.toPath()));
if (xml) {
out.print(graphToXml(graph));
} else {
out.print(graph);
}
out.flush();
out.close();
} catch (IOException e1) {
throw new IllegalArgumentException("Output file could not be opened: " + file);
}
}
/**
* readerToGraphRuben.
*
* @param reader a {@link java.io.Reader} object
* @return a {@link edu.cmu.tetrad.graph.Graph} object
* @throws java.io.IOException if any.
*/
public static Graph readerToGraphRuben(Reader reader) throws IOException {
Graph graph = new EdgeListGraph();
try (BufferedReader in = new BufferedReader(reader)) {
for (String line = in.readLine(); line != null; line = in.readLine()) {
line = line.trim();
switch (line) {
case "Graph Nodes:":
extractGraphNodes(graph, in);
break;
case "Graph Edges:":
extractGraphEdges(graph, in);
break;
}
}
}
return graph;
}
private static void extractGraphEdges(Graph graph, BufferedReader in) throws IOException {
Pattern lineNumPattern = Pattern.compile("^\\d+.\\s?");
Pattern spacePattern = Pattern.compile("\\s+");
Pattern semicolonPattern = Pattern.compile(";");
Pattern colonPattern = Pattern.compile(":");
for (String line = in.readLine(); line != null; line = in.readLine()) {
line = line.trim();
if (line.isEmpty()) {
return;
}
line = lineNumPattern.matcher(line).replaceAll("");
String[] fields = spacePattern.split(line, 4);
Edge edge = getEdge(fields[0], fields[1], fields[2], graph);
if (fields.length > 3) {
fields = semicolonPattern.split(fields[3]);
if (fields.length > 1) {
for (String prop : fields) {
setEdgeTypeProperties(edge, prop, spacePattern, colonPattern);
}
} else {
getEdgeProperties(fields[0], spacePattern)
.forEach(edge::addProperty);
}
}
graph.addEdge(edge);
}
}
private static void setEdgeTypeProperties(Edge edge, String prop, Pattern spacePattern, Pattern colonPattern) {
prop = prop.replace("[", "").replace("]", "");
String[] fields = colonPattern.split(prop);
if (fields.length == 2) {
String bootstrapEdge = fields[0];
String bootstrapEdgeTypeProb = fields[1];
// edge type
fields = spacePattern.split(bootstrapEdge, 4);
if (fields.length >= 3) {
// edge-type probability
EdgeTypeProbability.EdgeType edgeType = getEdgeType(fields[1]);
List properties = new LinkedList<>();
if (fields.length > 3) {
// pags
properties.addAll(getEdgeProperties(fields[3], spacePattern));
}
edge.addEdgeTypeProbability(new EdgeTypeProbability(edgeType, properties, Double.parseDouble(bootstrapEdgeTypeProb)));
} else {
// edge probability
if ("edge".equals(bootstrapEdge)) {
fields = spacePattern.split(bootstrapEdgeTypeProb, 2);
if (fields.length > 1) {
edge.setProbability(Double.parseDouble(fields[0]));
getEdgeProperties(fields[1], spacePattern).forEach(edge::addProperty);
} else {
edge.setProbability(Double.parseDouble(bootstrapEdgeTypeProb));
}
} else if ("no edge".equals(bootstrapEdge)) {
edge.addEdgeTypeProbability(new EdgeTypeProbability(EdgeTypeProbability.EdgeType.nil, Double.parseDouble(bootstrapEdgeTypeProb)));
}
}
}
}
private static EdgeTypeProbability.EdgeType getEdgeType(String edgeType) {
Endpoint endpointFrom = getEndpoint(edgeType.charAt(0));
Endpoint endpointTo = getEndpoint(edgeType.charAt(2));
if (endpointFrom == Endpoint.TAIL && endpointTo == Endpoint.ARROW) {
return EdgeTypeProbability.EdgeType.ta;
} else if (endpointFrom == Endpoint.ARROW && endpointTo == Endpoint.TAIL) {
return EdgeTypeProbability.EdgeType.at;
} else if (endpointFrom == Endpoint.CIRCLE && endpointTo == Endpoint.ARROW) {
return EdgeTypeProbability.EdgeType.ca;
} else if (endpointFrom == Endpoint.ARROW && endpointTo == Endpoint.CIRCLE) {
return EdgeTypeProbability.EdgeType.ac;
} else if (endpointFrom == Endpoint.CIRCLE && endpointTo == Endpoint.CIRCLE) {
return EdgeTypeProbability.EdgeType.cc;
} else if (endpointFrom == Endpoint.ARROW && endpointTo == Endpoint.ARROW) {
return EdgeTypeProbability.EdgeType.aa;
} else if (endpointFrom == Endpoint.TAIL && endpointTo == Endpoint.TAIL) {
return EdgeTypeProbability.EdgeType.tt;
} else {
return EdgeTypeProbability.EdgeType.nil;
}
}
private static List getEdgeProperties(String props, Pattern spacePattern) {
List properties = new LinkedList<>();
for (String prop : spacePattern.split(props)) {
if ("dd".equals(prop)) {
properties.add(Edge.Property.dd);
} else if ("nl".equals(prop)) {
properties.add(Edge.Property.nl);
} else if ("pd".equals(prop)) {
properties.add(Edge.Property.pd);
} else if ("pl".equals(prop)) {
properties.add(Edge.Property.pl);
}
}
return properties;
}
private static void extractGraphNodes(Graph graph, BufferedReader in) throws IOException {
for (String line = in.readLine(); line != null; line = in.readLine()) {
line = line.trim();
if (line.isEmpty()) {
break;
}
String[] tokens = line.split("[,;]");
for (String token : tokens) {
if (token.startsWith("(") && token.endsWith(")")) {
token = token.replace("(", "");
token = token.replace(")", "");
Node node = new GraphNode(token);
node.setNodeType(NodeType.LATENT);
graph.addNode(node);
} else {
Node node = new GraphNode(token);
node.setNodeType(NodeType.MEASURED);
graph.addNode(node);
}
}
// Arrays.stream(line.split("[,;]")).map(GraphNode::new).forEach(graph::addNode);
}
}
/**
* readerToGraphJson.
*
* @param reader a {@link java.io.Reader} object
* @return a {@link edu.cmu.tetrad.graph.Graph} object
* @throws java.io.IOException if any.
*/
public static Graph readerToGraphJson(Reader reader) throws IOException {
BufferedReader in = new BufferedReader(reader);
StringBuilder json = new StringBuilder();
String line;
while ((line = in.readLine()) != null) {
json.append(line.trim());
}
return JsonUtils.parseJSONObjectToTetradGraph(json.toString());
}
/**
* Converts a graph to a Graphviz .dot file
*
* @param graph a {@link edu.cmu.tetrad.graph.Graph} object
* @return a {@link java.lang.String} object
*/
public static String graphToDot(Graph graph) {
StringBuilder builder = new StringBuilder();
builder.append("digraph g {\n");
List edges = new ArrayList<>(graph.getEdges());
Collections.sort(edges);
for (Edge edge : edges) {
String n1 = edge.getNode1().getName();
String n2 = edge.getNode2().getName();
Endpoint end1 = edge.getEndpoint1();
Endpoint end2 = edge.getEndpoint2();
// These may be in the graph, but they represent edges not in the ensemble for which
// bootstrap information is available.
if (end1 == Endpoint.NULL || end2 == Endpoint.NULL) continue;
builder.append(" \"").append(n1).append("\" -> \"").append(n2).append("\" [");
if (end1 != Endpoint.TAIL) {
builder.append("dir=both, ");
}
builder.append("arrowtail=");
if (end1 == Endpoint.ARROW) {
builder.append("normal");
} else if (end1 == Endpoint.TAIL) {
builder.append("none");
} else if (end1 == Endpoint.CIRCLE) {
builder.append("odot");
} else {
builder.append("xdot");
}
builder.append(", arrowhead=");
if (end2 == Endpoint.ARROW) {
builder.append("normal");
} else if (end2 == Endpoint.TAIL) {
builder.append("none");
} else if (end2 == Endpoint.CIRCLE) {
builder.append("odot");
} else {
builder.append("xdot");
}
// Bootstrapping
List edgeTypeProbabilities = edge.getEdgeTypeProbabilities();
if (edgeTypeProbabilities != null && !edgeTypeProbabilities.isEmpty()) {
StringBuilder label = new StringBuilder(n1 + " - " + n2);
for (EdgeTypeProbability edgeTypeProbability : edgeTypeProbabilities) {
EdgeTypeProbability.EdgeType edgeType = edgeTypeProbability.getEdgeType();
double probability = edgeTypeProbability.getProbability();
if (probability > 0) {
StringBuilder edgeTypeString = new StringBuilder();
switch (edgeType) {
case nil:
edgeTypeString = new StringBuilder("no edge");
break;
case ta:
edgeTypeString = new StringBuilder("-->");
break;
case at:
edgeTypeString = new StringBuilder("<--");
break;
case ca:
edgeTypeString = new StringBuilder("o->");
break;
case ac:
edgeTypeString = new StringBuilder("<-o");
break;
case cc:
edgeTypeString = new StringBuilder("o-o");
break;
case aa:
edgeTypeString = new StringBuilder("<->");
break;
case tt:
edgeTypeString = new StringBuilder("---");
break;
}
List properties = edgeTypeProbability.getProperties();
if (properties != null && properties.size() > 0) {
for (Edge.Property property : properties) {
edgeTypeString.append(" ").append(property.toString());
}
}
NumberFormat nf = new DecimalFormat("0.000");
label.append("\\n[").append(edgeTypeString).append("]:").append(nf.format(edgeTypeProbability.getProbability()));
}
}
builder.append(", label=\"").append(label).append("\", fontname=courier");
}
builder.append("]; \n");
}
builder.append("}");
return builder.toString();
}
/**
* graphToDot.
*
* @param graph a {@link edu.cmu.tetrad.graph.Graph} object
* @param file a {@link java.io.File} object
*/
public static void graphToDot(Graph graph, File file) {
try {
Writer writer = new FileWriter(file);
writer.write(graphToDot(graph));
writer.close();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* convertToXml.
*
* @param graph a {@link edu.cmu.tetrad.graph.Graph} object
* @return an XML element representing the given graph. (Well, only a basic graph for now...)
*/
public static Element convertToXml(Graph graph) {
Element element = new Element("graph");
Element variables = new Element("variables");
element.appendChild(variables);
for (Node node : graph.getNodes()) {
Element variable = new Element("variable");
Text text = new Text(node.getName());
variable.appendChild(text);
variables.appendChild(variable);
}
Element edges = new Element("edges");
element.appendChild(edges);
for (Edge edge : graph.getEdges()) {
Element _edge = new Element("edge");
Text text = new Text(edge.toString());
_edge.appendChild(text);
edges.appendChild(_edge);
}
Set ambiguousTriples = graph.getAmbiguousTriples();
if (!ambiguousTriples.isEmpty()) {
Element underlinings = new Element("ambiguities");
element.appendChild(underlinings);
for (Triple triple : ambiguousTriples) {
Element underlining = new Element("ambiguities");
Text text = new Text(niceTripleString(triple));
underlining.appendChild(text);
underlinings.appendChild(underlining);
}
}
Set underlineTriples = graph.getUnderLines();
if (!underlineTriples.isEmpty()) {
Element underlinings = new Element("underlines");
element.appendChild(underlinings);
for (Triple triple : underlineTriples) {
Element underlining = new Element("underline");
Text text = new Text(niceTripleString(triple));
underlining.appendChild(text);
underlinings.appendChild(underlining);
}
}
Set dottedTriples = graph.getDottedUnderlines();
if (!dottedTriples.isEmpty()) {
Element dottedUnderlinings = new Element("dottedUnderlines");
element.appendChild(dottedUnderlinings);
for (Triple triple : dottedTriples) {
Element dottedUnderlining = new Element("dottedUnderline");
Text text = new Text(niceTripleString(triple));
dottedUnderlining.appendChild(text);
dottedUnderlinings.appendChild(dottedUnderlining);
}
}
return element;
}
private static String niceTripleString(Triple triple) {
return triple.getX() + ", " + triple.getY() + ", " + triple.getZ();
}
/**
* graphToXml.
*
* @param graph a {@link edu.cmu.tetrad.graph.Graph} object
* @return a {@link java.lang.String} object
*/
public static String graphToXml(Graph graph) {
Document document = new Document(convertToXml(graph));
OutputStream out = new ByteArrayOutputStream();
Serializer serializer = new Serializer(out);
serializer.setLineSeparator("\n");
serializer.setIndent(2);
try {
serializer.write(document);
} catch (IOException e) {
throw new RuntimeException(e);
}
return out.toString();
}
/**
* graphToLavaan.
*
* @param g a {@link edu.cmu.tetrad.graph.Graph} object
* @return a {@link java.lang.String} object
*/
public static String graphToLavaan(Graph g) {
boolean includeIntercepts = true;
boolean includeErrors = true;
Map> parents = new HashMap<>();
Map> siblings = new HashMap<>();
StringBuilder lavaan = new StringBuilder();
for (Node a : g.getNodes()) {
if (includeIntercepts) lavaan.append(a.getName()).append(" ~ 1\n");
parents.put(a, new ArrayList<>());
siblings.put(a, new ArrayList<>());
for (Edge e : g.getEdges(a)) {
Node b = e.getDistalNode(a);
if (e.getProximalEndpoint(a) != Endpoint.ARROW) continue;
if (e.getProximalEndpoint(b) == Endpoint.TAIL) parents.get(a).add(b);
if (siblings.containsKey(b)) continue;
if (e.getProximalEndpoint(b) == Endpoint.ARROW) siblings.get(a).add(b);
}
}
if (includeIntercepts) lavaan.append("\n");
boolean hasDirected = false;
for (Node a : g.getNodes()) {
Iterator itr = parents.get(a).iterator();
if (itr.hasNext()) {
hasDirected = true;
lavaan.append(a.getName()).append(" ~ ").append(itr.next().getName());
} else continue;
while (itr.hasNext()) lavaan.append(" + ").append(itr.next().getName());
lavaan.append("\n");
}
if (hasDirected) lavaan.append("\n");
boolean hasBidirected = false;
for (Node a : g.getNodes()) {
Iterator itr = siblings.get(a).iterator();
if (itr.hasNext()) {
hasBidirected = true;
lavaan.append(a.getName()).append(" ~~ ").append(itr.next().getName());
} else continue;
while (itr.hasNext()) lavaan.append(" + ").append(itr.next().getName());
lavaan.append("\n");
}
if (hasBidirected) lavaan.append("\n");
for (Node a : g.getNodes()) {
if (includeErrors) lavaan.append(a.getName()).append(" ~~ ").append(a.getName()).append("\n");
}
return lavaan.toString();
}
/**
* graphToPcalg.
*
* @param g a {@link edu.cmu.tetrad.graph.Graph} object
* @return a {@link java.lang.String} object
*/
public static String graphToPcalg(Graph g) {
Map mark2Int = new HashMap<>();
mark2Int.put(Endpoint.NULL, 0);
mark2Int.put(Endpoint.CIRCLE, 1);
mark2Int.put(Endpoint.ARROW, 2);
mark2Int.put(Endpoint.TAIL, 3);
int n = g.getNumNodes();
int[][] A = new int[n][n];
List nodes = g.getNodes();
for (Edge edge : g.getEdges()) {
int i = nodes.indexOf(edge.getNode1());
int j = nodes.indexOf(edge.getNode2());
A[j][i] = mark2Int.get(edge.getEndpoint1());
A[i][j] = mark2Int.get(edge.getEndpoint2());
}
TextTable table = new TextTable(n + 1, n);
table.setDelimiter(TextTable.Delimiter.COMMA);
for (int j = 0; j < n; j++) {
table.setToken(0, j, nodes.get(j).getName());
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
table.setToken(i + 1, j, String.valueOf(A[i][j]));
}
}
return table.toString();
}
/**
* Converts a given graph into an adjacency matrix in CPAG format.
*
* @param g the input graph to be converted
* @return the adjacency matrix representation of the graph in CPAG format
* @throws IllegalArgumentException if the graph is not a MPDAG (including CPDAG or DAG)
*/
public static String graphToAmatCpag(Graph g) {
// if (!(g.paths().isLegalMpdag())) {
// throw new IllegalArgumentException("Graph is not a MPDAG (including CPDAG or DAG).");
// }
List vars = g.getNodes();
int[][] m = new int[vars.size()][vars.size()];
for (int i = 0; i < vars.size(); i++) {
for (int j = 0; j < vars.size(); j++) {
if (i == j) {
continue;
}
Node node1 = vars.get(i);
Node node2 = vars.get(j);
if (g.isAdjacentTo(node1, node2)) {
Edge edge = g.getEdge(node1, node2);
if (Edges.isDirectedEdge(edge)) {
if (edge.pointsTowards(node2)) {
m[j][i] = 1;
}
} else if (Edges.isUndirectedEdge(edge)) {
m[i][j] = 1;
m[j][i] = 1;
}
}
}
}
StringBuilder sb = new StringBuilder();
for (Node node : vars) {
sb.append("\"").append(node.getName()).append("\" ");
}
sb.append("\n");
for (int i = 0; i < vars.size(); i++) {
sb.append("\"").append(vars.get(i).getName()).append("\" ");
for (int j = 0; j < vars.size(); j++) {
sb.append(m[i][j]).append(" ");
}
sb.append("\n");
}
return sb.toString();
}
/**
* Saves a PAG in the "amat.pag" format of PCALG. We will save it in the form that R would print the matrix to file
* using write.matrix(mat, path). For the amat.pag format, for a matrix m, endpoints are explicitly represented, as
* follows. 1 is a circle endpoint, 2 is an arrow endpoint, 3 is a tail endpoint, and 0 is a null endpoint (i.e., no
* edge)
*
* @param g a {@link edu.cmu.tetrad.graph.Graph} object
* @return a {@link java.lang.String} object
*/
public static String graphToAmatPag(Graph g) {
// if (!(g.paths().isLegalPag() || g.paths().isLegalMag())) {
// throw new IllegalArgumentException("Graph is not a PAG or MAG.");
// }
List vars = g.getNodes();
int[][] m = new int[vars.size()][vars.size()];
for (int i = 0; i < vars.size(); i++) {
for (int j = 0; j < vars.size(); j++) {
if (i == j) {
continue;
}
Node node1 = vars.get(i);
Node node2 = vars.get(j);
if (g.isAdjacentTo(node1, node2)) {
Edge edge = g.getEdge(node1, node2);
Endpoint endpoint1 = edge.getEndpoint1();
if (endpoint1 == Endpoint.CIRCLE) {
m[j][i] = 1;
} else if (endpoint1 == Endpoint.ARROW) {
m[j][i] = 2;
} else if (endpoint1 == Endpoint.TAIL) {
m[j][i] = 3;
} else {
m[j][i] = 0;
}
Endpoint endpoint2 = edge.getEndpoint2();
if (endpoint2 == Endpoint.CIRCLE) {
m[i][j] = 1;
} else if (endpoint2 == Endpoint.ARROW) {
m[i][j] = 2;
} else if (endpoint2 == Endpoint.TAIL) {
m[i][j] = 3;
} else {
m[i][j] = 0;
}
}
}
}
StringBuilder sb = new StringBuilder();
for (Node node : vars) {
sb.append("\"").append(node.getName()).append("\" ");
}
sb.append("\n");
for (int i = 0; i < vars.size(); i++) {
sb.append("\"").append(vars.get(i).getName()).append("\" ");
for (int j = 0; j < vars.size(); j++) {
sb.append(m[i][j]).append(" ");
}
sb.append("\n");
}
return sb.toString();
}
/**
* parseGraphXml.
*
* @param graphElement a {@link nu.xom.Element} object
* @param nodes a {@link java.util.Map} object
* @return a {@link edu.cmu.tetrad.graph.Graph} object
* @throws nu.xom.ParsingException if any.
*/
public static Graph parseGraphXml(Element graphElement, Map nodes) throws ParsingException {
if (!"graph".equals(graphElement.getLocalName())) {
throw new IllegalArgumentException("Expecting graph element: " + graphElement.getLocalName());
}
if (!("variables".equals(graphElement.getChildElements().get(0).getLocalName()))) {
throw new ParsingException("Expecting variables element: " + graphElement.getChildElements().get(0).getLocalName());
}
Element variablesElement = graphElement.getChildElements().get(0);
Elements variableElements = variablesElement.getChildElements();
List variables = new ArrayList<>();
for (int i = 0; i < variableElements.size(); i++) {
Element variableElement = variableElements.get(i);
if (!("variable".equals(variablesElement.getChildElements().get(i).getLocalName()))) {
throw new ParsingException("Expecting variable element.");
}
String value = variableElement.getValue();
if (nodes == null) {
variables.add(new GraphNode(value));
} else {
variables.add(nodes.get(value));
}
}
Graph graph = new EdgeListGraph(variables);
// graphNotes.add(noteAttribute.getValue());
if (!("edges".equals(graphElement.getChildElements().get(1).getLocalName()))) {
throw new ParsingException("Expecting edges element.");
}
Element edgesElement = graphElement.getChildElements().get(1);
Elements edgesElements = edgesElement.getChildElements();
for (int i = 0; i < edgesElements.size(); i++) {
Element edgeElement = edgesElements.get(i);
if (!("edge".equals(edgeElement.getLocalName()))) {
throw new ParsingException("Expecting edge element: " + edgeElement.getLocalName());
}
String value = edgeElement.getValue();
final String regex = "([A-Za-z0-9_-]*:?[A-Za-z0-9_-]*) ?(.)-(.) ?([A-Za-z0-9_-]*:?[A-Za-z0-9_-]*)";
// String regex = "([A-Za-z0-9_-]*) ?([]) ?([A-Za-z0-9_-]*)";
java.util.regex.Pattern pattern = java.util.regex.Pattern.compile(regex);
Matcher matcher = pattern.matcher(value);
if (!matcher.matches()) {
throw new ParsingException("Edge doesn't match pattern.");
}
String var1 = matcher.group(1);
String leftEndpoint = matcher.group(2);
String rightEndpoint = matcher.group(3);
String var2 = matcher.group(4);
Node node1 = graph.getNode(var1);
Node node2 = graph.getNode(var2);
Endpoint endpoint1;
switch (leftEndpoint) {
case "<":
endpoint1 = Endpoint.ARROW;
break;
case "o":
endpoint1 = Endpoint.CIRCLE;
break;
case "-":
endpoint1 = Endpoint.TAIL;
break;
default:
throw new IllegalStateException("Expecting an endpoint: " + leftEndpoint);
}
Endpoint endpoint2;
switch (rightEndpoint) {
case ">":
endpoint2 = Endpoint.ARROW;
break;
case "o":
endpoint2 = Endpoint.CIRCLE;
break;
case "-":
endpoint2 = Endpoint.TAIL;
break;
default:
throw new IllegalStateException("Expecting an endpoint: " + rightEndpoint);
}
Edge edge = new Edge(node1, node2, endpoint1, endpoint2);
graph.addEdge(edge);
}
int size = graphElement.getChildElements().size();
if (2 >= size) {
return graph;
}
int p = 2;
if ("ambiguities".equals(graphElement.getChildElements().get(p).getLocalName())) {
Element ambiguitiesElement = graphElement.getChildElements().get(p);
Set triples = parseTriples(variables, ambiguitiesElement, "ambiguity");
graph.setAmbiguousTriples(triples);
p++;
}
if (p >= size) {
return graph;
}
if ("underlines".equals(graphElement.getChildElements().get(p).getLocalName())) {
Element ambiguitiesElement = graphElement.getChildElements().get(p);
Set triples = parseTriples(variables, ambiguitiesElement, "underline");
graph.setUnderLineTriples(triples);
p++;
}
if (p >= size) {
return graph;
}
if ("dottedunderlines".equals(graphElement.getChildElements().get(p).getLocalName())) {
Element ambiguitiesElement = graphElement.getChildElements().get(p);
Set triples = parseTriples(variables, ambiguitiesElement, "dottedunderline");
graph.setDottedUnderLineTriples(triples);
}
return graph;
}
/**
* A triples element has a list of three (comman separated) nodes as text.
*/
private static Set parseTriples(List variables, Element triplesElement, String s) {
Elements elements = triplesElement.getChildElements(s);
Set triples = new HashSet<>();
for (int q = 0; q < elements.size(); q++) {
Element tripleElement = elements.get(q);
String value = tripleElement.getValue();
String[] tokens = value.split(",");
if (tokens.length != 3) {
throw new IllegalArgumentException("Expecting a triple: " + value);
}
String x = tokens[0].trim();
String y = tokens[1].trim();
String z = tokens[2].trim();
Node _x = getNode(variables, x);
Node _y = getNode(variables, y);
Node _z = getNode(variables, z);
Triple triple = new Triple(_x, _y, _z);
triples.add(triple);
}
return triples;
}
private static Node getNode(List variables, String x) {
for (Node node : variables) {
if (node.getName().equals(x)) return node;
}
return null;
}
/**
* getRootElement.
*
* @param file a {@link java.io.File} object
* @return a {@link nu.xom.Element} object
* @throws nu.xom.ParsingException if any.
* @throws java.io.IOException if any.
*/
public static Element getRootElement(File file) throws ParsingException, IOException {
Builder builder = new Builder();
Document document = builder.build(file);
return document.getRootElement();
}
private static Edge getEdge(String nodeNameFrom, String edgeType, String nodeNameTo, Graph graph) {
Node nodeFrom = getNode(nodeNameFrom, graph);
Node nodeTo = getNode(nodeNameTo, graph);
Endpoint endpointFrom = getEndpoint(edgeType.charAt(0));
Endpoint endpointTo = getEndpoint(edgeType.charAt(2));
return new Edge(nodeFrom, nodeTo, endpointFrom, endpointTo);
}
private static Endpoint getEndpoint(char endpoint) {
if (endpoint == '>' || endpoint == '<') {
return Endpoint.ARROW;
} else if (endpoint == 'o') {
return Endpoint.CIRCLE;
} else if (endpoint == '-') {
return Endpoint.TAIL;
} else if (endpoint == '.') {
return Endpoint.NULL;
} else {
throw new IllegalArgumentException(String.format("Unrecognized endpoint: %s.", endpoint));
}
}
private static Node getNode(String nodeName, Graph graph) {
Node node = graph.getNode(nodeName);
if (node == null) {
graph.addNode(new GraphNode(nodeName));
node = graph.getNode(nodeName);
}
return node;
}
/**
* grabLayout.
*
* @param nodes a {@link java.util.List} object
* @return a {@link java.util.HashMap} object
*/
public static HashMap grabLayout(List nodes) {
HashMap layout = new HashMap<>();
for (Node node : nodes) {
layout.put(node.getName(), new PointXy(node.getCenterX(), node.getCenterY()));
}
return layout;
}
/**
* getCollidersFromGraph.
*
* @param node a {@link edu.cmu.tetrad.graph.Node} object
* @param graph a {@link edu.cmu.tetrad.graph.Graph} object
* @return A list of triples of the form X*->Y<-*Z.
*/
public static List getCollidersFromGraph(Node node, Graph graph) {
List colliders = new ArrayList<>();
List adj = new ArrayList<>(graph.getAdjacentNodes(node));
if (adj.size() < 2) {
return new LinkedList<>();
}
ChoiceGenerator gen = new ChoiceGenerator(adj.size(), 2);
int[] choice;
while ((choice = gen.next()) != null) {
Node x = adj.get(choice[0]);
Node z = adj.get(choice[1]);
Endpoint endpt1 = graph.getEdge(x, node).getProximalEndpoint(node);
Endpoint endpt2 = graph.getEdge(z, node).getProximalEndpoint(node);
if (endpt1 == Endpoint.ARROW && endpt2 == Endpoint.ARROW) {
colliders.add(new Triple(x, node, z));
}
}
return colliders;
}
}