All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scalaz.StrictTree.scala Maven / Gradle / Ivy

package scalaz

import scala.collection.mutable
import std.vector.{vectorInstance, vectorMonoid}

/**
  *
  * @param rootLabel The label at the root of this tree.
  * @param subForest The child nodes of this tree.
  * @tparam A
  */
case class StrictTree[A](
  rootLabel: A,
  subForest: Vector[StrictTree[A]]
) {

  import StrictTree._

  /**
    * Run a bottom-up algorithm.
    *
    * This is the framework for several stackless methods, such as map.
    *
    * @param reduce is a function from a label and its mapped children to the new result.
    */
  private[scalaz] def runBottomUp[B](
    reduce: A => mutable.Buffer[B] => B
  ): B = {
    val root = BottomUpStackElem[A, B](None, this)
    val stack = mutable.Stack[BottomUpStackElem[A, B]](root)

    while (stack.nonEmpty) {
      val here = stack.head
      if (here.hasNext) {
        val child = here.next()
        val nextStackElem = BottomUpStackElem[A, B](Some(here), child)
        stack.push(nextStackElem)
      } else {
        //The "here" node is completed, so add its result to its parents completed children.
        val result = reduce(here.rootLabel)(here.mappedSubForest)
        here.parent.foreach(_.mappedSubForest += result)
        stack.pop()
      }
    }

    reduce(root.rootLabel)(root.mappedSubForest)
  }

  /** Maps the elements of the StrictTree into a Monoid and folds the resulting StrictTree. */
  def foldMap[B: Monoid](f: A => B): B =
    runBottomUp(foldMapReducer(f))

  def foldRight[B](z: B)(f: (A, => B) => B): B =
    Foldable[Vector].foldRight(flatten, z)(f)

  /** A 2D String representation of this StrictTree. */
  def drawTree(implicit sh: Show[A]): String = {
    toTree.drawTree
  }

  /** A histomorphic transform. Each element in the resulting tree
    * is a function of the corresponding element in this tree
    * and the histomorphic transform of its children.
    */
  def scanr[B](g: (A, Vector[StrictTree[B]]) => B): StrictTree[B] =
    runBottomUp(scanrReducer(g))

  /** Pre-order traversal. */
  def flatten: Vector[A] = {
    val stack = mutable.Stack(this)

    val result = mutable.Buffer.empty[A]

    while (stack.nonEmpty) {
      val popped = stack.pop()
      result += popped.rootLabel
      popped.subForest.reverseIterator.foreach(stack.push)
    }

    result.toVector
  }

  def size: Int = {
    val stack = mutable.Stack(this.subForest)

    var result = 1

    while (stack.nonEmpty) {
      val popped = stack.pop()
      result += popped.size
      stack.pushAll(popped.map(_.subForest))
    }

    result
  }

  /** Breadth-first traversal. */
  def levels: Vector[Vector[A]] = {
    val f = (s: Vector[StrictTree[A]]) => {
      Foldable[Vector].foldMap(s)((_: StrictTree[A]).subForest)
    }
    Vector.iterate(Vector(this), size)(f) takeWhile (!_.isEmpty) map (_ map (_.rootLabel))
  }

  def toTree: Tree[A] = {
    Tree.Node[A](rootLabel, subForest.foldRight(EphemeralStream.emptyEphemeralStream[Tree[A]])((t, b) => t.toTree ##:: b))
  }

  /** Binds the given function across all the subtrees of this tree. */
  def cobind[B](f: StrictTree[A] => B): StrictTree[B] = unfoldTree(this)(t => (f(t), t.subForest))

  def foldNode[Z](f: A => Vector[StrictTree[A]] => Z): Z =
    f(rootLabel)(subForest)

  def map[B](f: A => B): StrictTree[B] = {
    runBottomUp(mapReducer(f))
  }

  def flatMap[B](f: A => StrictTree[B]): StrictTree[B] = {
    runBottomUp(flatMapReducer(f))
  }

  def traverse1[G[_] : Apply, B](f: A => G[B]): G[StrictTree[B]] = {
    val G = Apply[G]

    subForest match {
      case Vector() => G.map(f(rootLabel))(Leaf(_))
      case x +: xs => G.apply2(f(rootLabel), NonEmptyList.nel(x, IList.fromFoldable(xs)).traverse1(_.traverse1(f))) {
        case (h, t) => Node(h, t.list.toVector)
      }
    }
  }

  def zip[B](b: StrictTree[B]): StrictTree[(A, B)] = {
    val root = ZipStackElem[A, B](None, this, b)
    val stack = mutable.Stack[ZipStackElem[A, B]](root)

    while (stack.nonEmpty) {
      val here = stack.head
      if (here.hasNext) {
        val (childA, childB) = here.next()
        val nextStackElem = ZipStackElem[A, B](Some(here), childA, childB)
        stack.push(nextStackElem)
      } else {
        //The "here" node is completed, so add its result to its parents completed children.
        val result = StrictTree((here.a.rootLabel, here.b.rootLabel), here.mappedSubForest.toVector)
        here.parent.foreach(_.mappedSubForest += result)
        stack.pop()
      }
    }

    StrictTree((rootLabel, b.rootLabel), root.mappedSubForest.toVector)
  }

  /**
    * This implementation is 24x faster than the trampolined implementation for StrictTreeTestJVM's hashCode test.
    *
    * @return
    */
  override def hashCode(): Int = {
    runBottomUp(hashCodeReducer)
  }

  override def equals(obj: scala.Any): Boolean = {
    obj match {
      case other: StrictTree[A] =>
        StrictTree.badEqInstance[A].equal(this, other)
      case _ =>
        false
    }
  }
}

sealed abstract class StrictTreeInstances {

  implicit val strictTreeIsCovariant: IsCovariant[StrictTree] =
    IsCovariant.force[StrictTree]

  implicit val strictTreeInstance: Traverse1[StrictTree] & Monad[StrictTree] & Comonad[StrictTree] & Align[StrictTree] & Zip[StrictTree] = new Traverse1[StrictTree] with Monad[StrictTree] with Comonad[StrictTree] with Align[StrictTree] with Zip[StrictTree] {
    def point[A](a: => A): StrictTree[A] = StrictTree.Leaf(a)
    def cobind[A, B](fa: StrictTree[A])(f: StrictTree[A] => B): StrictTree[B] = fa cobind f
    def copoint[A](p: StrictTree[A]): A = p.rootLabel
    override def map[A, B](fa: StrictTree[A])(f: A => B) = fa map f
    def bind[A, B](fa: StrictTree[A])(f: A => StrictTree[B]): StrictTree[B] = fa flatMap f
    def traverse1Impl[G[_]: Apply, A, B](fa: StrictTree[A])(f: A => G[B]): G[StrictTree[B]] = fa traverse1 f
    override def foldRight[A, B](fa: StrictTree[A], z: => B)(f: (A, => B) => B): B = fa.foldRight(z)(f)
    override def foldMapRight1[A, B](fa: StrictTree[A])(z: A => B)(f: (A, => B) => B) = (fa.flatten.reverse: @unchecked) match {
      case h +: t => t.foldLeft(z(h))((b, a) => f(a, b))
    }
    override def foldLeft[A, B](fa: StrictTree[A], z: B)(f: (B, A) => B): B =
      fa.flatten.foldLeft(z)(f)
    override def foldMapLeft1[A, B](fa: StrictTree[A])(z: A => B)(f: (B, A) => B): B = fa.flatten match {
      case h +: t => t.foldLeft(z(h))(f)
    }
    override def foldMap[A, B](fa: StrictTree[A])(f: A => B)(implicit F: Monoid[B]): B = fa foldMap f

    //This implementation is 14x faster than the trampolined implementation for StrictTreeTestJVM's align test.
    override def alignWith[A, B, C](f: (\&/[A, B]) => C): (StrictTree[A], StrictTree[B]) => StrictTree[C] = {
      (a, b) =>
        import StrictTree.AlignStackElem
        val root = AlignStackElem[A, B, C](None, \&/(a, b))
        val stack = mutable.Stack(root)

        while (stack.nonEmpty) {
          val here = stack.head
          if (here.hasNext) {
            val nextChildren = here.next()
            val nextStackElem = AlignStackElem[A, B, C](Some(here), nextChildren)
            stack.push(nextStackElem)
          } else {
            //The "here" node is completed, so add its result to its parents completed children.
            val result = StrictTree[C](f(here.trees.bimap(_.rootLabel, _.rootLabel)), here.mappedSubForest.toVector)
            here.parent.foreach(_.mappedSubForest += result)
            stack.pop()
          }
        }

        StrictTree(f(root.trees.bimap(_.rootLabel, _.rootLabel)), root.mappedSubForest.toVector)
    }

    override def zip[A, B](a: => StrictTree[A], b: => StrictTree[B]): StrictTree[(A, B)] = {
      a.zip(b)
    }
  }

  implicit def treeEqual[A](implicit A0: Equal[A]): Equal[StrictTree[A]] =
    new StrictTreeEqual[A] { def A = A0 }

  implicit def treeOrder[A](implicit A0: Order[A]): Order[StrictTree[A]] =
    new Order[StrictTree[A]] with StrictTreeEqual[A] {
      override def A: Order[A] = A0
      import std.vector._
      override def order(x: StrictTree[A], y: StrictTree[A]) =
        A.order(x.rootLabel, y.rootLabel) match {
          case Ordering.EQ =>
            Order[Vector[StrictTree[A]]].order(x.subForest, y.subForest)
          case x => x
        }
    }



  /* TODO
  def applic[A, B](f: StrictTree[A => B]) = a => StrictTree.node((f.rootLabel)(a.rootLabel), implicitly[Applic[newtypes.ZipVector]].applic(f.subForest.map(applic[A, B](_)).?)(a.subForest ?).value)
   */
}

object StrictTree extends StrictTreeInstances {
  /**
   * Node represents a tree node that may have children.
   *
   * You can use Node for tree construction or pattern matching.
   */
  object Node {
    def apply[A](root: A, forest: Vector[StrictTree[A]]): StrictTree[A] = {
      StrictTree[A](root, forest)
    }

    def unapply[A](t: StrictTree[A]): Some[(A, Vector[StrictTree[A]])] = Some((t.rootLabel, t.subForest))
  }

  /**
   *  Leaf represents a tree node with no children.
   *
   *  You can use Leaf for tree construction or pattern matching.
   */
  object Leaf {
    def apply[A](root: A): StrictTree[A] = {
      Node(root, Vector.empty)
    }

    def unapply[A](t: StrictTree[A]): Option[A] = {
      t match {
        case Node(root, Vector()) =>
          Some(root)
        case _ =>
          None
      }
    }
  }

  def unfoldForest[A, B](s: Vector[A])(f: A => (B, Vector[A])): Vector[StrictTree[B]] =
    s.map(unfoldTree(_)(f))

  def unfoldTree[A, B](v: A)(f: A => (B, Vector[A])): StrictTree[B] =
    f(v) match {
      case (a, bs) => Node(a, unfoldForest(bs)(f))
    }

  //Only used for .equals.
  private def badEqInstance[A] = new StrictTreeEqual[A] {
    override def A: Equal[A] = (a1: A, a2: A) => a1 == a2
  }

  /**
    * This implementation is 16x faster than the trampolined implementation for StrictTreeTestJVM's scanr test.
    */
  private def scanrReducer[A, B](
    f: (A, Vector[StrictTree[B]]) => B
  )(rootLabel: A
  )(subForest: mutable.Buffer[StrictTree[B]]
  ): StrictTree[B] = {
    val subForestVector = subForest.toVector
    StrictTree[B](f(rootLabel, subForestVector), subForestVector)
  }

  /**
    * This implementation is 10x faster than mapTrampoline for StrictTreeTestJVM's map test.
    */
  private def mapReducer[A, B](
    f: A => B
  )(rootLabel: A
  )(subForest: scala.collection.Seq[StrictTree[B]]
  ): StrictTree[B] = {
    StrictTree[B](f(rootLabel), subForest.toVector)
  }

  /**
    * This implementation is 9x faster than flatMapTrampoline for StrictTreeTestJVM's flatMap test.
    */
  private def flatMapReducer[A, B](
    f: A => StrictTree[B]
  )(root: A
  )(subForest: scala.collection.Seq[StrictTree[B]]
  ): StrictTree[B] = {
    val StrictTree(rootLabel0, subForest0) = f(root)
    StrictTree(rootLabel0, subForest0 ++ subForest)
  }

  /**
    * This implementation is 9x faster than the trampolined implementation for StrictTreeTestJVM's foldMap test.
    */
  private def foldMapReducer[A, B: Monoid](
    f: A => B
  )(rootLabel: A
  )(subForest: mutable.Buffer[B]
  ): B = {
    val mappedRoot = f(rootLabel)
    val foldedForest = Foldable[Vector].fold[B](subForest.toVector)

    Monoid[B].append(mappedRoot, foldedForest)
  }

  private def hashCodeReducer[A](root: A)(subForest: scala.collection.Seq[Int]): Int = {
    root.hashCode ^ subForest.hashCode
  }

  private case class BottomUpStackElem[A, B](
    parent: Option[BottomUpStackElem[A, B]],
    tree: StrictTree[A]
  ) extends Iterator[StrictTree[A]] {
    private[this] val subIterator = tree.subForest.iterator

    def rootLabel = tree.rootLabel

    val mappedSubForest: mutable.Buffer[B] = mutable.Buffer.empty

    override def hasNext: Boolean = subIterator.hasNext

    override def next(): StrictTree[A] = subIterator.next()
  }

  private case class ZipStackElem[A, B](
    parent: Option[ZipStackElem[A, B]],
    a: StrictTree[A],
    b: StrictTree[B]
  ) extends Iterator[(StrictTree[A], StrictTree[B])] {
    private[this] val zippedSubIterator =
      a.subForest.iterator.zip(b.subForest.iterator)

    val mappedSubForest: mutable.Buffer[StrictTree[(A, B)]] = mutable.Buffer.empty

    override def hasNext: Boolean = zippedSubIterator.hasNext

    override def next(): (StrictTree[A], StrictTree[B]) = zippedSubIterator.next()
  }

  private[scalaz] case class AlignStackElem[A, B, C](
    parent: Option[AlignStackElem[A, B, C]],
    trees: \&/[StrictTree[A], StrictTree[B]]
  ) extends Iterator[\&/[StrictTree[A], StrictTree[B]]] {
    private[this] val iterators =
      trees.bimap(_.subForest.iterator, _.subForest.iterator)

    val mappedSubForest: mutable.Buffer[StrictTree[C]] = mutable.Buffer.empty

    def whichHasNext: \&/[Boolean, Boolean] =
      iterators.bimap(_.hasNext, _.hasNext)

    override def hasNext: Boolean =
      whichHasNext.fold(identity, identity, _ || _)

    override def next(): \&/[StrictTree[A], StrictTree[B]] =
      whichHasNext match {
        case \&/(true, true) =>
          iterators.bimap(_.next(), _.next())

        case \&/(true, false) | \&/.This(true) =>
          \&/.This(iterators.onlyThis.get.next())

        case \&/(false, true) | \&/.That(true) =>
          \&/.That(iterators.onlyThat.get.next())

        case _ =>
          throw new NoSuchElementException("reached iterator end")
      }
  }

  implicit def ToStrictTreeUnzip[A1, A2](root: StrictTree[(A1, A2)]): StrictTreeUnzip[A1, A2] =
    new StrictTreeUnzip[A1, A2](root)

}

private trait StrictTreeEqual[A] extends Equal[StrictTree[A]] {
  def A: Equal[A]

  private case class EqualStackElem(
    a: StrictTree[A],
    b: StrictTree[A]
  ) {
    val aSubIterator =
      a.subForest.iterator

    val bSubIterator =
      b.subForest.iterator
  }

  //This implementation is 4.5x faster than the trampolined implementation for StrictTreeTestJVM's equal test.
  override final def equal(a1: StrictTree[A], a2: StrictTree[A]): Boolean = {
    val root = EqualStackElem(a1, a2)
    val stack = mutable.Stack[EqualStackElem](root)

    while (stack.nonEmpty) {
      val here = stack.head
      if (A.equal(here.a.rootLabel, here.b.rootLabel)) {
        val aNext = here.aSubIterator.hasNext
        val bNext = here.bSubIterator.hasNext
        (aNext, bNext) match {
          case (true, true) =>
            val childA = here.aSubIterator.next()
            val childB = here.bSubIterator.next()
            val nextStackElem = EqualStackElem(childA, childB)
            stack.push(nextStackElem)
          case (false, false) =>
            stack.pop()
          case _ =>
            return false
        }
      } else return false
    }

    true
  }
}

final class StrictTreeUnzip[A1, A2](private val root: StrictTree[(A1, A2)]) extends AnyVal {
  private def unzipCombiner(rootLabel: (A1, A2))(accumulator: scala.collection.Seq[(StrictTree[A1], StrictTree[A2])]): (StrictTree[A1], StrictTree[A2]) = {
    (StrictTree(rootLabel._1, accumulator.map(_._1).toVector), StrictTree(rootLabel._2, accumulator.map(_._2).toVector))
  }

  /** Turns a tree of pairs into a pair of trees. */
  def unzip: (StrictTree[A1], StrictTree[A2]) = {
    root.runBottomUp[(StrictTree[A1], StrictTree[A2])](unzipCombiner)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy