
dotty.tools.dotc.evaluation.ExtractExpression.scala Maven / Gradle / Ivy
package dotty.tools.dotc.evaluation
import dotty.tools.dotc.ExpressionContext
import dotty.tools.dotc.ast.tpd.*
import dotty.tools.dotc.core.Constants.Constant
import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.Flags.*
import dotty.tools.dotc.core.Names.*
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.transform.SymUtils.*
import dotty.tools.dotc.core.DenotTransformers.DenotTransformer
import dotty.tools.dotc.core.Denotations.SingleDenotation
import dotty.tools.dotc.core.SymDenotations.SymDenotation
import dotty.tools.dotc.transform.MacroTransform
import dotty.tools.dotc.core.Phases.*
import dotty.tools.dotc.report
import dotty.tools.dotc.util.SrcPos
class ExtractExpression(using exprCtx: ExpressionContext) extends MacroTransform with DenotTransformer:
override def phaseName: String = ExtractExpression.name
/**
* Change the return type of the `evaluate` method
* and update the owner and types of the symDenotations inserted into `evaluate`.
*/
override def transform(ref: SingleDenotation)(using
Context
): SingleDenotation =
ref match
case ref: SymDenotation if isExpressionVal(ref.symbol.maybeOwner) =>
// update owner of the symDenotation, e.g. local vals
// that are extracted out of the expression val to the evaluate method
ref.copySymDenotation(owner = exprCtx.evaluateMethod)
case _ =>
ref
override def transformPhase(using Context): Phase = this.next
override protected def newTransformer(using Context): Transformer =
new Transformer:
var expressionTree: Tree = _
override def transform(tree: Tree)(using Context): Tree =
tree match
case PackageDef(pid, stats) =>
val evaluationClassDef =
stats.find(_.symbol == exprCtx.expressionClass)
val others = stats.filter(_.symbol != exprCtx.expressionClass)
val transformedStats = (others ++ evaluationClassDef).map(transform)
PackageDef(pid, transformedStats)
case tree: ValDef if isExpressionVal(tree.symbol) =>
expressionTree = tree.rhs
exprCtx.store(tree.symbol)
unitLiteral
case tree: DefDef if tree.symbol == exprCtx.evaluateMethod =>
val transformedExpr =
ExpressionTransformer.transform(expressionTree)
cpy.DefDef(tree)(rhs = transformedExpr)
case tree =>
super.transform(tree)
private object ExpressionTransformer extends TreeMap:
override def transform(tree: Tree)(using Context): Tree =
val desugaredIdent = tree match
case tree: Ident => desugarIdent(tree)
case _ => tree
desugaredIdent match
case tree: ImportOrExport => tree
// static object
case tree: (Ident | Select) if isStaticObject(tree.symbol) =>
getStaticObject(tree)(tree.symbol.moduleClass)
// non-static object
case tree: (Ident | Select) if isInaccessibleNonStaticObject(tree.symbol) =>
val qualifier = getTransformedQualifier(tree)
callMethod(tree)(qualifier, tree.symbol.asTerm, List.empty)
// local variable
case tree: Ident if isLocalVariable(tree.symbol) =>
if tree.symbol.is(Lazy) then
report.error(s"Evaluation of local lazy val not supported", tree.srcPos)
tree
else
getCapturer(tree.symbol.asTerm) match
case Some(capturer) =>
if capturer.isClass then getClassCapture(tree)(tree.symbol, capturer.asClass)
else getMethodCapture(tree)(tree.symbol, capturer.asTerm)
case None => getLocalValue(tree)(tree.symbol)
// assignement to local variable
case tree @ Assign(lhs, _) if isLocalVariable(lhs.symbol) =>
val variable = lhs.symbol.asTerm
val rhs = transform(tree.rhs)
getCapturer(variable) match
case Some(capturer) =>
if capturer.isClass then setClassCapture(tree)(variable, capturer.asClass, rhs)
else setMethodCapture(tree)(variable, capturer.asTerm, rhs)
case None => setLocalValue(tree)(variable, rhs)
// inaccessible fields
case tree: Select if isInaccessibleField(tree.symbol) =>
val qualifier = getTransformedQualifier(tree)
getField(tree)(qualifier, tree.symbol.asTerm)
// assignment to inaccessible fields
case tree @ Assign(lhs, rhs) if isInaccessibleField(lhs.symbol) =>
val qualifier = getTransformedQualifier(lhs)
setField(tree)(qualifier, lhs.symbol.asTerm, transform(rhs))
// this or outer this
case tree @ This(Ident(name)) if !isOwnedByExpression(tree.symbol) =>
thisOrOuterValue(tree)(tree.symbol.enclosingClass.asClass)
// inaccessible constructors
case tree: (Select | Apply | TypeApply) if isInaccessibleConstructor(tree.symbol) =>
val args = getTransformedArgs(tree)
val qualifier = getTransformedQualifierOfNew(tree)
callConstructor(tree)(qualifier, tree.symbol.asTerm, args)
// inaccessible methods
case tree: (Ident | Select | Apply | TypeApply) if isInaccessibleMethod(tree.symbol) =>
val args = getTransformedArgs(tree)
val qualifier = getTransformedQualifier(tree)
callMethod(tree)(qualifier, tree.symbol.asTerm, args)
case Typed(tree, tpt) if tpt.symbol.isType && !isTypeAccessible(tpt.symbol.asType) =>
transform(tree)
case tree =>
super.transform(tree)
private def getCapturer(variable: TermSymbol)(using
Context
): Option[Symbol] =
// a local variable can be captured by a class or method
val candidates = exprCtx.expressionSymbol.ownersIterator
.takeWhile(_ != variable.owner)
.filter(s => s.isClass || s.is(Method))
.toSeq
candidates
.findLast(_.isClass)
.orElse(candidates.find(_.is(Method)))
private def getTransformedArgs(tree: Tree)(using Context): List[Tree] =
tree match
case _: (Ident | Select) => List.empty
case Apply(fun, args) => getTransformedArgs(fun) ++ args.map(transform)
case TypeApply(fun, _) => getTransformedArgs(fun)
private def getTransformedQualifier(tree: Tree)(using Context): Tree =
tree match
case Ident(_) =>
val classOwner = tree.symbol.enclosingClass.asClass
if isStaticObject(classOwner)
then getStaticObject(tree)(classOwner)
else thisOrOuterValue(tree)(classOwner)
case Select(qualifier, _) =>
val classOwner = tree.symbol.enclosingClass.asClass
if isStaticObject(classOwner)
then getStaticObject(tree)(classOwner)
else transform(qualifier)
case Apply(fun, _) => getTransformedQualifier(fun)
case TypeApply(fun, _) => getTransformedQualifier(fun)
private def getTransformedQualifierOfNew(tree: Tree)(using Context): Tree =
tree match
case Select(New(tpt), _) => getTransformedPrefix(tpt)
case Apply(fun, _) => getTransformedQualifierOfNew(fun)
case TypeApply(fun, _) => getTransformedQualifierOfNew(fun)
private def getTransformedPrefix(typeTree: Tree)(using Context): Tree =
typeTree match
case Ident(_) =>
thisOrOuterValue(typeTree)(
typeTree.symbol.owner.enclosingClass.asClass
)
case Select(qualifier, _) => transform(qualifier)
case AppliedTypeTree(tpt, _) => getTransformedPrefix(tpt)
end ExpressionTransformer
private def isExpressionVal(sym: Symbol)(using Context): Boolean =
sym.name == exprCtx.expressionTermName
// symbol can be a class or a method
private def thisOrOuterValue(tree: Tree)(cls: ClassSymbol)(using
Context
): Tree =
reportErrorIfLocalInsideValueClass(
exprCtx.expressionSymbol.owner,
tree.srcPos
)
val ths = getThis(tree)(exprCtx.classOwners.head)
val target = exprCtx.classOwners.indexOf(cls)
if target >= 0 then
exprCtx.classOwners
.drop(1)
.take(target)
.foldLeft(ths) { (innerObj, outerSym) =>
getOuter(tree)(innerObj, outerSym)
}
else nullLiteral
private def getThis(tree: Tree)(cls: ClassSymbol)(using Context): Tree =
reflectEval(tree)(
nullLiteral,
EvaluationStrategy.This(cls),
List.empty,
exprCtx.classOwners.head.typeRef
)
private def getOuter(
tree: Tree
)(qualifier: Tree, outerCls: ClassSymbol)(using
Context
): Tree =
val strategy = EvaluationStrategy.Outer(outerCls)
reflectEval(tree)(qualifier, strategy, List.empty, outerCls.typeRef)
private def getLocalValue(tree: Tree)(variable: Symbol)(using Context): Tree =
val isByName = isByNameParam(variable.info)
val strategy = EvaluationStrategy.LocalValue(variable.asTerm, isByName)
reflectEval(tree)(nullLiteral, strategy, List.empty, tree.tpe)
private def isByNameParam(tpe: Type)(using Context): Boolean =
tpe match
case _: ExprType => true
case ref: TermRef => isByNameParam(ref.symbol.info)
case _ => false
private def setLocalValue(tree: Tree)(
variable: Symbol,
rhs: Tree
)(using Context): Tree =
val strategy = EvaluationStrategy.LocalValueAssign(variable.asTerm)
reflectEval(tree)(nullLiteral, strategy, List(rhs), tree.tpe)
private def getClassCapture(tree: Tree)(variable: Symbol, cls: ClassSymbol)(using
Context
): Tree =
reportErrorIfLocalInsideValueClass(cls, tree.srcPos)
val byName = isByNameParam(variable.info)
val strategy = EvaluationStrategy.ClassCapture(variable.asTerm, cls, byName)
val qualifier = thisOrOuterValue(tree)(cls)
reflectEval(tree)(qualifier, strategy, List.empty, tree.tpe)
private def setClassCapture(
tree: Tree
)(variable: Symbol, cls: ClassSymbol, value: Tree)(using Context) =
reportErrorIfLocalInsideValueClass(cls, tree.srcPos)
val strategy = EvaluationStrategy.ClassCaptureAssign(variable.asTerm, cls)
val qualifier = thisOrOuterValue(tree)(cls)
reflectEval(tree)(qualifier, strategy, List(value), tree.tpe)
private def getMethodCapture(
tree: Tree
)(variable: Symbol, method: TermSymbol)(using
Context
): Tree =
reportErrorIfLocalInsideValueClass(method, tree.srcPos)
val isByName = isByNameParam(variable.info)
val strategy =
EvaluationStrategy.MethodCapture(variable.asTerm, method.asTerm, isByName)
reflectEval(tree)(nullLiteral, strategy, List.empty, tree.tpe)
private def setMethodCapture(
tree: Tree
)(variable: Symbol, method: Symbol, value: Tree)(using Context) =
reportErrorIfLocalInsideValueClass(method, tree.srcPos)
val strategy =
EvaluationStrategy.MethodCaptureAssign(variable.asTerm, method.asTerm)
reflectEval(tree)(nullLiteral, strategy, List(value), tree.tpe)
private def getStaticObject(
tree: Tree
)(obj: Symbol)(using ctx: Context): Tree =
val strategy = EvaluationStrategy.StaticObject(obj.asClass)
reflectEval(tree)(nullLiteral, strategy, List.empty, obj.typeRef)
private def getField(
tree: Tree
)(qualifier: Tree, field: TermSymbol)(using
Context
): Tree =
reportErrorIfLocalInsideValueClass(field, tree.srcPos)
val byName = isByNameParam(field.info)
val strategy = EvaluationStrategy.Field(field, byName)
reflectEval(tree)(qualifier, strategy, List.empty, tree.tpe)
private def setField(tree: Tree)(
qualifier: Tree,
field: TermSymbol,
rhs: Tree
)(using
Context
): Tree =
reportErrorIfLocalInsideValueClass(field, tree.srcPos)
val strategy = EvaluationStrategy.FieldAssign(field)
reflectEval(tree)(qualifier, strategy, List(rhs), tree.tpe)
private def callMethod(tree: Tree)(
qualifier: Tree,
method: TermSymbol,
args: List[Tree]
)(using Context): Tree =
reportErrorIfLocalInsideValueClass(method, tree.srcPos)
val strategy = EvaluationStrategy.MethodCall(method)
reflectEval(tree)(qualifier, strategy, args, tree.tpe)
private def callConstructor(tree: Tree)(
qualifier: Tree,
ctr: TermSymbol,
args: List[Tree]
)(using Context): Tree =
reportErrorIfLocalInsideValueClass(ctr, tree.srcPos)
val strategy = EvaluationStrategy.ConstructorCall(ctr, ctr.owner.asClass)
reflectEval(tree)(qualifier, strategy, args, tree.tpe)
private def reflectEval(tree: Tree)(
qualifier: Tree,
strategy: EvaluationStrategy,
args: List[Tree],
tpe: Type
)(using
Context
): Tree =
val reflectEval =
cpy.Apply(tree)(
Select(This(exprCtx.expressionClass), termName("reflectEval")),
List(
qualifier,
Literal(Constant(strategy.toString)),
JavaSeqLiteral(args, TypeTree(ctx.definitions.ObjectType))
)
)
reflectEval.putAttachment(EvaluationStrategy, strategy)
val widenDealiasTpe = tpe.widenDealias
if isTypeAccessible(widenDealiasTpe.typeSymbol.asType)
then reflectEval.cast(widenDealiasTpe)
else reflectEval
/**
* In the [[ResolveReflectEval]] phase we cannot find the symbol of a local method
* or local class inside a value class. So we report an error early.
*/
private def reportErrorIfLocalInsideValueClass(
symbol: Symbol,
srcPos: SrcPos
)(using Context): Unit =
for
localClassOrMethod <- symbol.ownersIterator
.find { sym =>
(sym.isClass || sym.is(Method)) && sym.enclosure.is(Method)
}
valueClass <- localClassOrMethod.ownersIterator.find(_.isValueClass)
do
report.error(
s"""|Evaluation involving a local method or local class in a value class not supported:
|$symbol belongs to $localClassOrMethod which is local inside value $valueClass""".stripMargin,
srcPos
)
private def isStaticObject(symbol: Symbol)(using Context): Boolean =
symbol.is(Module) &&
symbol.isStatic &&
!symbol.is(JavaDefined) &&
!symbol.isRoot
private def isInaccessibleNonStaticObject(symbol: Symbol)(using
Context
): Boolean =
symbol.is(Module) &&
!symbol.isStatic &&
!symbol.isRoot &&
!isOwnedByExpression(symbol)
/**
* The symbol is a field and the expression class cannot access it
* either because it is private or it belongs to an inacessible type
*/
private def isInaccessibleField(symbol: Symbol)(using Context): Boolean =
symbol.isField &&
symbol.owner.isType &&
!isTermAccessible(symbol.asTerm, symbol.owner.asType)
/**
* The symbol is a real method and the expression class cannot access it
* either because it is private or it belongs to an inaccessible type
*/
private def isInaccessibleMethod(symbol: Symbol)(using Context): Boolean =
!isOwnedByExpression(symbol) &&
symbol.isRealMethod &&
(!symbol.owner.isType || !isTermAccessible(
symbol.asTerm,
symbol.owner.asType
))
/**
* The symbol is a constructor and the expression class cannot access it
* either because it is an inaccessible method or it belong to a nested type (not static)
*/
private def isInaccessibleConstructor(symbol: Symbol)(using
Context
): Boolean =
!isOwnedByExpression(symbol) &&
symbol.isConstructor &&
(isInaccessibleMethod(symbol) || !symbol.owner.isStatic)
private def isLocalVariable(symbol: Symbol)(using Context): Boolean =
!symbol.is(Method) && symbol.isLocalToBlock && !isOwnedByExpression(symbol)
// Check if a term is accessible from the expression class
private def isTermAccessible(symbol: TermSymbol, owner: TypeSymbol)(using
Context
): Boolean =
isOwnedByExpression(symbol) || (
!symbol.isPrivate && isTypeAccessible(owner)
)
// Check if a type is accessible from the expression class
private def isTypeAccessible(symbol: TypeSymbol)(using Context): Boolean =
isOwnedByExpression(symbol) || (
!symbol.isLocal &&
symbol.ownersIterator.forall(s => s.isPublic || s.privateWithin.is(PackageClass))
)
private def isOwnedByExpression(symbol: Symbol)(using Context): Boolean =
val evaluateMethod = exprCtx.evaluateMethod
symbol.ownersIterator.exists(_ == evaluateMethod)
object ExtractExpression:
val name: String = "extract-expression"
© 2015 - 2025 Weber Informatics LLC | Privacy Policy