edu.cmu.tetrad.algcomparison.statistic.TrueDagTruePositiveArrow Maven / Gradle / Ivy
The newest version!
package edu.cmu.tetrad.algcomparison.statistic;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import java.io.Serial;
/**
* The bidirected true positives.
*
* @author josephramsey
* @version $Id: $Id
*/
public class TrueDagTruePositiveArrow implements Statistic {
@Serial
private static final long serialVersionUID = 23L;
/**
* This class represents a statistic that calculates the true positives for arrows compared to the true DAG.
*/
public TrueDagTruePositiveArrow() {
}
/**
* Retrieves the abbreviation for the statistic.
*
* @return The abbreviation.
*/
@Override
public String getAbbreviation() {
return "DTPA";
}
/**
* Retrieves a short one-line description of this statistic.
*
* @return The description of the statistic.
*/
@Override
public String getDescription() {
return "True Positives for Arrows compared to true DAG";
}
/**
* Calculates the number of true positives for arrows compared to the true DAG.
*
* @param trueGraph The true graph (DAG, CPDAG, PAG_of_the_true_DAG).
* @param estGraph The estimated graph (same type).
* @param dataModel The data model.
* @return The number of true positives.
*/
@Override
public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) {
int tp = 0;
for (Edge edge : estGraph.getEdges()) {
if (edge.getEndpoint1() == Endpoint.ARROW) {
if (!trueGraph.paths().isAncestorOf(edge.getNode1(), edge.getNode2())) {
tp++;
}
}
if (edge.getEndpoint2() == Endpoint.ARROW) {
if (!trueGraph.paths().isAncestorOf(edge.getNode2(), edge.getNode1())) {
tp++;
}
}
}
return tp;
}
/**
* Retrieves the normalized value of the statistic.
*
* @param value The value of the statistic.
* @return The normalized value of the statistic.
*/
@Override
public double getNormValue(double value) {
return value;
}
}