All Downloads are FREE. Search and download functionalities are using the official Maven repository.

geb.transform.implicitassertions.ImplicitAssertionsTransformationVisitor.groovy Maven / Gradle / Ivy

Go to download

Geb (pronounced "jeb") implicit assertion transform (used in conjunction with geb-core).

There is a newer version: 7.0
Show newest version
/*
 * Copyright 2012 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package geb.transform.implicitassertions

import org.codehaus.groovy.ast.ClassCodeVisitorSupport
import org.codehaus.groovy.ast.ClassHelper
import org.codehaus.groovy.ast.ClassNode
import org.codehaus.groovy.ast.FieldNode
import org.codehaus.groovy.control.SourceUnit
import org.codehaus.groovy.ast.expr.*
import org.codehaus.groovy.ast.stmt.*
import static org.codehaus.groovy.syntax.Types.ASSIGNMENT_OPERATOR
import static org.codehaus.groovy.syntax.Types.ofType
import static geb.transform.implicitassertions.ImplicitAssertionsTransformationUtil.*

class ImplicitAssertionsTransformationVisitor extends ClassCodeVisitorSupport {

    private static final int LAST = -1
    private static final String WAIT_FOR_METHOD_NAME = "waitFor"

    SourceUnit sourceUnit

    ImplicitAssertionsTransformationVisitor(SourceUnit sourceUnit) {
        this.sourceUnit = sourceUnit
    }

    @Override
    void visitField(FieldNode node) {
        if (node.static && node.initialExpression in ClosureExpression) {
            switch (node.name) {
                case 'at':
                    transformEachStatement(node.initialExpression)
                    break
                case 'content':
                    visitContentDsl(node.initialExpression)
                    break
            }
        }
    }

    private boolean lastArgumentIsClosureExpression(ArgumentListExpression arguments) {
        arguments.expressions && arguments.expressions[LAST] in ClosureExpression
    }

    @Override
    void visitExpressionStatement(ExpressionStatement statement) {
        if (statement.expression in MethodCallExpression) {
            MethodCallExpression expression = statement.expression
            if (expression.methodAsString == WAIT_FOR_METHOD_NAME && expression.arguments in ArgumentListExpression) {
                ArgumentListExpression arguments = expression.arguments
                if (lastArgumentIsClosureExpression(arguments)) {
                    transformEachStatement(arguments.expressions[LAST])
                }
            } else {
                compensateForSpockIfNecessary(expression)
            }
        }
    }

    void compensateForSpockIfNecessary(MethodCallExpression expression) {
        if (expression.objectExpression in ClassExpression && expression.method in ConstantExpression) {
            ClassExpression classExpression = expression.objectExpression as ClassExpression
            ConstantExpression method = expression.method as ConstantExpression

            if (classExpression.type.name == "org.spockframework.runtime.SpockRuntime" && method.value == "verifyMethodCondition") {
                if (expression.arguments in ArgumentListExpression) {
                    ArgumentListExpression arguments = expression.arguments as ArgumentListExpression
                    List argumentExpressions = arguments.expressions

                    if (argumentExpressions.size() >= 8) {
                        Expression verifyMethodConditionMethodArg = argumentExpressions.get(6)
                        String methodName = getConstantValueOfType(extractRecordedValueExpression(verifyMethodConditionMethodArg), String)

                        if (methodName) {
                            Expression verifyMethodConditionArgsArgument = argumentExpressions.get(7)
                            if (verifyMethodConditionArgsArgument in ArrayExpression) {

                                List values = (verifyMethodConditionArgsArgument as ArrayExpression).expressions.collect { Expression argumentExpression ->
                                    extractRecordedValueExpression(argumentExpression)
                                }

                                visitSpockValueRecordMethodCall(methodName, values)
                            }
                        }
                    }
                }
            }
        }
    }

    Expression extractRecordedValueExpression(Expression valueRecordExpression) {
        if (valueRecordExpression in MethodCallExpression) {
            MethodCallExpression methodCallExpression = valueRecordExpression as MethodCallExpression

            if (methodCallExpression.arguments in ArgumentListExpression) {
                ArgumentListExpression arguments = methodCallExpression.arguments as ArgumentListExpression

                if (arguments.expressions.size() >= 2) {
                    return arguments.expressions.get(1)
                }
            }
        }

        null
    }

    def getConstantValueOfType(Expression expression, Class type) {
        if (expression != null && expression in ConstantExpression) {
            Object value = ((ConstantExpression) expression).value
            type.isInstance(value) ? value : null
        } else {
            null
        }
    }

    void visitSpockValueRecordMethodCall(String name, List arguments) {
        if (name == WAIT_FOR_METHOD_NAME) {
            if (!arguments.empty) {
                Expression lastArg = arguments.last()
                if (lastArg instanceof ClosureExpression) {
                    transformEachStatement(lastArg as ClosureExpression)
                }
            }
        }
    }

    @Override
    protected SourceUnit getSourceUnit() {
        sourceUnit
    }

    private boolean requiredOptionSpecifiedAsFalse(ArgumentListExpression arguments) {
        MapExpression paramMap = arguments.expressions.find { it in MapExpression }
        paramMap?.mapEntryExpressions.any {
            if (it.keyExpression in ConstantExpression && it.valueExpression in ConstantExpression) {
                ConstantExpression key = it.keyExpression
                ConstantExpression value = it.valueExpression
                key.value == 'required' && value.value == false
            }
        }
    }

    private boolean waitOptionIsSpecified(ArgumentListExpression arguments) {
        MapExpression paramMap = arguments.expressions.find { it in MapExpression }
        paramMap?.mapEntryExpressions.any {
            if (it.keyExpression in ConstantExpression) {
                ConstantExpression key = it.keyExpression
                key.value == 'wait'
            }
        }
    }

    private void visitContentDsl(ClosureExpression closureExpression) {
        BlockStatement blockStatement = closureExpression.code
        blockStatement.statements.each { Statement statement ->
            if (statement in ExpressionStatement) {
                ExpressionStatement expressionStatement = statement
                if (expressionStatement.expression in MethodCallExpression) {
                    MethodCallExpression methodCall = expressionStatement.expression
                    if (methodCall.arguments in ArgumentListExpression) {
                        ArgumentListExpression arguments = methodCall.arguments
                        if (lastArgumentIsClosureExpression(arguments) && waitOptionIsSpecified(arguments) && !requiredOptionSpecifiedAsFalse(arguments)) {
                            transformEachStatement(arguments.expressions[LAST])
                        }
                    }
                }
            }
        }
    }

    private void transformEachStatement(ClosureExpression closureExpression) {
        BlockStatement blockStatement = closureExpression.code
        ListIterator iterator = blockStatement.statements.listIterator()
        while (iterator.hasNext()) {
            iterator.set(maybeTransform(iterator.next()))
        }
    }

    private Statement maybeTransform(Statement statement) {
        Statement result = statement
        Expression expression = getTransformableExpression(statement)
        if (expression) {
            result = transform(expression, statement)
        }
        result
    }

    private Expression getTransformableExpression(Statement statement) {
        if (statement in ExpressionStatement) {
            ExpressionStatement expressionStatement = statement
            if (!(expressionStatement.expression in DeclarationExpression)
                && isTransformable(expressionStatement)) {
                return expressionStatement.expression
            }
        }
    }

    boolean isTransformable(ExpressionStatement statement) {
        if (statement.expression in BinaryExpression) {
            BinaryExpression binaryExpression = statement.expression
            if (ofType(binaryExpression.operation.type, ASSIGNMENT_OPERATOR)) {
                reportError(statement, "Expected a condition, but found an assignment. Did you intend to write '==' ?", sourceUnit)
                false
            }
        }
        true
    }

    private Statement transform(Expression expression, Statement statement) {
        Statement replacement

        Expression recordedValueExpression = createRuntimeCall("recordValue", expression)
        BooleanExpression booleanExpression = new BooleanExpression(recordedValueExpression)

        Statement retrieveRecordedValueStatement = new ExpressionStatement(createRuntimeCall("retrieveRecordedValue"))

        Statement withAssertion = new AssertStatement(booleanExpression)
        withAssertion.setSourcePosition(expression)
        withAssertion.setStatementLabel((String) expression.getNodeMetaData("statementLabel"))

        BlockStatement assertAndRetrieveRecordedValue = new BlockStatement()
        assertAndRetrieveRecordedValue.addStatement(withAssertion)
        assertAndRetrieveRecordedValue.addStatement(retrieveRecordedValueStatement)

        if (expression in MethodCallExpression) {
            MethodCallExpression rewrittenMethodCall = expression

            Statement noAssertion = new ExpressionStatement(expression)
            StaticMethodCallExpression isVoidMethod = createRuntimeCall(
                "isVoidMethod",
                rewrittenMethodCall.objectExpression,
                rewrittenMethodCall.method,
                toArgumentArray(rewrittenMethodCall.arguments)
            )

            replacement = new IfStatement(new BooleanExpression(isVoidMethod), noAssertion, assertAndRetrieveRecordedValue)
        } else {
            replacement = assertAndRetrieveRecordedValue
        }

        replacement.setSourcePosition(statement)
        replacement
    }

    private StaticMethodCallExpression createRuntimeCall(String methodName, Expression... argumentExpressions) {
        ArgumentListExpression argumentListExpression = new ArgumentListExpression()
        for (Expression expression in argumentExpressions) {
            argumentListExpression.addExpression(expression)
        }

        new StaticMethodCallExpression(new ClassNode(Runtime), methodName, argumentListExpression)
    }

    private Expression toArgumentArray(Expression arguments) {
        List argumentList
        if (arguments instanceof NamedArgumentListExpression) {
            argumentList = [arguments]
        } else {
            TupleExpression tuple = arguments
            argumentList = tuple.expressions
        }
        List spreadExpressions = argumentList.findAll { it in SpreadExpression }
        if (spreadExpressions) {
            spreadExpressions.each { reportError(it, 'Spread expressions are not allowed here', sourceUnit) }
            null
        } else {
            new ArrayExpression(ClassHelper.OBJECT_TYPE, argumentList)
        }
    }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy