All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.stripe.brushfire.TreeTraversal.scala Maven / Gradle / Ivy

package com.stripe.brushfire

import com.stripe.bonsai.FullBinaryTreeOps

/**
 * A `TreeTraversal` provides a way to find all of the leaves in a tree that
 * some row can evaluate to. Specifically, there may be cases where multiple
 * predicates in a single split node return `true` for a given row (eg missing
 * features). A tree traversal chooses which paths to go down (which may be all
 * of them) and the order in which they are traversed.
 */
trait TreeTraversal[Tree, K, V, T, A] {

  val treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]

  /**
   * Limit the maximum number of leaves returned from `find` to `n`.
   */
  def limitTo(n: Int): TreeTraversal[Tree, K, V, T, A] =
    LimitedTreeTraversal(this, n)

  /**
   * Find the [[LeafNode]]s that best fit `row` in the tree.  Generally, the
   * path from `tree.root` to the resulting leaf node will be along *only*
   * `true` predicates. However, when multiple predicates are `true` in a
   * [[SplitNode]], the actual choice of which ones gets traversed is left to
   * the particular implementation of `TreeTraversal`.
   *
   * @param tree the decision tree to search in
   * @param row  the row/instance we're trying to match with a leaf node
   * @return the leaf nodes that best match the row
   */
  def search(tree: Tree, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]] =
    treeOps.root(tree) match {
      case Some(root) => searchNode(root, row, id)
      case None => Stream.empty
    }

  /**
   * Find the [[LeafNode]]s that best fit `row` in the tree.  Generally, the
   * path from `node` to the resulting leaf node will be along *only* `true`
   * predicates. However, when multiple predicates are `true` in a
   * [[SplitNode]], the actual choice of which ones gets traversed is left to
   * the particular implementation of `TreeTraversal`.
   *
   * @param node the initial node to start from
   * @param row  the row/instance we're trying to match with a leaf node
   * @return the leaf nodes that match the row
   */
  def searchNode(node: treeOps.Node, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]]
}

object TreeTraversal {

  def search[Tree, K, V, T, A](tree: Tree, row: Map[K, V], id: Option[String] = None)(implicit ev: TreeTraversal[Tree, K, V, T, A]): Stream[LeafLabel[T, A]] =
    ev.search(tree, row, id)

  /**
   * Performs a depth-first traversal of the tree, returning all matching leaf
   * nodes.
   */
  implicit def depthFirst[Tree, K, V: Ordering, T, A](implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]): TreeTraversal[Tree, K, V, T, A] =
    DepthFirstTreeTraversal(Reorder.unchanged)

  /**
   * A depth-first search for matching leaves, where the candidate child nodes
   * for a given parent node are traversed in reverse order of their
   * annotations. This means that if we have multiple valid candidate children,
   * we will traverse the child with the largest annotation first.
   */
  def weightedDepthFirst[Tree, K, V: Ordering, T, A: Ordering](implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]): TreeTraversal[Tree, K, V, T, A] =
    DepthFirstTreeTraversal(Reorder.weightedDepthFirst)

  /**
   * A depth first search for matching leaves, randomly choosing the order of
   * child candidate nodes to traverse at each step. Since it is depth-first,
   * after a node is chosen to be traversed, all of the matching leafs that
   * descend from that node are traversed before moving onto the node's
   * sibling.
   */
  def randomDepthFirst[Tree, K, V: Ordering, T, A](seed: Option[Int] = None)(implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]): TreeTraversal[Tree, K, V, T, A] =
    DepthFirstTreeTraversal(Reorder.shuffled(seed.getOrElse(System.nanoTime.toInt)))

  /**
   * A depth-first search for matching leaves, where the candidate child leaves
   * of a parent node are randomly shuffled, but with nodes with higher weight
   * being given a higher probability of being ordered earlier. This is
   * basically a mix between [[randomDepthFirst]] and [[weightedDepthFirst]].
   *
   * The actual algorithm can best be though of as a random sample from a set
   * of weighted elements without replacement. The weight is directly
   * proportional to its probability of being sampled, relative to all the
   * other elements still in the set.
   */
  def probabilisticWeightedDepthFirst[Tree, K, V: Ordering, T, A](seed: Option[Int] = None)(implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]], conversion: A => Double): TreeTraversal[Tree, K, V, T, A] = {
    val n = seed.getOrElse(System.nanoTime.toInt)
    DepthFirstTreeTraversal(Reorder.probabilisticWeightedDepthFirst(n, conversion))
  }
}

case class DepthFirstTreeTraversal[Tree, K, V, T, A](reorder: Reorder[A])(implicit val treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]], ord: Ordering[V]) extends TreeTraversal[Tree, K, V, T, A] {

  import treeOps.{Node, foldNode}

  def searchNode(start: Node, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]] = {

    // this will be a noop unless we have an id and our reorder
    // instance requires randomness. it ensures that each searchNode
    // call has its own independent RNG (in cases where we care about
    // repeatability, i.e. when `id` is not None).
    val r = reorder.setSeed(id)

    // pull the A value out of a branch or leaf.
    val getAnnotation: Node => A =
      node => foldNode(node)((_, _, bl) => bl._3, ll => ll._3)

    // construct a singleton stream from a leaf
    val leafF: LeafLabel[T, A] => Stream[LeafLabel[T, A]] =
      _ #:: Stream.empty

    // determine the order to traverse into two given nodes. this var
    // is initialized just after 'recurse' -- it is basically a lazy
    // val but with better performance.
    var reorderF: (Node, Node) => Stream[LeafLabel[T, A]] = null

    // recurse into branch nodes, going left, right, or both,
    // depending on what our predicate says. this var is initialized
    // just after 'recurse' -- it is basically a lazy val but with
    // better performance.
    var branchF: (Node, Node, BranchLabel[K, V, A]) => Stream[LeafLabel[T, A]] = null

    // recursively handle each node. the foldNode method decides
    // whether to handle it as a branch or a leaf.
    def recurse(node: Node): Stream[LeafLabel[T, A]] =
      foldNode(node)(branchF, leafF)

    // now that recurse is defined we can initialize this
    reorderF = (n1, n2) => recurse(n1) #::: recurse(n2)

    // now that recurse is defined we can initialize this
    branchF = (lc, rc, t) => t match {
      case (k, p, _) => row.get(k) match {
        case Some(v) => if (p(v)) recurse(lc) else recurse(rc)
        case None => r(lc, rc, getAnnotation, reorderF)
      }
    }

    // ok, now do it!
    //
    // the reason we did all the work above of defining the functions
    // in variables is that this makes our traversal more
    // efficient. otherwise we'd have to generate Function1 instances
    // at each level of each tree.
    recurse(start)
  }
}

case class LimitedTreeTraversal[Tree, K, V, T, A](traversal: TreeTraversal[Tree, K, V, T, A], limit: Int) extends TreeTraversal[Tree, K, V, T, A] {
  require(limit > 0, "limit must be greater than 0")
  val treeOps: traversal.treeOps.type = traversal.treeOps
  def searchNode(node: treeOps.Node, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]] =
    traversal.searchNode(node, row, id).take(limit)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy