All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.stripe.brushfire.AnnotatedTree.scala Maven / Gradle / Ivy

package com.stripe.brushfire

import com.twitter.algebird._
import com.stripe.bonsai.{ FullBinaryTree, FullBinaryTreeOps, Layout }

import java.lang.Math.{abs, max}

sealed abstract class Node[K, V, T, A] {

  def annotation: A

  def renumber(nextId: Int): (Int, Node[K, V, T, A]) =
    this match {
      case SplitNode(k, p, lc0, rc0, a) =>
        val (n1, lc) = lc0.renumber(nextId)
        val (n2, rc) = rc0.renumber(n1)
        (n2, SplitNode(k, p, lc, rc, a))
      case LeafNode(_, target, annotation) =>
        (nextId + 1, LeafNode(nextId, target, annotation))
    }

  def fold[B](f: (Node[K, V, T, A], Node[K, V, T, A], (K, Predicate[V], A)) => B, g: Tuple3[Int, T, A] => B): B =
    this match {
      case SplitNode(k, p, lc, rc, a) => f(lc, rc, (k, p, a))
      case LeafNode(index, target, a) => g((index, target, a))
    }
}

case class SplitNode[K, V, T, A](key: K, predicate: Predicate[V], leftChild: Node[K, V, T, A], rightChild: Node[K, V, T, A], annotation: A) extends Node[K, V, T, A] {
  def evaluate(row: Map[K, V])(implicit ord: Ordering[V]): List[Node[K, V, T, A]] =
    row.get(key) match {
      case Some(v) =>
        (if (predicate(v)) leftChild else rightChild) :: Nil
      case None =>
        leftChild :: rightChild :: Nil
    }

  def splitLabel: (K, Predicate[V], A) = (key, predicate, annotation)

  override def toString: String =
    s"SplitNode($key,${predicate.display(key.toString)},$leftChild,$rightChild)"
}

object SplitNode {
  def apply[K, V, T, A: Semigroup](k: K, p: Predicate[V], lc: Node[K, V, T, A], rc: Node[K, V, T, A]): SplitNode[K, V, T, A] =
    SplitNode(k, p, lc, rc, Semigroup.plus(lc.annotation, rc.annotation))
}

case class LeafNode[K, V, T, A](index: Int, target: T, annotation: A) extends Node[K, V, T, A] {
  def leafLabel: (Int, T, A) = (index, target, annotation)
}

object LeafNode {
  def apply[K, V, T, A: Monoid](index: Int, target: T): LeafNode[K, V, T, A] =
    LeafNode(index, target, Monoid.zero)
}

case class AnnotatedTree[K, V, T, A](root: Node[K, V, T, A]) {

  import AnnotatedTree.AnnotatedTreeTraversal

  /**
   * Transform the splits of a tree (the predicates and keys of the
   * nodes) while leaving everything else alone.
   */
  private def mapSplits[K0, V0](f: (K, Predicate[V]) => (K0, Predicate[V0])): AnnotatedTree[K0, V0, T, A] = {
    def recur(node: Node[K, V, T, A]): Node[K0, V0, T, A] = {
      node match {
        case SplitNode(k, p, lc, rc, a) =>
          val (key, pred) = f(k, p)
          SplitNode(key, pred, recur(lc), recur(rc), a)
        case LeafNode(index, target, a) =>
          LeafNode(index, target, a)
      }
    }
    AnnotatedTree(recur(root))
  }

  /**
   * Transform the leaves of a tree (the target and annotation) while
   * leaving the structure alone.
   */
  private def mapLeaves[T0, A0: Semigroup](f: (T, A) => (T0, A0)): AnnotatedTree[K, V, T0, A0] = {
    def recur(node: Node[K, V, T, A]): Node[K, V, T0, A0] = node match {
      case SplitNode(k, p, lc, rc, _) =>
        SplitNode(k, p, recur(lc), recur(rc))
      case LeafNode(index, target, a) =>
        val (target1, a1) = f(target, a)
        LeafNode(index, target1, a1)
    }
    AnnotatedTree(recur(root))
  }

  /**
   * Annotate the tree by mapping the leaf target distributions to
   * some annotation for the leaves, then bubbling the annotations up
   * the tree using the `Semigroup` for the annotation type.
   */
  def annotate[A1: Semigroup](f: T => A1): AnnotatedTree[K, V, T, A1] =
    mapLeaves { (t, _) => (t, f(t)) }

  /**
   * Re-annotate the leaves of this tree using `f` to transform the
   * annotations. This will then bubble the annotations up the tree
   * using the `Semigroup` for `A1`. If `f` is a semigroup
   * homomorphism, then this (semantically) just transforms the
   * annotation at each node using `f`.
   */
  def mapAnnotation[A1: Semigroup](f: A => A1): AnnotatedTree[K, V, T, A1] =
    mapLeaves { (t, a) => (t, f(a)) }

  /**
   * Maps the feature keys used to split the `Tree` using `f`.
   */
  def mapKeys[K1](f: K => K1): AnnotatedTree[K1, V, T, A] =
    mapSplits { (k, p) => (f(k), p) }

  /**
   * Maps the [[Predicate]]s in the `Tree` using `f`. Note, this will
   * only produce a valid `Tree` if `f` preserves the ordering (ie if
   * `a.compare(b) == f(a).compare(f(b))`).
   */
  def mapPredicates[V1: Ordering](f: V => V1): AnnotatedTree[K, V1, T, A] =
    mapSplits { (k, p) => (k, p.map(f)) }

  /**
   * Returns the leaf with index `leafIndex` by performing a DFS.
   */
  def leafAt(leafIndex: Int): Option[LeafNode[K, V, T, A]] =
    leafAt(leafIndex, root)

  /**
   * Returns the leaf with index `leafIndex` that is a descendant of
   * `start` by performing a DFS starting with `start`.
   */
  def leafAt(leafIndex: Int, start: Node[K, V, T, A]): Option[LeafNode[K, V, T, A]] =
    start match {
      case SplitNode(k, p, lc, rc, _) =>
        leafAt(leafIndex, lc) match {
          case None => leafAt(leafIndex, rc)
          case some => some
        }
      case leaf @ LeafNode(index, _, _) =>
        if (index == leafIndex) Some(leaf) else None
    }

  /**
   * Prune a tree to minimize validation error.
   *
   * Recursively replaces each split with a leaf when it would have a
   * lower error than the sum of the child leaves errors.
   *
   * @param validationData A Map from leaf index to validation data.
   * @param voter to create predictions from target distributions.
   * @param error to calculate an error statistic given observations (validation) and predictions (training).
   * @return The new, pruned tree.
   */
  def prune[P, E: Ordering](validationData: Map[Int, T], voter: Voter[T, P], error: Error[T, P, E])(implicit m: Monoid[T], s: Semigroup[A]): AnnotatedTree[K, V, T, A] =
    AnnotatedTree(pruneNode(validationData, root, voter, error)._2)

  /**
   * Prune a tree to minimize validation error, starting from given
   * root node.
   *
   * This method recursively traverses the tree from the root,
   * branching on splits, until it finds leaves, then goes back down
   * the tree combining leaves when such a combination would reduce
   * validation error.
   *
   * @param validationData Map from leaf index to validation data.
   * @param start The root node of the tree.
   * @return A node at the root of the new, pruned tree.
   */
  def pruneNode[P, E: Ordering](validationData: Map[Int, T], start: Node[K, V, T, A], voter: Voter[T, P], error: Error[T, P, E])(implicit m: Monoid[T], aSemigroup: Semigroup[A]): (Map[Int, T], Node[K, V, T, A]) = {
    start match {
      case leaf @ LeafNode(_, _, _) =>
        // Bounce at the bottom and start back up the tree.
        (validationData, leaf)

      case SplitNode(k, p, lc0, rc0, _) =>
        val (vData, lc1) = pruneNode(validationData, lc0, voter, error)
        val (newData, rc1) = pruneNode(vData, rc0, voter, error)

        // If all the children are leaves, we can potentially
        // prune. Otherwise we definitely cannot prune.
        lc1 match {
          case lc2 @ LeafNode(_, _, _) =>
            rc1 match {
              case rc2 @ LeafNode(_, _, _) =>
                pruneLevel(SplitNode(k, p, lc2, rc2), lc2, rc2, newData, voter, error)
              case _ =>
                (newData, SplitNode(k, p, lc1, rc1))
            }
          case _ =>
            (newData, SplitNode(k, p, lc1, rc1))
        }
    }
  }

  /**
   * Test conditions and optionally replace parent with a new leaf
   * that combines children.
   *
   * Also merges validation data for any combined leaves. This relies
   * on a hack that assumes no leaves have negative indices to start
   * out.
   *
   * @param parent
   * @param children
   * @return
   */
  def pruneLevel[P, E](
    parent: SplitNode[K, V, T, A],
    leftChild: LeafNode[K, V, T, A],
    rightChild: LeafNode[K, V, T, A],
    validationData: Map[Int, T],
    voter: Voter[T, P],
    error: Error[T, P, E])(implicit targetMonoid: Monoid[T], errorOrdering: Ordering[E]): (Map[Int, T], Node[K, V, T, A]) = {

    def v(leaf: LeafNode[K, V, T, A]): T =
      validationData.getOrElse(leaf.index, targetMonoid.zero)

    def e(leaf: LeafNode[K, V, T, A]): E =
      error.create(v(leaf), voter.combine(Some(leaf.target)))

    val targetSum = targetMonoid.plus(leftChild.target, rightChild.target)
    val validationSum = targetMonoid.plus(v(leftChild), v(rightChild))
    val sumOfErrors = error.semigroup.plus(e(leftChild), e(rightChild))

    // Get training and validation data and validation error for each leaf.
    val targetPrediction = voter.combine(Some(targetSum)) // Generate prediction from combined target.
    val errorOfSums = error.create(validationSum, targetPrediction) // Error of potential combined node.

    // Compare sum of errors and error of sums (and lower us out of the sum of errors Option).
    val doCombine = errorOrdering.gteq(sumOfErrors, errorOfSums)

    if (doCombine) {
      // Create a new leaf from the combination of the children.
      // Find a unique (negative) index for the new leaf:
      val newIndex = -1 * max(abs(leftChild.index), abs(rightChild.index))
      val node = LeafNode[K, V, T, A](newIndex, targetSum, parent.annotation)
      (validationData + (newIndex -> validationSum), node)
    } else {
      (validationData, parent)
    }
  }

  def leafFor(row: Map[K, V], id: Option[String] = None)(implicit traversal: AnnotatedTreeTraversal[K, V, T, A]): Option[(Int, T, A)] =
    traversal.search(this, row, id).headOption

  def leafIndexFor(row: Map[K, V], id: Option[String] = None)(implicit traversal: AnnotatedTreeTraversal[K, V, T, A]): Option[Int] =
    leafFor(row, id).map(_._1)

  def targetFor(row: Map[K, V], id: Option[String] = None)(implicit traversal: AnnotatedTreeTraversal[K, V, T, A], semigroup: Semigroup[T]): Option[T] =
    semigroup.sumOption(leafFor(row, id).map(_._2))

  /**
   * For each leaf, this may convert the leaf to a [[SplitNode]] whose children
   * have the target distribution returned by calling `fn` with the leaf's
   * index. If `fn` returns an empty `Seq`, then the leaf is left as-is.
   */
  def growByLeafIndex(fn: Int => Option[SplitNode[K, V, T, A]])(implicit s: Semigroup[A]): AnnotatedTree[K, V, T, A] = {
    def growFrom(nextIndex: Int, start: Node[K, V, T, A]): (Int, Node[K, V, T, A]) = {
      start match {
        case LeafNode(index, target, annotation) =>
          fn(index) match {
            case None => (nextIndex + 1, LeafNode(nextIndex, target, annotation))
            case Some(split) => split.renumber(nextIndex)
          }
        case SplitNode(k, p, lc0, rc0, _) =>
          val (n1, lc) = growFrom(nextIndex, lc0)
          val (n2, rc) = growFrom(n1, rc0)
          (n2, SplitNode(k, p, lc, rc))
      }
    }

    AnnotatedTree(growFrom(0, root)._2)
  }

  /**
   * For each [[LeafNode]] in this tree, this will replace it with the node
   * returned from `fn` (called with the leaf node's index), if it returns
   * `Some` node. Otherwise, if `fn` returns `None`, then the leaf node is
   * left as-is.
   */
  def updateByLeafIndex(fn: Int => Option[Node[K, V, T, A]])(implicit s: Semigroup[A]): AnnotatedTree[K, V, T, A] = {
    def updateFrom(start: Node[K, V, T, A]): Node[K, V, T, A] = {
      start match {
        case LeafNode(index, _, _) =>
          fn(index).getOrElse(start)
        case SplitNode(k, p, lc0, rc0, _) =>
          SplitNode(k, p, updateFrom(lc0), updateFrom(rc0))
      }
    }

    AnnotatedTree(updateFrom(root)).renumberLeaves
  }

  /**
   * Renumber all leaves in the tree.
   *
   * @return A new tree with leaves renumbered.
   */
  def renumberLeaves: AnnotatedTree[K, V, T, A] =
    AnnotatedTree(root.renumber(0)._2)

  def compress(implicit
    branchLayout: Layout[BranchLabel[K, V, A]],
    leafLayout: Layout[LeafLabel[T, A]]
  ): FullBinaryTree[(K, Predicate[V], A), (Int, T, A)] = {
    FullBinaryTree(this)
  }
}

object AnnotatedTree {

  type AnnotatedTreeTraversal[K, V, T, A] = TreeTraversal[AnnotatedTree[K, V, T, A], K, V, T, A]

  implicit def fullBinaryTreeOpsForAnnotatedTree[K, V, T, A]: FullBinaryTreeOps[AnnotatedTree[K, V, T, A], (K, Predicate[V], A), (Int, T, A)] =
    new FullBinaryTreeOpsForAnnotatedTree[K, V, T, A]
}

class FullBinaryTreeOpsForAnnotatedTree[K, V, T, A] extends FullBinaryTreeOps[AnnotatedTree[K, V, T, A], (K, Predicate[V], A), (Int, T, A)] {
  type Node = com.stripe.brushfire.Node[K, V, T, A]

  def root(t: AnnotatedTree[K, V, T, A]): Option[Node] = Some(t.root)

  def foldNode[B](node: Node)(
    f: (Node, Node, (K, Predicate[V], A)) => B,
    g: Tuple3[Int, T, A] => B): B = node.fold(f, g)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy