org.codehaus.groovy.transform.tailrec.TailRecursiveASTTransformation.groovy Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of groovy-all Show documentation
Show all versions of groovy-all Show documentation
Groovy: A powerful, dynamic language for the JVM
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.codehaus.groovy.transform.tailrec
import groovy.transform.CompileStatic
import groovy.transform.Memoized
import groovy.transform.TailRecursive
import org.codehaus.groovy.ast.*
import org.codehaus.groovy.ast.expr.*
import org.codehaus.groovy.ast.stmt.BlockStatement
import org.codehaus.groovy.ast.stmt.ReturnStatement
import org.codehaus.groovy.ast.stmt.Statement
import org.codehaus.groovy.classgen.ReturnAdder
import org.codehaus.groovy.classgen.VariableScopeVisitor
import org.codehaus.groovy.control.CompilePhase
import org.codehaus.groovy.control.SourceUnit
import org.codehaus.groovy.transform.AbstractASTTransformation
import org.codehaus.groovy.transform.GroovyASTTransformation
/**
* Handles generation of code for the @TailRecursive annotation.
*
* It's doing its work in the earliest possible compile phase
*
* @author Johannes Link
*/
@CompileStatic
@GroovyASTTransformation(phase = CompilePhase.SEMANTIC_ANALYSIS)
class TailRecursiveASTTransformation extends AbstractASTTransformation {
private static final Class MY_CLASS = TailRecursive.class;
private static final ClassNode MY_TYPE = new ClassNode(MY_CLASS);
static final String MY_TYPE_NAME = "@" + MY_TYPE.getNameWithoutPackage()
private HasRecursiveCalls hasRecursiveCalls = new HasRecursiveCalls()
private TernaryToIfStatementConverter ternaryToIfStatement = new TernaryToIfStatementConverter()
@Override
public void visit(ASTNode[] nodes, SourceUnit source) {
init(nodes, source);
MethodNode method = nodes[1] as MethodNode
if (method.isAbstract()) {
addError("Annotation " + MY_TYPE_NAME + " cannot be used for abstract methods.", method);
return;
}
if (hasAnnotation(method, ClassHelper.make(Memoized))) {
ClassNode memoizedClassNode = ClassHelper.make(Memoized)
for (AnnotationNode annotationNode in method.annotations) {
if (annotationNode.classNode == MY_TYPE)
break
if (annotationNode.classNode == memoizedClassNode) {
addError("Annotation " + MY_TYPE_NAME + " must be placed before annotation @Memoized.", annotationNode)
return
}
}
}
if (!hasRecursiveMethodCalls(method)) {
AnnotationNode annotationNode = method.getAnnotations(ClassHelper.make(TailRecursive))[0]
addError("No recursive calls detected. You must remove annotation " + MY_TYPE_NAME + ".", annotationNode)
return;
}
transformToIteration(method, source)
ensureAllRecursiveCallsHaveBeenTransformed(method)
}
private boolean hasAnnotation(MethodNode methodNode, ClassNode annotation) {
List annots = methodNode.getAnnotations(annotation);
return (annots != null && annots.size() > 0);
}
private void transformToIteration(MethodNode method, SourceUnit source) {
if (method.isVoidMethod()) {
transformVoidMethodToIteration(method, source)
} else {
transformNonVoidMethodToIteration(method, source)
}
}
private void transformVoidMethodToIteration(MethodNode method, SourceUnit source) {
addError("Void methods are not supported by @TailRecursive yet.", method)
}
private void transformNonVoidMethodToIteration(MethodNode method, SourceUnit source) {
addMissingDefaultReturnStatement(method)
replaceReturnsWithTernariesToIfStatements(method)
wrapMethodBodyWithWhileLoop(method)
Map nameAndTypeMapping = name2VariableMappingFor(method)
replaceAllAccessToParams(method, nameAndTypeMapping)
addLocalVariablesForAllParameters(method, nameAndTypeMapping) //must happen after replacing access to params
Map positionMapping = position2VariableMappingFor(method)
replaceAllRecursiveReturnsWithIteration(method, positionMapping)
repairVariableScopes(source, method)
}
private void repairVariableScopes(SourceUnit source, MethodNode method) {
new VariableScopeVisitor(source).visitClass(method.declaringClass)
}
private void replaceReturnsWithTernariesToIfStatements(MethodNode method) {
Closure whenReturnWithTernary = { ASTNode node ->
if (!(node instanceof ReturnStatement)) {
return false
}
return (((ReturnStatement) node).expression instanceof TernaryExpression)
}
Closure replaceWithIfStatement = { ReturnStatement statement ->
ternaryToIfStatement.convert(statement)
}
StatementReplacer replacer = new StatementReplacer(when: whenReturnWithTernary, replaceWith: replaceWithIfStatement)
replacer.replaceIn(method.code)
}
private void addLocalVariablesForAllParameters(MethodNode method, Map nameAndTypeMapping) {
BlockStatement code = method.code as BlockStatement
nameAndTypeMapping.each { String paramName, Map localNameAndType ->
code.statements.add(0, AstHelper.createVariableDefinition(
(String) localNameAndType['name'],
(ClassNode) localNameAndType['type'],
new VariableExpression(paramName, (ClassNode) localNameAndType['type'])
))
}
}
private void replaceAllAccessToParams(MethodNode method, Map nameAndTypeMapping) {
new VariableAccessReplacer(nameAndTypeMapping: nameAndTypeMapping).replaceIn(method.code)
}
// Public b/c there are tests for this method
Map name2VariableMappingFor(MethodNode method) {
Map nameAndTypeMapping = [:]
method.parameters.each { Parameter param ->
String paramName = param.name
ClassNode paramType = param.type as ClassNode
String iterationVariableName = iterationVariableName(paramName)
nameAndTypeMapping[paramName] = [name: iterationVariableName, type: paramType]
}
return nameAndTypeMapping
}
// Public b/c there are tests for this method
Map position2VariableMappingFor(MethodNode method) {
Map positionMapping = [:]
method.parameters.eachWithIndex { Parameter param, int index ->
String paramName = param.name
ClassNode paramType = param.type as ClassNode
String iterationVariableName = this.iterationVariableName(paramName)
positionMapping[index] = [name: iterationVariableName, type: paramType]
}
return positionMapping
}
private String iterationVariableName(String paramName) {
'_' + paramName + '_'
}
private void replaceAllRecursiveReturnsWithIteration(MethodNode method, Map positionMapping) {
replaceRecursiveReturnsOutsideClosures(method, positionMapping)
replaceRecursiveReturnsInsideClosures(method, positionMapping)
}
private void replaceRecursiveReturnsOutsideClosures(MethodNode method, Map positionMapping) {
Closure whenRecursiveReturn = { Statement statement, boolean inClosure ->
if (inClosure)
return false
if (!(statement instanceof ReturnStatement)) {
return false
}
Expression inner = ((ReturnStatement) statement).expression
if (!(inner instanceof MethodCallExpression) && !(inner instanceof StaticMethodCallExpression)) {
return false
}
return isRecursiveIn(inner, method)
}
Closure replaceWithContinueBlock = { ReturnStatement statement ->
new ReturnStatementToIterationConverter().convert(statement, positionMapping)
}
def replacer = new StatementReplacer(when: whenRecursiveReturn, replaceWith: replaceWithContinueBlock)
replacer.replaceIn(method.code)
}
private void replaceRecursiveReturnsInsideClosures(MethodNode method, Map positionMapping) {
Closure whenRecursiveReturn = { Statement statement, boolean inClosure ->
if (!inClosure)
return false
if (!(statement instanceof ReturnStatement)) {
return false
}
Expression inner = ((ReturnStatement )statement).expression
if (!(inner instanceof MethodCallExpression) && !(inner instanceof StaticMethodCallExpression)) {
return false
}
return isRecursiveIn(inner, method)
}
Closure replaceWithThrowLoopException = { ReturnStatement statement ->
new ReturnStatementToIterationConverter(recurStatement: AstHelper.recurByThrowStatement()).convert(statement, positionMapping)
}
StatementReplacer replacer = new StatementReplacer(when: whenRecursiveReturn, replaceWith: replaceWithThrowLoopException)
replacer.replaceIn(method.code)
}
private void wrapMethodBodyWithWhileLoop(MethodNode method) {
new InWhileLoopWrapper().wrap(method)
}
private void addMissingDefaultReturnStatement(MethodNode method) {
new ReturnAdder().visitMethod(method)
new ReturnAdderForClosures().visitMethod(method)
}
private void ensureAllRecursiveCallsHaveBeenTransformed(MethodNode method) {
List remainingRecursiveCalls = new CollectRecursiveCalls().collect(method)
for(Expression expression : remainingRecursiveCalls) {
addError("Recursive call could not be transformed by @TailRecursive. Maybe it's not a tail call.", expression)
}
}
private boolean hasRecursiveMethodCalls(MethodNode method) {
hasRecursiveCalls.test(method)
}
private boolean isRecursiveIn(Expression methodCall, MethodNode method) {
if (methodCall instanceof MethodCallExpression)
return new RecursivenessTester().isRecursive(method, (MethodCallExpression) methodCall)
if (methodCall instanceof StaticMethodCallExpression)
return new RecursivenessTester().isRecursive(method, (StaticMethodCallExpression) methodCall)
}
}