All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.cerst.auto_require.internal.Macros.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2020 Constantin Gerstberger
 *
 * SPDX-License-Identifier: Apache-2.0
 */

package com.github.cerst.auto_require.internal

import com.github.cerst.auto_require.{DisplayConfig, InfixOnlyOperators, SimpleNameSegments, StripTypeNameSuffix}
import com.github.ghik.silencer.silent

import scala.reflect.macros.blackbox

final class Macros(val c: blackbox.Context) {
  import Macros._
  import c.universe._

  case class Result(pretty: String, declarations: List[c.Tree], debugs: List[c.Tree], precedence: Int)

  // build sub-expressions by recursively descending the tree
  // to prevent repeated evaluation of sub-expressions, go bottom (leaf) up and store intermediate sub-expressions in variables (this is the transformation)
  @silent // pattern match variables are incorrect reported as never used
  private def process(tree: c.Tree)(implicit infixOnlyOperators: Boolean): Result = tree match {
    // method invocation
    case q"$recv.$method(..$args)" =>
      val recvResult =
        if (method == TermName("apply")) {
          // skip transforming the receiver of an apply method (e.g. value class constructor) as evaluating the former simply returns the object hash
          // the expression matching the constructor call itself likely cannot be optimized because we have to traverse its arguments to output their values
          Result(showCode(recv), List.empty, List.empty, Precedence.Object)
        } else {
          process(recv)
        }
      val (argsPrettys, argsDecls, argsDebugs, argsPrecedences) = args
        .map(arg => process(arg))
        .foldLeft {
          (List.empty[String], List.empty[c.Tree], List.empty[c.Tree], List.empty[Int])
        } {
          case ((accPrettys, accDecls, accDebugs, accPrecedences), result) =>
            (
              accPrettys :+ result.pretty,
              accDecls ++ result.declarations,
              result.debugs ++ accDebugs,
              accPrecedences :+ result.precedence
            )
        }
      val (precedence, pretty) = {
        val methodNiceName = method.asInstanceOf[TermName].decodedName.toString
        if (methodNiceName == "apply") {
          // hide apply method
          val precedence = Precedence.NonInfixMethod
          val pretty = recvResult.pretty + argsPrettys.mkString("(", ",", ")")
          (precedence, pretty)
        } else if (args.length == 1 && (!infixOnlyOperators || Operators.contains(methodNiceName))) {
          // use infix syntax
          val precedence = if (Operators contains methodNiceName) Precedence.InfixOperator else Precedence.InfixMethod
          val left =
            if (recvResult.precedence >= precedence) recvResult.pretty else "(" + recvResult.pretty + ")"
          val right =
            if (argsPrecedences.head >= precedence) argsPrettys.head else "(" + argsPrettys.head + ")"
          val pretty = left + " " + methodNiceName + " " + right
          (precedence, pretty)
        } else {
          // default method display
          val precedence = Precedence.NonInfixMethod
          val pretty = recvResult.pretty + "." + methodNiceName + argsPrettys.mkString("(", ",", ")")
          (precedence, pretty)
        }
      }
      Result(pretty, recvResult.declarations ++ argsDecls, (recvResult.debugs ++ argsDebugs), precedence)

    // supported unary methods
    case q"$recv.$unary" if unary == TermName("unary_$bang") =>
      val recvResult = process(recv)
      // precedence >= is incorrect here because '!!x' is not a valid expression
      val pretty =
        if (recvResult.precedence > Precedence.Not) "!" + recvResult.pretty else "!(" + recvResult.pretty + ")"
      Result(pretty, recvResult.declarations, recvResult.debugs, Precedence.Not)

    // field selection or unsupported unary method
    // in case of the former, de-constructing the receiver further does not work and shouldn't be needed
    case q"$recv.$fieldOrUnary" =>
      val evaluated = q"$recv.$fieldOrUnary"
      val pretty = showCode(evaluated)
      val (declaration, debug) = refDeclDebug(evaluated, pretty)
      Result(pretty, List(declaration), List(debug), Precedence.FieldSelection)

    // literal constants (e.g. '10')
    case Literal(_) =>
      val pretty = showCode(tree)
      // don't create declaration + debug for literals as this debug as '10 = 10'
      Result(pretty, List.empty, List.empty, Precedence.LiteralConstant)

    // plain identifier (e.g. 'itemId')
    case _: Ident =>
      val evaluated = tree
      val pretty = showCode(evaluated)
      val (declaration, debug) = refDeclDebug(tree, pretty)
      Result(pretty, List(declaration), List(debug), Precedence.Identifier)

    // leave the rest as is and don't derive sub-expressions (mostly because we don't know what it is)
    case _ =>
      Result(showCode(tree), List.empty, List.empty, Precedence.Other)
  }

  private def refDeclDebug(tree: Tree, pretty: String): (c.Tree, c.Tree) = {
    val name: c.TermName = c.freshName(TermName("temp"))
    val reference = q"$name"
    val declaration = q"val $name = $tree"
    val debug: c.Tree = q""" $pretty + " = " + $reference """

    (declaration, debug)
  }

  private def genResult[A: c.WeakTypeTag](expression: c.Tree,
                                          onError: c.Tree,
                                          onSuccess: c.Tree,
                                          overrides: Seq[c.Expr[DisplayConfig]]): c.Tree = {

    var infixOnlyOperators = true
    var simpleNameSegments = 1
    var stripTypeNameSuffix = true

    overrides.map(exp => c.eval(c.Expr(c.untypecheck(exp.tree.duplicate)))) foreach {
      case InfixOnlyOperators(value)                => infixOnlyOperators = value
      case SimpleNameSegments(value) if (value < 1) => c.abort(c.enclosingPosition, "SimpleNameSegment must be >= 1")
      case SimpleNameSegments(value)                => simpleNameSegments = value
      case StripTypeNameSuffix(value)               => stripTypeNameSuffix = value
    }

    val rawFullName = weakTypeTag[A].tpe.typeSymbol.fullName
    val simpleName = DeriveSimpleName(rawFullName, simpleNameSegments, stripTypeNameSuffix)
      .fold(left => c.abort(c.enclosingPosition, left), identity)
    val result = process(expression)(infixOnlyOperators)

    q"""
       if (!${expression}) {
          ..${result.declarations}
          val printedValues = List(..${result.debugs})
          val msg = "Requirement failed for '" + $simpleName + "': '" + ${result.pretty} + "' " + printedValues.mkString("{ ", ", ", " }")
          $onError
        } else {
          $onSuccess
        }
     """
  }

  def autoRequire[A: c.WeakTypeTag](expression: c.Tree, overrides: c.Expr[DisplayConfig]*): c.Tree = {
    val onError =
      q"""
         import scala.util.control.NoStackTrace
         throw new IllegalArgumentException(msg) with NoStackTrace
       """
    val onSuccess = q"""()"""
    genResult(expression, onError, onSuccess, overrides)
  }

  def autoRequireEither[A: c.WeakTypeTag](expression: c.Tree, overrides: c.Expr[DisplayConfig]*): c.Tree = {
    val onError: c.Tree = q"""Left { msg }"""
    val onSuccess: c.Tree = q"""Right { () }"""
    genResult(expression, onError, onSuccess, overrides)
  }

}

object Macros {

  val Indentation = "  "
  val Operators = Set("<", "<=", ">", ">=", "&&", "||")

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy