All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.codehaus.groovy.transform.NewifyASTTransformation Maven / Gradle / Ivy

The newest version!
/*
 *  Licensed to the Apache Software Foundation (ASF) under one
 *  or more contributor license agreements.  See the NOTICE file
 *  distributed with this work for additional information
 *  regarding copyright ownership.  The ASF licenses this file
 *  to you under the Apache License, Version 2.0 (the
 *  "License"); you may not use this file except in compliance
 *  with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing,
 *  software distributed under the License is distributed on an
 *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 *  KIND, either express or implied.  See the License for the
 *  specific language governing permissions and limitations
 *  under the License.
 */
package org.codehaus.groovy.transform;

import groovy.lang.Newify;
import org.codehaus.groovy.GroovyBugError;
import org.codehaus.groovy.ast.ASTNode;
import org.codehaus.groovy.ast.AnnotatedNode;
import org.codehaus.groovy.ast.AnnotationNode;
import org.codehaus.groovy.ast.ClassCodeExpressionTransformer;
import org.codehaus.groovy.ast.ClassNode;
import org.codehaus.groovy.ast.FieldNode;
import org.codehaus.groovy.ast.MethodNode;
import org.codehaus.groovy.ast.expr.ClassExpression;
import org.codehaus.groovy.ast.expr.ClosureExpression;
import org.codehaus.groovy.ast.expr.ConstantExpression;
import org.codehaus.groovy.ast.expr.ConstructorCallExpression;
import org.codehaus.groovy.ast.expr.DeclarationExpression;
import org.codehaus.groovy.ast.expr.Expression;
import org.codehaus.groovy.ast.expr.ListExpression;
import org.codehaus.groovy.ast.expr.MethodCallExpression;
import org.codehaus.groovy.ast.expr.VariableExpression;
import org.codehaus.groovy.control.CompilePhase;
import org.codehaus.groovy.control.SourceUnit;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static org.codehaus.groovy.ast.ClassHelper.make;
import static org.codehaus.groovy.ast.tools.GeneralUtils.callX;
import static org.codehaus.groovy.ast.tools.GeneralUtils.classX;

/**
 * Handles generation of code for the {@code @Newify} AST transform.
 */
@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION)
public class NewifyASTTransformation extends ClassCodeExpressionTransformer implements ASTTransformation {
    private static final ClassNode MY_TYPE = make(Newify.class);
    private static final String MY_NAME = MY_TYPE.getNameWithoutPackage();
    private static final String BASE_BAD_PARAM_ERROR = "Error during @" + MY_NAME +
            " processing. Annotation parameter must be a class or list of classes but found ";
    private SourceUnit source;
    private ListExpression classesToNewify;
    private DeclarationExpression candidate;
    private boolean auto;

    public void visit(ASTNode[] nodes, SourceUnit source) {
        this.source = source;
        if (nodes.length != 2 || !(nodes[0] instanceof AnnotationNode) || !(nodes[1] instanceof AnnotatedNode)) {
            internalError("Expecting [AnnotationNode, AnnotatedClass] but got: " + Arrays.asList(nodes));
        }

        AnnotatedNode parent = (AnnotatedNode) nodes[1];
        AnnotationNode node = (AnnotationNode) nodes[0];
        if (!MY_TYPE.equals(node.getClassNode())) {
            internalError("Transformation called from wrong annotation: " + node.getClassNode().getName());
        }

        boolean autoFlag = determineAutoFlag(node.getMember("auto"));
        Expression value = node.getMember("value");

        if (parent instanceof ClassNode) {
            newifyClass((ClassNode) parent, autoFlag, determineClasses(value, false));
        } else if (parent instanceof MethodNode || parent instanceof FieldNode) {
            newifyMethodOrField(parent, autoFlag, determineClasses(value, false));
        } else if (parent instanceof DeclarationExpression) {
            newifyDeclaration((DeclarationExpression) parent, autoFlag, determineClasses(value, true));
        }
    }

    private void newifyDeclaration(DeclarationExpression de, boolean autoFlag, ListExpression list) {
        ClassNode cNode = de.getDeclaringClass();
        candidate = de;
        final ListExpression oldClassesToNewify = classesToNewify;
        final boolean oldAuto = auto;
        classesToNewify = list;
        auto = autoFlag;
        super.visitClass(cNode);
        classesToNewify = oldClassesToNewify;
        auto = oldAuto;
    }

    private boolean determineAutoFlag(Expression autoExpr) {
        return !(autoExpr instanceof ConstantExpression && ((ConstantExpression) autoExpr).getValue().equals(false));
    }

    /** allow non-strict mode in scripts because parsing not complete at that point */
    private ListExpression determineClasses(Expression expr, boolean searchSourceUnit) {
        ListExpression list = new ListExpression();
        if (expr instanceof ClassExpression) {
            list.addExpression(expr);
        } else if (expr instanceof VariableExpression && searchSourceUnit) {
            VariableExpression ve = (VariableExpression) expr;
            ClassNode fromSourceUnit = getSourceUnitClass(ve);
            if (fromSourceUnit != null) {
                ClassExpression found = classX(fromSourceUnit);
                found.setSourcePosition(ve);
                list.addExpression(found);
            } else {
                addError(BASE_BAD_PARAM_ERROR + "an unresolvable reference to '" + ve.getName() + "'.", expr);
            }
        } else if (expr instanceof ListExpression) {
            list = (ListExpression) expr;
            final List expressions = list.getExpressions();
            for (int i = 0; i < expressions.size(); i++) {
                Expression next = expressions.get(i);
                if (next instanceof VariableExpression && searchSourceUnit) {
                    VariableExpression ve = (VariableExpression) next;
                    ClassNode fromSourceUnit = getSourceUnitClass(ve);
                    if (fromSourceUnit != null) {
                        ClassExpression found = classX(fromSourceUnit);
                        found.setSourcePosition(ve);
                        expressions.set(i, found);
                    } else {
                        addError(BASE_BAD_PARAM_ERROR + "a list containing an unresolvable reference to '" + ve.getName() + "'.", next);
                    }
                } else if (!(next instanceof ClassExpression)) {
                    addError(BASE_BAD_PARAM_ERROR + "a list containing type: " + next.getType().getName() + ".", next);
                }
            }
            checkDuplicateNameClashes(list);
        } else if (expr != null) {
            addError(BASE_BAD_PARAM_ERROR + "a type: " + expr.getType().getName() + ".", expr);
        }
        return list;
    }

    private ClassNode getSourceUnitClass(VariableExpression ve) {
        List classes = source.getAST().getClasses();
        for (ClassNode classNode : classes) {
            if (classNode.getNameWithoutPackage().equals(ve.getName())) return classNode;
        }
        return null;
    }

    public Expression transform(Expression expr) {
        if (expr == null) return null;
        if (expr instanceof MethodCallExpression && candidate == null) {
            MethodCallExpression mce = (MethodCallExpression) expr;
            Expression args = transform(mce.getArguments());
            if (isNewifyCandidate(mce)) {
                Expression transformed = transformMethodCall(mce, args);
                transformed.setSourcePosition(mce);
                return transformed;
            }
            Expression method = transform(mce.getMethod());
            Expression object = transform(mce.getObjectExpression());
            MethodCallExpression transformed = callX(object, method, args);
            transformed.setImplicitThis(mce.isImplicitThis());
            transformed.setSafe(mce.isSafe());
            transformed.setSourcePosition(mce);
            return transformed;
        } else if (expr instanceof ClosureExpression) {
            ClosureExpression ce = (ClosureExpression) expr;
            ce.getCode().visit(this);
        } else if (expr instanceof ConstructorCallExpression) {
            ConstructorCallExpression cce = (ConstructorCallExpression) expr;
            if (cce.isUsingAnonymousInnerClass()) {
                cce.getType().visitContents(this);
            }
        } else if (expr instanceof DeclarationExpression) {
            DeclarationExpression de = (DeclarationExpression) expr;
            if (shouldTransform(de)) {
                candidate = null;
                Expression left = de.getLeftExpression();
                Expression right = transform(de.getRightExpression());
                DeclarationExpression newDecl = new DeclarationExpression(left, de.getOperation(), right);
                newDecl.addAnnotations(de.getAnnotations());
                return newDecl;
            }
            return de;
        }
        return expr.transformExpression(this);
    }

    private boolean shouldTransform(DeclarationExpression exp) {
        return exp == candidate || auto || hasClassesToNewify();
    }

    private boolean hasClassesToNewify() {
        return classesToNewify != null && !classesToNewify.getExpressions().isEmpty();
    }

    private void newifyClass(ClassNode cNode, boolean autoFlag, ListExpression list) {
        String cName = cNode.getName();
        if (cNode.isInterface()) {
            addError("Error processing interface '" + cName + "'. @"
                    + MY_NAME + " not allowed for interfaces.", cNode);
        }
        final ListExpression oldClassesToNewify = classesToNewify;
        final boolean oldAuto = auto;
        classesToNewify = list;
        auto = autoFlag;
        super.visitClass(cNode);
        classesToNewify = oldClassesToNewify;
        auto = oldAuto;
    }

    private void newifyMethodOrField(AnnotatedNode parent, boolean autoFlag, ListExpression list) {
        final ListExpression oldClassesToNewify = classesToNewify;
        final boolean oldAuto = auto;
        checkClassLevelClashes(list);
        checkAutoClash(autoFlag, parent);
        classesToNewify = list;
        auto = autoFlag;
        if (parent instanceof FieldNode) {
            super.visitField((FieldNode) parent);
        } else {
            super.visitMethod((MethodNode) parent);
        }
        classesToNewify = oldClassesToNewify;
        auto = oldAuto;
    }

    private void checkDuplicateNameClashes(ListExpression list) {
        final Set seen = new HashSet();
        @SuppressWarnings("unchecked")
        final List classes = (List)list.getExpressions();
        for (ClassExpression ce : classes) {
            final String name = ce.getType().getNameWithoutPackage();
            if (seen.contains(name)) {
                addError("Duplicate name '" + name + "' found during @" + MY_NAME + " processing.", ce);
            }
            seen.add(name);
        }
    }

    private void checkAutoClash(boolean autoFlag, AnnotatedNode parent) {
        if (auto && !autoFlag) {
            addError("Error during @" + MY_NAME + " processing. The 'auto' flag can't be false at " +
                    "method/constructor/field level if it is true at the class level.", parent);
        }
    }

    private void checkClassLevelClashes(ListExpression list) {
        @SuppressWarnings("unchecked")
        final List classes = (List)list.getExpressions();
        for (ClassExpression ce : classes) {
            final String name = ce.getType().getNameWithoutPackage();
            if (findClassWithMatchingBasename(name)) {
                addError("Error during @" + MY_NAME + " processing. Class '" + name + "' can't appear at " +
                        "method/constructor/field level if it already appears at the class level.", ce);
            }
        }
    }

    private boolean findClassWithMatchingBasename(String nameWithoutPackage) {
        if (classesToNewify == null) return false;
        @SuppressWarnings("unchecked")
        final List classes = (List)classesToNewify.getExpressions();
        for (ClassExpression ce : classes) {
            if (ce.getType().getNameWithoutPackage().equals(nameWithoutPackage)) {
                return true;
            }
        }
        return false;
    }

    private boolean isNewifyCandidate(MethodCallExpression mce) {
        return mce.getObjectExpression() == VariableExpression.THIS_EXPRESSION
                || (auto && isNewMethodStyle(mce));
    }

    private boolean isNewMethodStyle(MethodCallExpression mce) {
        final Expression obj = mce.getObjectExpression();
        final Expression meth = mce.getMethod();
        return (obj instanceof ClassExpression && meth instanceof ConstantExpression
                && ((ConstantExpression) meth).getValue().equals("new"));
    }

    private Expression transformMethodCall(MethodCallExpression mce, Expression args) {
        ClassNode classType;
        if (isNewMethodStyle(mce)) {
            classType = mce.getObjectExpression().getType();
        } else {
            classType = findMatchingCandidateClass(mce);
        }
        if (classType != null) {
            return new ConstructorCallExpression(classType, args);
        }
        // set the args as they might have gotten Newify transformed GROOVY-3491
        mce.setArguments(args);
        return mce;
    }

    private ClassNode findMatchingCandidateClass(MethodCallExpression mce) {
        if (classesToNewify == null) return null;
        @SuppressWarnings("unchecked")
        List classes = (List)classesToNewify.getExpressions();
        for (ClassExpression ce : classes) {
            final ClassNode type = ce.getType();
            if (type.getNameWithoutPackage().equals(mce.getMethodAsString())) {
                return type;
            }
        }
        return null;
    }

    private void internalError(String message) {
        throw new GroovyBugError("Internal error: " + message);
    }

    protected SourceUnit getSourceUnit() {
        return source;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy