All Downloads are FREE. Search and download functionalities are using the official Maven repository.

speed.impl.ConstantFolding.scala Maven / Gradle / Ivy

package speed
package impl

import scala.util.control.NonFatal

trait ConstantFolding { self: WithContext ⇒
  import c.universe._

  def finish(tree: Tree): Tree = c.resetLocalAttrs(foldConstants(c.resetLocalAttrs(tree)))

  def foldConstants(tree: Tree): Tree = {
    trace(s"Input to partially: $tree")
    new ConstantFolder().transform(tree)
  }

  /**
   * This constant folder has only a crude notion of lexical scopes, so be careful
   */
  class ConstantFolder(initEnv: Map[Name, Constant] = Map.empty) extends Transformer {
    var environmentStack = collection.immutable.Stack[Map[Name, Constant]](Map.empty)

    def createBinding(sym: Name, value: Constant): Unit = {
      val top = environmentStack.head
      assert(!top.contains(sym))
      environmentStack = environmentStack.pop.push(top + (sym -> value))
    }

    def envContains(sym: Name): Boolean = environmentStack.exists(_.contains(sym))
    def lookup(sym: Name): Constant = environmentStack.find(_.contains(sym)).get(sym)
    def pushContext() = environmentStack = environmentStack.push(Map.empty)
    def popContext() = environmentStack = environmentStack.pop

    override def transform(tree: Tree): Tree = {
      def binaryOp(e1: Tree, e2: Tree, op: TermName, ifOneIsConstant: Boolean = false)(calc: PartialFunction[(Any, Any), Any]) =
        (transform(e1), transform(e2)) match {
          case (Literal(Constant(a)), Literal(Constant(b))) ⇒ Literal(Constant(calc((a, b))))
          case (Literal(Constant(a)), b) if ifOneIsConstant && calc.isDefinedAt((a, b)) ⇒ Literal(Constant(calc((a, b))))
          case (a, Literal(Constant(b))) if ifOneIsConstant && calc.isDefinedAt((a, b)) ⇒ Literal(Constant(calc((a, b))))
          case (x1, x2) ⇒ q"$x1 $op $x2"
        }
      def unaryOp(e1: Tree, op: TermName)(calc: Any ⇒ Any) =
        transform(e1) match {
          case lit @ Literal(Constant(a)) ⇒ Literal(Constant(calc(a)))
          case x                          ⇒ q"$x.$op"
        }

      tree match {
        case q"$x: ($t @speed.dontfold)" ⇒
          trace(s"Matched type ascription annotation: $x")
          val x_1 = RemoveDontFold.transform(x)
          val t_1 = RemoveDontFold.transform(t)
          q"$x_1: $t_1"
        case q"$x: @speed.dontfold" ⇒
          trace(s"Matched plain ascription: $x")
          RemoveDontFold.transform(x)
        case q"$x: ${ tpe: TypeTree }" ⇒
          trace(s"Matched TypeTree ascription: $x $tpe ")
          tpe.original match {
            case tq"$t @speed.dontfold()" ⇒
              val x_1 = RemoveDontFold.transform(x)
              val t_1 = RemoveDontFold.transform(t)
              q"$x_1: $t_1"
            case t ⇒
              val inner = transform(x)
              q"$inner: $tpe"
          }

        case _: Block ⇒
          pushContext()
          val res = super.transform(tree)
          popContext()
          res match {
            case Block(Nil, l: Literal) ⇒ l
            case _                      ⇒ res
          }
        case v @ Ident(name) if envContains(name) ⇒
          trace(s"Replaced constant binding for $name")
          Literal(lookup(name))

        case v @ q"val $x = $expr" ⇒
          trace(s"Modifiers for $v: ${v.asInstanceOf[ValDef].mods} ${expr.productPrefix}")
          //trace(s"Trying to figure out value of $x ($expr), env has values for ${env.keys.mkString(", ")}")
          transform(expr) match {
            case lit @ Literal(constant) ⇒
              trace(s"Found literal binding for $x (${v.symbol}): $constant")
              createBinding(x, constant)
              q""
            case expr ⇒
              q"val $x = $expr"
          }
        case q"- $expr" ⇒
          unaryOp(expr, "unary_$minus") {
            case i: Int ⇒ (-i): Int
          }
        case q"$expr.toLong" ⇒
          unaryOp(expr, "toLong") {
            case i: Int ⇒ i.toLong
          }
        case q"$expr.toInt" ⇒
          unaryOp(expr, "toInt") {
            case i: Long ⇒ i.toInt
            case i: Int  ⇒ i
          }
        // There's a lots of code duplication coming. The reason is that
        // this is the easiest way of making sure that we exactly match
        // the primitive implicit conversions Scala is also doing.
        //
        // It would be possible to a much more general algorithm here:
        // for every binary operation that is pure
        //   * generate the tree representing the operation
        //   * evaluate the operation
        //   * insert the result
        case q"$e1 + $e2" ⇒
          binaryOp(e1, e2, newTermName("$plus")) {
            case (i1: Int, i2: Int)   ⇒ i1 + i2
            case (i1: Int, i2: Long)  ⇒ i1 + i2
            case (i1: Long, i2: Int)  ⇒ i1 + i2
            case (i1: Long, i2: Long) ⇒ i1 + i2
          }

        case q"$e1 - $e2" ⇒
          binaryOp(e1, e2, newTermName("$minus")) {
            case (i1: Int, i2: Int)   ⇒ i1 - i2
            case (i1: Int, i2: Long)  ⇒ i1 - i2
            case (i1: Long, i2: Int)  ⇒ i1 - i2
            case (i1: Long, i2: Long) ⇒ i1 - i2
          }
        case q"$e1 * $e2" ⇒
          binaryOp(e1, e2, newTermName("$times")) {
            case (i1: Int, i2: Int)   ⇒ i1 * i2
            case (i1: Int, i2: Long)  ⇒ i1 * i2
            case (i1: Long, i2: Int)  ⇒ i1 * i2
            case (i1: Long, i2: Long) ⇒ i1 * i2
          }
        case q"$e1 / $e2" ⇒
          binaryOp(e1, e2, newTermName("$div")) {
            case (i1: Int, i2: Int)   ⇒ i1 / i2
            case (i1: Int, i2: Long)  ⇒ i1 / i2
            case (i1: Long, i2: Int)  ⇒ i1 / i2
            case (i1: Long, i2: Long) ⇒ i1 / i2
          }
        case q"$e1 % $e2" ⇒
          binaryOp(e1, e2, newTermName("$percent"), ifOneIsConstant = true) {
            // this doesn't check for pureness of the first expression
            case (_, 1)               ⇒ 0
            case (_, 1L)              ⇒ 0L
            case (i1: Int, i2: Int)   ⇒ i1 % i2
            case (i1: Int, i2: Long)  ⇒ i1 % i2
            case (i1: Long, i2: Int)  ⇒ i1 % i2
            case (i1: Long, i2: Long) ⇒ i1 % i2
          }
        case q"$e1 == $e2" ⇒
          binaryOp(e1, e2, newTermName("$eq$eq")) {
            case (i1: Int, i2: Int)   ⇒ i1 == i2
            case (i1: Int, i2: Long)  ⇒ i1 == i2
            case (i1: Long, i2: Int)  ⇒ i1 == i2
            case (i1: Long, i2: Long) ⇒ i1 == i2
          }
        case q"$e1 > $e2" ⇒
          binaryOp(e1, e2, newTermName("$greater")) {
            case (i1: Int, i2: Int)   ⇒ i1 > i2
            case (i1: Int, i2: Long)  ⇒ i1 > i2
            case (i1: Long, i2: Int)  ⇒ i1 > i2
            case (i1: Long, i2: Long) ⇒ i1 > i2
          }
        case q"$e1 < $e2" ⇒
          binaryOp(e1, e2, newTermName("$less")) {
            case (i1: Int, i2: Int)   ⇒ i1 < i2
            case (i1: Int, i2: Long)  ⇒ i1 < i2
            case (i1: Long, i2: Int)  ⇒ i1 < i2
            case (i1: Long, i2: Long) ⇒ i1 < i2
          }
        case q"$e1 || $e2" ⇒
          // do a bit of peephole optimization
          (transform(e1), transform(e2)) match {
            case (Literal(Constant(true)), _)      ⇒ Literal(Constant(true))
            case (Literal(Constant(false)), other) ⇒ other

            // this one is wrong if the first operand does side-effects
            case (_, Literal(Constant(true)))      ⇒ Literal(Constant(true))
            case (other, Literal(Constant(false))) ⇒ other
            case (x1, x2)                          ⇒ q"$x1 || $x2"
          }
        case q"$e1 && $e2" ⇒
          // do a bit of peephole optimization
          (transform(e1), transform(e2)) match {
            case (Literal(Constant(true)), other) ⇒ other
            case (Literal(Constant(false)), _)    ⇒ Literal(Constant(false))

            // this one is wrong if the first operand does side-effects
            case (_, Literal(Constant(false)))    ⇒ Literal(Constant(false))
            case (other, Literal(Constant(true))) ⇒ other
            case (x1, x2)                         ⇒ q"$x1 && $x2"
          }
        case q"! $expr" ⇒
          unaryOp(expr, "unary_$bang") {
            case b: Boolean ⇒ !b
          }
        case q"if ($cond) $rawThenB else $rawElseB" ⇒
          def thenB = transform(rawThenB)
          def elseB = transform(rawElseB)
          transform(cond) match {
            case Literal(Constant(true))  ⇒ thenB
            case Literal(Constant(false)) ⇒ elseB
            case x                        ⇒ q"if ($x) $thenB else $elseB"
          }

        case m @ Match(selector, cases) ⇒
          transform(selector) match {
            case lit @ Literal(selectorValue) ⇒
              trace(s"Found literal binding for match selector: $selectorValue ($selector), trying to run match")
              val allConstant = cases.forall(cs ⇒ cs.pat match {
                case Literal(_) if (cs.guard == EmptyTree)          ⇒ true
                case Ident(nme.WILDCARD) if (cs.guard == EmptyTree) ⇒ true
                case _ if (cs.guard != EmptyTree) ⇒
                  c.warning(cs.guard.pos, s"Constant folding doesn't support guards right now ($m)")
                  false
                case _ ⇒ false
              })
              val lastPattern = c.universe.showRaw(cases.last.pat)
              trace(s"Only constant pattern branches: $allConstant: $lastPattern")

              if (allConstant) {
                val matchingCase =
                  cases.find(_.pat match {
                    case Literal(`selectorValue`) ⇒ true
                    case Ident(nme.WILDCARD)      ⇒ true
                    case _                        ⇒ false
                  }).get

                transform(matchingCase.body)
              } else super.transform(m)

            case sel ⇒
              trace(s"Got selector $sel")
              super.transform(m)
          }

        case x ⇒ super.transform(x)
      }
    }
  }

  /** Removes the annotation from nested trees */
  object RemoveDontFold extends Transformer {
    override def transform(tree: Tree): Tree = tree match {
      case tq"$t @speed.dontfold()"    ⇒ transform(t)
      case q"$x: ($t @speed.dontfold)" ⇒ q"$x: $t"
      case q"$x: ${ tpe: TypeTree }"   ⇒ q"$x: ${transform(tpe.original)}"
      case _                           ⇒ super.transform(tree)
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy