org.broadinstitute.hellbender.tools.spark.pathseq.PSTree Maven / Gradle / Ivy
package org.broadinstitute.hellbender.tools.spark.pathseq;
import com.esotericsoftware.kryo.DefaultSerializer;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.Utils;
import java.util.*;
import java.util.stream.Collectors;
/**
* Represents a taxonomic tree with nodes assigned a name and taxonomic rank (e.g. order, family, genus, species, etc.)
* Nodes are stored in a HashMap and keyed by a String id. Tree keeps track of each node's children as well as its
* parent for efficient traversals top-down and bottom-up.
*
* Designed to be constructed on-the-fly while parsing the NCBI taxonomy dump, which specifies each node and its parent.
* As a result, invalid trees may be constructed (e.g. multiple roots or cycles). Use checkStructure() to confirm its validity.
*
* Note the tree root is initialized in the constructor and cannot be modified (except adding children with addNode()).
*/
@DefaultSerializer(PSTree.Serializer.class)
public class PSTree {
private final int root;
private Map tree;
public static final int NULL_NODE = 0;
public PSTree(final int rootId) {
tree = new HashMap<>();
root = rootId;
tree.put(root, new PSTreeNode());
tree.get(root).setName("root");
tree.get(root).setRank("root");
tree.get(root).setParent(NULL_NODE);
}
@SuppressWarnings("unchecked")
protected PSTree(final Kryo kryo, final Input input) {
final boolean oldReferences = kryo.getReferences();
kryo.setReferences(false);
root = Integer.valueOf(input.readString());
final int treeSize = input.readInt();
tree = new HashMap<>(treeSize);
for (int i = 0; i < treeSize; i++) {
final int key = input.readInt();
final PSTreeNode value = kryo.readObject(input, PSTreeNode.class);
tree.put(key, value);
}
kryo.setReferences(oldReferences);
}
/**
* Returns short String of 20 arbitrarily chosen nodes
*/
private static String getAbbreviatedNodeListString(final Set nodes) {
return nodes.stream().limit(20).map(String::valueOf).collect(Collectors.joining(", ","[","]"));
}
protected void serialize(final Kryo kryo, final Output output) {
final boolean oldReferences = kryo.getReferences();
kryo.setReferences(false);
output.writeString(String.valueOf(root));
output.writeInt(tree.size());
for (final int key : tree.keySet()) {
output.writeInt(key);
kryo.writeObject(output, tree.get(key));
}
kryo.setReferences(oldReferences);
}
/**
* Adds a node to the tree. ID must be unique, and all arguments must be non-null.
* If the node exists, its properties are modified
*/
public void addNode(final int id, final String name, final int parent, final long length, final String rank) {
Utils.validateArg((name != null) && (rank != null), "Passed a null argument to addNode()");
Utils.validateArg(id != NULL_NODE, "Passed invalid node ID to addNode()");
Utils.validateArg(parent != NULL_NODE, "Passed invalid parent ID to addNode()");
Utils.validateArg(root != id, "Tried to set root attributes using addNode()");
//If the node already exists, keep its current children and set everything else
if (!tree.containsKey(id)) {
tree.put(id, new PSTreeNode());
}
final PSTreeNode node = tree.get(id);
node.setName(name);
node.setParent(parent);
node.setLength(length);
node.setRank(rank);
//If the parent doesn't exist, create a new node for it
//We are assuming its attributes will be set later using addNode()
if (!tree.containsKey(parent)) {
tree.put(parent, new PSTreeNode());
}
tree.get(parent).addChild(id);
}
/**
* Deletes nodes unreachable from the root and returns the set of deleted nodes.
*/
public Set removeUnreachableNodes() {
final Set reachable = traverse();
final Set unreachable = new HashSet<>(tree.keySet());
unreachable.removeAll(reachable);
retainNodes(reachable);
return unreachable;
}
/**
* Because of the piecemeal way we allow the tree to be constructed, we can end up with invalid tree structures.
* Check that tree structure contains valid pointers and is connected.
*/
public void checkStructure() {
//Check child-parent pointers are consistent
for (final int id : tree.keySet()) {
final PSTreeNode n = tree.get(id);
for (final int child : n.getChildren()) {
if (!tree.containsKey(child)) {
throw new UserException.BadInput("Malformed tree detected. Node " + id + " has non-existent child " + child);
}
if (tree.get(child).getParent() != id) {
throw new UserException.BadInput("Malformed tree detected. Node " + id + " has child " + child + ", whose parent is " + tree.get(child).getParent() + " instead of " + id);
}
}
final int parent = n.getParent();
if (parent != NULL_NODE) {
if (!tree.containsKey(parent)) {
throw new UserException.BadInput("Malformed tree detected. Node " + id + " has non-existent parent " + parent);
}
if (!tree.get(parent).getChildren().contains(id)) {
throw new UserException.BadInput("Malformed tree detected. Node " + id + " has parent " + parent + ", which does not have child " + id);
}
}
}
//Check disconnected sets of nodes using a traversal from the root
final Set unreachable = removeUnreachableNodes();
if (!unreachable.isEmpty()) {
final String nodesList = getAbbreviatedNodeListString(unreachable);
throw new UserException.BadInput("Malformed tree detected. Tree is disconnected. " + unreachable.size() + " of " + tree.size() + " nodes were unreachable: " + nodesList);
}
}
/**
* Find all nodes reachable from the root
*/
private Set traverse() {
final Queue queue = new LinkedList<>();
final Set visited = new HashSet<>(tree.size());
queue.add(root);
while (!queue.isEmpty()) {
final int id = queue.poll();
if (!visited.contains(id)) { //checked visited in case there are cycles
if (tree.containsKey(id)) {
queue.addAll(tree.get(id).getChildren());
} else {
throw new UserException.BadInput("Could not find node " + id + " while traversing the tree");
}
} else {
throw new UserException.BadInput("Tree contains a cycle. Attempted to visit node " + id + " twice during a breadth-first traversal.");
}
visited.add(id);
}
return visited;
}
/**
* Get lowest common ancester of the set of given nodes.
* Takes the intersection of node-to-root paths of all the nodes and finding the lowest one in the tree.
*/
public int getLCA(final Collection nodes) {
Utils.validateArg(nodes.size() > 0, "Queried lowest common ancestor of a null set");
final List> paths = new ArrayList<>(nodes.size());
for (final int node : nodes) {
paths.add(getPathOf(node));
}
final List firstPath = paths.remove(0);
final Set commonNodes = new HashSet<>(firstPath);
for (final List path : paths) {
commonNodes.retainAll(path);
}
//Return first common node. Note paths are returned in order from lowest to highest (root at the end)
for (final int node : firstPath) {
if (commonNodes.contains(node)) return node;
}
//This should never happen if the tree structure has been checked
throw new GATKException.ShouldNeverReachHereException("Could not find common ancester of node set.");
}
@SuppressWarnings("unchecked")
public Collection getChildrenOf(final int id) {
Utils.validateArg(tree.containsKey(id), "Could not get children of node id " + id + " because it does not exist");
return tree.get(id).getChildren();
}
public Set getNodeIDs() {
return tree.keySet();
}
public String getNameOf(final int id) {
Utils.validateArg(tree.containsKey(id), "Could not get name of node id " + id + " because it does not exist");
return tree.get(id).getName();
}
public int getParentOf(final int id) {
Utils.validateArg(tree.containsKey(id), "Could not get parent of node id " + id + " because it does not exist");
return tree.get(id).getParent();
}
public long getLengthOf(final int id) {
Utils.validateArg(tree.containsKey(id), "Could not get length of node id " + id + " because it does not exist");
return tree.get(id).getLength();
}
public String getRankOf(final int id) {
Utils.validateArg(tree.containsKey(id), "Could not get rank of node id " + id + " because it does not exist");
return tree.get(id).getRank();
}
public boolean hasNode(final int id) {
return tree.containsKey(id);
}
/**
* Removes all nodes not in the given set
*/
public void retainNodes(final Set idsToKeep) {
Utils.validateArg(idsToKeep.contains(root), "Cannot remove root");
final HashMap newTree = new HashMap<>(idsToKeep.size());
for (final int id : tree.keySet()) {
if (idsToKeep.contains(id)) {
final PSTreeNode info = tree.get(id);
final PSTreeNode newInfo = info.copy();
for (final int child : info.getChildren()) {
if (!idsToKeep.contains(child)) {
newInfo.removeChild(child);
}
}
if (!idsToKeep.contains(info.getParent())) {
newInfo.setParent(NULL_NODE);
}
newTree.put(id, newInfo);
}
}
tree = newTree;
}
public void setNameOf(final int id, final String name) {
Utils.validateArg(tree.containsKey(id), "Could not set name of node id " + id + " because it does not exist");
Utils.validateArg(name != null, "Cannot set node name to null");
Utils.validateArg(id != root, "Cannot set name of root");
tree.get(id).setName(name);
}
public void setRankOf(final int id, final String rank) {
Utils.validateArg(tree.containsKey(id), "Could not set rank of node id " + id + " because it does not exist");
Utils.validateArg(rank != null, "Cannot set node rank to null");
Utils.validateArg(id != root, "Cannot set rank of root");
tree.get(id).setRank(rank);
}
public void setLengthOf(final int id, final long length) {
Utils.validateArg(tree.containsKey(id), "Could not set rank of node id " + id + " because it does not exist");
Utils.validateArg(id != root, "Cannot set rank of root");
tree.get(id).setLength(length);
}
/**
* Returns path of node id's from the input id to the root.
*/
public List getPathOf(int id) {
final List path = new ArrayList<>();
final Set visitedNodes = new HashSet<>(tree.size());
while (id != NULL_NODE) {
if (!visitedNodes.contains(id)) {
visitedNodes.add(id);
if (tree.containsKey(id)) {
path.add(id);
id = tree.get(id).getParent();
} else {
throw new UserException.BadInput("Parent node " + id + " not found in tree while getting path");
}
} else {
throw new UserException.BadInput("The tree contains a cycle at node " + id);
}
}
return path;
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final PSTree psTree = (PSTree) o;
return root == psTree.root && tree.equals(psTree.tree);
}
@Override
public int hashCode() {
int result = root;
result = 31 * result + tree.hashCode();
return result;
}
@Override
public String toString() {
return getAbbreviatedNodeListString(tree.keySet());
}
public static final class Serializer extends com.esotericsoftware.kryo.Serializer {
@Override
public void write(final Kryo kryo, final Output output, final PSTree psTree) {
psTree.serialize(kryo, output);
}
@Override
public PSTree read(final Kryo kryo, final Input input,
final Class klass) {
return new PSTree(kryo, input);
}
}
}