All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.softwaremill.macmemo.memoizeMacro.scala Maven / Gradle / Ivy

The newest version!
package com.softwaremill.macmemo

import scala.concurrent.duration.FiniteDuration
import scala.reflect.macros._

object memoizeMacro {
  private val debug = new Debug()

  def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
    import c.universe._

    case class MacroArgs(maxSize: Long, expireAfter: FiniteDuration, concurrencyLevel: Option[Int] = None)

    case class MemoIdentifier(methodName: TermName, generatedMemoValName: TermName)

    def reportInvalidAnnotationTarget(): Unit = {
      c.error(c.enclosingPosition, "This annotation can only be used on methods")
    }

    def prepareInjectedBody(cachedMethodId: MemoIdentifier, valDefs: List[List[ValDef]], bodyTree: Tree, returnTypeTree: Tree): c.type#Tree = {
      val names = valDefs.flatten.map(_.name)
      q"""
      def callRealBody(): $returnTypeTree = { $bodyTree }
      if (System.getProperty("macmemo.disable") != null) {
        callRealBody()
      }
      else {
        ${cachedMethodId.generatedMemoValName}.get($names, {
          List(
            callRealBody()
          )
        }).head.asInstanceOf[$returnTypeTree]
      }"""
    }

    def createMemoVal(cachedMethodId: MemoIdentifier, returnTypeTree: Tree, macroArgs: MacroArgs): c.type#Tree = {

      def buildCacheBucketId: Tree = {
        val enclosingClassSymbol = c.internal.enclosingOwner
        val enclosureFullName = enclosingClassSymbol.fullName + (if (enclosingClassSymbol.isModuleClass) "$." else ".")
        Literal(Constant(
           enclosureFullName + cachedMethodId.methodName.toString))
      }

      def buildParams: Tree = {
        val maxSize = macroArgs.maxSize
        val ttl = macroArgs.expireAfter
        val concurrencyLevelOpt = macroArgs.concurrencyLevel
        q"""com.softwaremill.macmemo.MemoizeParams($maxSize, ${ttl.toMillis}, $concurrencyLevelOpt)"""
      }

      q"""lazy val ${cachedMethodId.generatedMemoValName}: com.softwaremill.macmemo.Cache[List[Any]] =
         com.softwaremill.macmemo.BuilderResolver.resolve($buildCacheBucketId).build($buildCacheBucketId, $buildParams)"""

    }

    def injectCacheUsage(cachedMethodId: MemoIdentifier, function: DefDef) = {
      val DefDef(mods, name, tparams, valDefs, returnTypeTree, bodyTree) = function
      val injectedBody = prepareInjectedBody(cachedMethodId, valDefs, bodyTree, returnTypeTree)
      DefDef(mods, name, tparams, valDefs, returnTypeTree, injectedBody)
    }

    def extractMacroArgs(application: Tree) = {
      debug(s"RAW application = ${reflect.runtime.universe.showRaw(application)}")
      val argsTree = application.children.head.children.head.children
      val maxSize = extractMaxSize(argsTree(1))
      val ttl = extractTtl(argsTree(2))
      val concurrencyLevelOpt = argsTree match {
        case List(_, _, _, concurrencyLevelTree) => extractConcurrencyLevel(concurrencyLevelTree)
        case _ => None
      }
      val args = MacroArgs(maxSize, ttl, concurrencyLevelOpt)
      debug(s"Macro args: $args")
      args
    }

    def extractMaxSize(tree: Tree) = {
      tree match {
        case q"maxSize=$x" => evalLongExpr(x)
        case _ => evalLongExpr(tree)
      }
    }

    def evalLongExpr(tree: Tree) = {
      val length: Any = c.eval(c.Expr(tree))
      length match {
        case intLength: Int => intLength.toLong
        case longLength: Long => longLength
      }
    }

    def extractTtl(tree: Tree) = {
      tree match {
        case q"expiresAfter=$x" => evalFiniteDurationExpr(x)
        case _ => evalFiniteDurationExpr(tree)
      }
    }

    def evalFiniteDurationExpr(tree: Tree) = {
      val newTree = q"import scala.concurrent.duration._; $tree"
      val dur: FiniteDuration = c.eval(c.Expr(newTree))
      dur

    }

    def extractConcurrencyLevel(tree: Tree) = {
      tree match {
        case q"concurrencyLevel=$x" => evalOptionInt(x)
        case _ => evalOptionInt(tree)
      }
    }

    def evalOptionInt(tree: Tree) = {
      val value: Option[Int] = c.eval(c.Expr(tree))
      value
    }
    val inputs = annottees.map(_.tree).toList
    val (_, expandees) = inputs match {
      case (functionDefinition: DefDef) :: rest =>
        debug(s"Found annotated function [${functionDefinition.name}]")
        val DefDef(_, name: TermName, _, _, returnTypeTree, _) = functionDefinition
        val cachedMethodIdentifier = MemoIdentifier(name, TermName(c.freshName(s"memo_${name}_")))
        val macroArgs = extractMacroArgs(c.macroApplication)
        val memoVal = createMemoVal(cachedMethodIdentifier, returnTypeTree, macroArgs)
        val newFunctionDef = injectCacheUsage(cachedMethodIdentifier, functionDefinition)
        (functionDefinition, (newFunctionDef :: rest) :+ memoVal)
      case _ => reportInvalidAnnotationTarget(); (EmptyTree, inputs)
    }

    debug(s"final method = ${show(expandees)}")

    c.Expr[Any](Block(expandees, Literal(Constant(()))))
  }

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy