All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tambapps.marcel.semantic.transform.CachedAstTransformation.kt Maven / Gradle / Ivy

The newest version!
package com.tambapps.marcel.semantic.transform

import com.tambapps.marcel.parser.cst.CstNode
import com.tambapps.marcel.parser.cst.MethodCstNode
import com.tambapps.marcel.semantic.SemanticPurpose
import com.tambapps.marcel.semantic.SemanticHelper
import com.tambapps.marcel.semantic.Visibility
import com.tambapps.marcel.semantic.ast.AnnotationNode
import com.tambapps.marcel.semantic.ast.AstNode
import com.tambapps.marcel.semantic.ast.ClassNode
import com.tambapps.marcel.semantic.ast.MethodNode
import com.tambapps.marcel.semantic.ast.expression.ExpressionNode
import com.tambapps.marcel.semantic.ast.statement.ExpressionStatementNode
import com.tambapps.marcel.semantic.exception.MarcelSyntaxTreeTransformationException
import com.tambapps.marcel.semantic.extensions.javaType

import com.tambapps.marcel.semantic.method.JavaMethod
import com.tambapps.marcel.semantic.type.JavaType
import com.tambapps.marcel.semantic.type.SourceJavaType
import java.util.concurrent.ConcurrentHashMap

/**
 * AST Transformations caching results of the annotated method
 */
class CachedAstTransformation : GenerateMethodAstTransformation() {

  companion object {
    const val THREAD_SAFE_OPTION = "threadSafe"
  }

  override fun generateSignatures(
    node: CstNode,
    javaType: SourceJavaType,
    annotation: AnnotationNode
  ): List {
    if ((node as? MethodCstNode)?.isAsync == true) {
      throw MarcelSyntaxTreeTransformationException(this, node.token, "Cannot cache async methods")
    }
    // as the generated method is private and only called for this caching usecase, no need to define it here
    //  (it would be anyway hard to do so as we don't have the AST node here)
    return emptyList()
  }

  /**
   * Generates a new method doCompute whose body is the body of the original annotated method and modifies the original method
   * so that it uses the cache and call doCompute if the result wasn't already computed.
   *
   * When the 'threadSafe' option is not provided we use a simple HashMap. When this option is enabled, we use a ConcurrentHashMap
   * and call map.computeIfAbsent(key, (k) -> doCompute(k)).
   *
   * In case the method has multiple parameter, the key becomes the list of all those parameters
   */
  override fun generateMethodNodes(node: AstNode, classNode: ClassNode, annotation: AnnotationNode): List {
    val originalMethod = node as MethodNode
    if (originalMethod.isConstructor) throw MarcelSyntaxTreeTransformationException(
      this,
      node.token,
      "Annotated method is a constructor"
    )
    if (originalMethod.returnType == JavaType.void) throw MarcelSyntaxTreeTransformationException(
      this,
      node.token,
      "Cannot cache void methods"
    )
    if (originalMethod.parameters.isEmpty()) throw MarcelSyntaxTreeTransformationException(
      this,
      node.token,
      "Cannot cache methods with no parameters"
    )
    val newMethodName = "doCompute" + originalMethod.name.first().uppercase() + originalMethod.name.substring(1)

    val doComputeMethod = methodNode(
      visibility = Visibility.INTERNAL,
      name = newMethodName,
      parameters = originalMethod.parameters,
      returnType = originalMethod.returnType
    ) {
      addAllStmt(originalMethod.blockStatement.statements)
    }

    val threadSafe = annotation.getAttribute(THREAD_SAFE_OPTION)?.value == true
    val cacheExpr = getCacheExpression(classNode, originalMethod, threadSafe)

    // rewrite the original method
    originalMethod.blockStatement.statements.clear()
    addStatements(originalMethod) {
      val isMultiParams = originalMethod.parameters.size > 1
      val cacheKeyRef = if (!isMultiParams) argRef(0)
      else {
        val lv = currentMethodScope.addLocalVariable(List::class.javaType)
        varAssignStmt(
          lv,
          cast(
            array(JavaType.objectArray, originalMethod.parameters.map { cast(ref(it), JavaType.Object) }),
            List::class.javaType
          )
        )
        ref(lv)
      }
      val cacheKeyObjectRef = cast(cacheKeyRef, JavaType.Object)
      if (threadSafe) {
        // generate code cache.computeIfAbsent(key, k -> doCompute(k))
        val ciaLambda = newLambda(
          classNode,
          parameters = listOf(parameter(cacheKeyObjectRef.type, "param")),
          returnType = doComputeMethod.returnType.objectType,
          interfaceType = java.util.function.Function::class.javaType,
        ) {
          val doComputeMethodCall = if (isMultiParams) {
            val listLv = currentMethodScope.addLocalVariable(List::class.javaType)
            varAssignStmt(listLv, argRef(0))

            fCall(owner = outerRef(), method = doComputeMethod,
              arguments = doComputeMethod.parameters.mapIndexed { index, methodParameter ->
                cast(
                  fCall(
                    owner = ref(
                      listLv
                    ), name = "get", arguments = listOf(int(index))
                  ), methodParameter.type
                )
              })
          } else fCall(
            owner = outerRef(),
            method = doComputeMethod,
            arguments = listOf(cast(argRef(0), cacheKeyRef.type))
          )
          returnStmt(doComputeMethodCall)
        }
        returnStmt(
          fCall(
            owner = cacheExpr,
            name = "computeIfAbsent",
            arguments = listOf(cacheKeyObjectRef, ciaLambda)
          )
        )
      } else {
        // generate code if (!cache.containsKey(key)) cache[key] = doCompute(key); return cache[key]
        ifStmt(notExpr(fCall(owner = cacheExpr, name = "containsKey", arguments = listOf(cacheKeyObjectRef)))) {
          stmt(
            fCall(
              owner = cacheExpr, name = "put",
              arguments = listOf(
                cacheKeyObjectRef,
                cast(
                  fCall(
                    owner = thisRef(),
                    method = doComputeMethod,
                    arguments = originalMethod.parameters.map { ref(it) }), JavaType.Object
                )
              )
            )
          )
        }
        returnStmt(fCall(owner = cacheExpr, name = "get", arguments = listOf(cacheKeyObjectRef)))
      }
    }
    return listOf(doComputeMethod)
  }

  private fun getCacheExpression(
    classNode: ClassNode,
    originalMethod: MethodNode,
    threadSafe: Boolean
  ): ExpressionNode {
    val fieldName = originalMethod.name + "\$cache"
    val cacheInitValueExpr = constructorCall(
      method = symbolResolver.findMethodOrThrow(
        if (threadSafe) ConcurrentHashMap::class.javaType
        else HashMap::class.javaType, JavaMethod.CONSTRUCTOR_NAME, emptyList()
      ), arguments = emptyList()
    )
    if (purpose == SemanticPurpose.REPL && classNode.type.isScript) {
      // init field in constructor
      classNode.constructors.forEach {
        SemanticHelper.addStatementLast(
          ExpressionStatementNode(fCall("setVariable", listOf(string(fieldName), cacheInitValueExpr), thisRef())),
          it.blockStatement
        )
      }
      return fCall(name = "getVariable", arguments = listOf(string(fieldName)), owner = thisRef(), castType = java.util.Map::class.javaType)
    } else {
      val cacheField = fieldNode(Map::class.javaType, fieldName)
      addField(classNode, cacheField, cacheInitValueExpr)
      return ref(cacheField)
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy