All Downloads are FREE. Search and download functionalities are using the official Maven repository.

lombok.javac.handlers.HandleYield Maven / Gradle / Ivy

There is a newer version: 0.11.3
Show newest version
/*
 * Copyright © 2011 Philipp Eichhorn
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 * 
 * Credit where credit is due:
 * ===========================
 * 
 *   Yield is based on the idea, algorithms and sources of
 *   Arthur and Vladimir Nesterovsky (http://www.nesterovsky-bros.com/weblog),
 *   who generously allowed me to use them.
 */
package lombok.javac.handlers;

import static lombok.ast.AST.*;
import static lombok.core.util.ErrorMessages.*;
import static lombok.javac.handlers.Javac.*;

import java.util.*;

import lombok.*;
import lombok.ast.*;
import lombok.core.handlers.YieldHandler;
import lombok.core.handlers.YieldHandler.AbstractYieldDataCollector;
import lombok.core.handlers.YieldHandler.ErrorHandler;
import lombok.core.handlers.YieldHandler.Scope;
import lombok.core.util.As;
import lombok.core.util.Each;
import lombok.core.util.Is;
import lombok.javac.JavacASTAdapter;
import lombok.javac.JavacASTVisitor;
import lombok.javac.JavacNode;
import lombok.javac.handlers.ast.JavacMethod;

import org.mangosdk.spi.ProviderFor;

import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.JCTry;
import com.sun.tools.javac.tree.TreeScanner;
import com.sun.tools.javac.tree.JCTree.JCBlock;
import com.sun.tools.javac.tree.JCTree.JCBreak;
import com.sun.tools.javac.tree.JCTree.JCCase;
import com.sun.tools.javac.tree.JCTree.JCCatch;
import com.sun.tools.javac.tree.JCTree.JCContinue;
import com.sun.tools.javac.tree.JCTree.JCDoWhileLoop;
import com.sun.tools.javac.tree.JCTree.JCEnhancedForLoop;
import com.sun.tools.javac.tree.JCTree.JCExpressionStatement;
import com.sun.tools.javac.tree.JCTree.JCForLoop;
import com.sun.tools.javac.tree.JCTree.JCIdent;
import com.sun.tools.javac.tree.JCTree.JCIf;
import com.sun.tools.javac.tree.JCTree.JCLabeledStatement;
import com.sun.tools.javac.tree.JCTree.JCLiteral;
import com.sun.tools.javac.tree.JCTree.JCNewClass;
import com.sun.tools.javac.tree.JCTree.JCParens;
import com.sun.tools.javac.tree.JCTree.JCReturn;
import com.sun.tools.javac.tree.JCTree.JCStatement;
import com.sun.tools.javac.tree.JCTree.JCSwitch;
import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
import com.sun.tools.javac.tree.JCTree.JCCompilationUnit;
import com.sun.tools.javac.tree.JCTree.JCExpression;
import com.sun.tools.javac.tree.JCTree.JCMethodInvocation;
import com.sun.tools.javac.tree.JCTree.JCTypeApply;
import com.sun.tools.javac.tree.JCTree.JCWhileLoop;
import com.sun.tools.javac.util.Name;

@ProviderFor(JavacASTVisitor.class)
public class HandleYield extends JavacASTAdapter {
	private final Set methodNames = new HashSet();

	@Override
	public void visitCompilationUnit(final JavacNode top, final JCCompilationUnit unit) {
		methodNames.clear();
	}

	@Override
	public void visitStatement(final JavacNode statementNode, final JCTree statement) {
		if (statement instanceof JCMethodInvocation) {
			JCMethodInvocation methodCall = (JCMethodInvocation) statement;
			String methodName = methodCall.meth.toString();
			if (isMethodCallValid(statementNode, methodName, Yield.class, "yield")) {
				final JavacMethod method = JavacMethod.methodOf(statementNode, statement);
				if ((method == null) || method.isConstructor()) {
					statementNode.addError(canBeUsedInBodyOfMethodsOnly("yield"));
				} else if (new YieldHandler().handle(method, new JavacYieldDataCollector())) {
					methodNames.add(methodName);
				}
			}
		}
	}

	@Override
	public void endVisitCompilationUnit(final JavacNode top, final JCCompilationUnit unit) {
		for (String methodName : methodNames) {
			deleteMethodCallImports(top, methodName, Yield.class, "yield");
		}
	}

	private static class JavacYieldDataCollector extends AbstractYieldDataCollector {

		public String elementType(final JavacMethod method) {
			JCExpression type = method.get().restype;
			if (type instanceof JCTypeApply) {
				JCTypeApply returnType = (JCTypeApply) type;
				if (!returnType.arguments.isEmpty()) {
					return returnType.arguments.head.type.toString();
				}
			}
			return Object.class.getName();
		}

		public boolean scan() {
			try {
				new YieldQuickScanner().scan(method.get().body);
				return false;
			} catch (final IllegalStateException ignore) {
				// this means there are unhandled yields left
			}

			YieldScanner scanner = new YieldScanner();
			scanner.scan(method.get().body);

			for (Scope scope : yields) {
				Scope yieldScope = scope;
				do {
					allScopes.put(yieldScope.node, yieldScope);
					yieldScope = yieldScope.parent;
				} while (yieldScope != null);
			}

			boolean collected = !breaks.isEmpty();
			while (collected) {
				collected = false;
				for (Scope scope : breaks) {
					Scope target = scope.target;
					if (((target == null) || allScopes.containsKey(target.node)) && !allScopes.containsKey(scope.node)) {
						collected = true;
						Scope breakScope = scope;
						do {
							allScopes.put(breakScope.node, breakScope);
							breakScope = breakScope.parent;
						} while (breakScope != null);
					}
				}
			}

			for (Scope scope : variableDecls) {
				boolean stateVariable = false;
				if (allScopes.containsKey(scope.parent.node)) {
					stateVariable = true;
				} else {
					if ((scope.parent.node instanceof JCCatch) && allScopes.containsKey(scope.parent.parent.node)) {
						stateVariable = true;
					}
				}

				if (stateVariable) {
					JCVariableDecl variable = (JCVariableDecl) scope.node;
					allScopes.put(scope.node, scope);
					FieldDecl field = FieldDecl(Type(variable.vartype), As.string(variable.name)).makePrivate();
					if (scope.parent.node instanceof JCTry) {
						field.withAnnotation(Annotation(Type(SuppressWarnings.class)).withValue(String("unused")));
					}
					stateVariables.add(field);
				}
			}
			return true;
		}

		public void prepareRefactor() {
			root = allScopes.get(method.get().body);
		}

		private JCExpression getYieldExpression(final JCExpression expr) {
			if (expr instanceof JCMethodInvocation) {
				JCMethodInvocation methodCall = (JCMethodInvocation) expr;
				if (methodCall.meth.toString().endsWith("yield") && (methodCall.args.length() == 1)) {
					return methodCall.args.head;
				}
			}
			return null;
		}

		private boolean isTrueLiteral(final JCExpression expression) {
			if (expression instanceof JCLiteral) {
				return "true".equals(expression.toString());
			} else if (expression instanceof JCParens) {
				return isTrueLiteral(((JCParens) expression).expr);
			}
			return false;
		}

		private Scope getFinallyScope(Scope scope, Scope top) {
			JCTree previous = null;
			while (scope != null) {
				JCTree tree = scope.node;
				if (tree instanceof JCTry) {
					JCTry statement = (JCTry) tree;
					if ((statement.finalizer != null) && (statement.finalizer != previous)) {
						return scope;
					}
				}
				if (scope == top) break;
				previous = tree;
				scope = scope.parent;
			}
			return null;
		}

		private class YieldQuickScanner extends TreeScanner {
			@Override
			public void visitExec(final JCExpressionStatement tree) {
				final JCExpression expression = getYieldExpression(tree.expr);
				if (expression != null) {
					throw new IllegalStateException();
				} else {
					super.visitExec(tree);
				}
			}
		}

		private class YieldScanner extends TreeScanner {

			@Override
			public void visitBlock(final JCBlock tree) {
				current = new Scope(current, tree) {
					@Override
					public void refactor() {
						for (JCStatement statement : Each.elementIn(tree.stats)) {
							refactorStatement(statement);
						}
						addLabel(getBreakLabel(this));
					}
				};

				super.visitBlock(tree);
				current = current.parent;
			}

			@Override
			public void visitLabelled(final JCLabeledStatement tree) {
				current = new Scope(current, tree) {
					@Override
					public void refactor() {
						refactorStatement(tree.body);
					}
				};

				super.visitLabelled(tree);
				current = current.parent;
			}

			@Override
			public void visitForLoop(final JCForLoop tree) {
				current = new Scope(current, tree) {
					@Override
					public void refactor() {
						for (JCStatement statement : Each.elementIn(tree.init)) {
							refactorStatement(statement);
						}
						Case label = Case();
						Case breakLabel = getBreakLabel(this);
						addLabel(label);
						if ((tree.cond != null) && !isTrueLiteral(tree.cond)) {
							addStatement(If(Not(Expr(tree.cond))).Then(Block().withStatement(setState(literal(breakLabel))).withStatement(Continue())));
						}
						refactorStatement(tree.body);
						addLabel(getIterationLabel(this));
						for (JCStatement statement : Each.elementIn(tree.step)) {
							refactorStatement(statement);
						}
						addStatement(setState(literal(label)));
						addStatement(Continue());
						addLabel(breakLabel);
					}
				};

				super.visitForLoop(tree);
				current = current.parent;
			}

			@Override
			public void visitForeachLoop(final JCEnhancedForLoop tree) {
				current = new Scope(current, tree) {
					@Override
					public void refactor() {
						String iteratorVar = "$" + As.string(tree.var.name) + "Iter";
						stateVariables.add(FieldDecl(Type("java.util.Iterator"), iteratorVar).makePrivate().withAnnotation(Annotation(Type(SuppressWarnings.class)).withValue(String("all"))));

						addStatement(Assign(Name(iteratorVar), Call(Expr(tree.expr), "iterator")));
						addLabel(getIterationLabel(this));
						addStatement(If(Not(Call(Name(iteratorVar), "hasNext"))).Then(Block().withStatement(setState(literal(getBreakLabel(this)))).withStatement(Continue())));
						addStatement(Assign(Name(As.string(tree.var.name)), Cast(Type(tree.var.vartype), Call(Name(iteratorVar), "next"))));

						refactorStatement(tree.body);
						addStatement(setState(literal(getIterationLabel(this))));
						addStatement(Continue());
						addLabel(getBreakLabel(this));
					}
				};

				super.visitForeachLoop(tree);
				current = current.parent;
			}

			@Override
			public void visitDoLoop(final JCDoWhileLoop tree) {
				current = new Scope(current, tree) {
					@Override
					public void refactor() {
						addLabel(getIterationLabel(this));
						refactorStatement(tree.body);
						addStatement(If(Expr(tree.cond)).Then(Block().withStatement(setState(literal(getIterationLabel(this)))).withStatement(Continue())));
						addLabel(getBreakLabel(this));
					}
				};

				super.visitDoLoop(tree);
				current = current.parent;
			}

			@Override
			public void visitWhileLoop(final JCWhileLoop tree) {
				current = new Scope(current, tree) {
					@Override
					public void refactor() {
						addLabel(getIterationLabel(this));
						if (!isTrueLiteral(tree.cond)) {
							addStatement(If(Not(Expr(tree.cond))).Then(Block().withStatement(setState(literal(getBreakLabel(this)))).withStatement(Continue())));
						}
						refactorStatement(tree.body);
						addStatement(setState(literal(getIterationLabel(this))));
						addStatement(Continue());
						addLabel(getBreakLabel(this));
					}
				};

				super.visitWhileLoop(tree);
				current = current.parent;
			}

			@Override
			public void visitIf(final JCIf tree) {
				current = new Scope(current, tree) {
					@Override
					public void refactor() {
						Case label = tree.elsepart == null ? getBreakLabel(this) : Case();
						addStatement(If(Not(Expr(tree.cond))).Then(Block().withStatement(setState(literal(label))).withStatement(Continue())));
						if (tree.elsepart != null) {
							refactorStatement(tree.thenpart);
							addStatement(setState(literal(getBreakLabel(this))));
							addStatement(Continue());
							addLabel(label);
							refactorStatement(tree.elsepart);
							addLabel(getBreakLabel(this));
						} else {
							refactorStatement(tree.thenpart);
							addLabel(getBreakLabel(this));
						}
					}
				};

				super.visitIf(tree);
				current = current.parent;
			}

			@Override
			public void visitSwitch(final JCSwitch tree) {
				current = new Scope(current, tree) {
					@Override
					public void refactor() {
						Case breakLabel = getBreakLabel(this);
						Switch switchStatement = Switch(Expr(tree.selector));
						addStatement(switchStatement);
						if (Is.notEmpty(tree.cases)) {
							boolean hasDefault = false;
							for (JCCase item : tree.cases) {
								if (item.pat == null) {
									hasDefault = true;
								}
								Case label = Case();
								switchStatement.withCase(Case(Expr(item.pat)).withStatement(setState(literal(label))).withStatement(Continue()));
								addLabel(label);
								for (JCStatement statement : item.stats) {
									refactorStatement(statement);
								}
							}
							if (!hasDefault) {
								switchStatement.withCase(Case().withStatement(setState(literal(breakLabel))).withStatement(Continue()));
							}
						}
						addLabel(breakLabel);
					}
				};

				super.visitSwitch(tree);
				current = current.parent;
			}

			@Override
			public void visitTry(final JCTry tree) {
				current = new Scope(current, tree) {
					@Override
					public void refactor() {
						boolean hasFinally = tree.finalizer != null;
						boolean hasCatch = Is.notEmpty(tree.catchers);
						ErrorHandler catchHandler = null;
						ErrorHandler finallyHandler = null;
						Case tryLabel = Case();
						Case finallyLabel;
						Case breakLabel = getBreakLabel(this);
						String finallyErrorName = null;

						if (hasFinally) {
							finallyHandler = new ErrorHandler();
							finallyLabel = getFinallyLabel(this);
							finallyBlocks++;
							finallyErrorName = errorName + finallyBlocks;
							labelName = "$state" + finallyBlocks;

							stateVariables.add(FieldDecl(Type(Throwable.class), finallyErrorName).makePrivate());
							stateVariables.add(FieldDecl(Type("int"), labelName).makePrivate());

							addStatement(Assign(Name(finallyErrorName), Null()));
							addStatement(Assign(Name(labelName), literal(breakLabel)));
						} else {
							finallyLabel = breakLabel;
						}

						addStatement(setState(literal(tryLabel)));

						if (hasCatch) {
							catchHandler = new ErrorHandler();
							catchHandler.begin = cases.size();
						} else if (hasFinally) {
							finallyHandler.begin = cases.size();
						}

						addLabel(tryLabel);
						refactorStatement(tree.body);
						addStatement(setState(literal(finallyLabel)));

						if (hasCatch) {
							addStatement(Continue());
							catchHandler.end = cases.size();
							for (JCCatch catcher : tree.catchers) {
								Case label = Case();
								usedLabels.add(label);
								addLabel(label);
								refactorStatement(catcher.body);
								addStatement(setState(literal(finallyLabel)));
								addStatement(Continue());

								catchHandler.statements.add(If(InstanceOf(Name(errorName), Type(catcher.param.vartype))).Then(Block() //
										.withStatement(Assign(Name(As.string(catcher.param.name)), Cast(Type(catcher.param.vartype), Name(errorName)))) //
										.withStatement(setState(literal(label))).withStatement(Continue())));
							}

							errorHandlers.add(catchHandler);

							if (hasFinally) {
								finallyHandler.begin = catchHandler.end;
							}
						}

						if (hasFinally) {
							finallyHandler.end = cases.size();
							addLabel(finallyLabel);
							refactorStatement(tree.finalizer);

							addStatement(If(NotEqual(Name(finallyErrorName), Null())).Then(Block().withStatement(Assign(Name(errorName), Name(finallyErrorName))).withStatement(Break())));

							Scope next = getFinallyScope(parent, null);
							if (next != null) {
								Case label = getFinallyLabel(next);
								addStatement(If(Binary(Name(labelName), ">", literal(label))).Then(Block() //
										.withStatement(Assign(Name(next.labelName), Name(labelName))) //
										.withStatement(setState(literal(label))).withStatement(setState(Name(labelName)))));
							} else {
								addStatement(setState(Name(labelName)));
							}

							addStatement(Continue());

							finallyHandler.statements.add(Assign(Name(finallyErrorName), Name(errorName)));
							finallyHandler.statements.add(setState(literal(finallyLabel)));
							finallyHandler.statements.add(Continue());

							usedLabels.add(finallyLabel);
							errorHandlers.add(finallyHandler);
						}
						addLabel(breakLabel);
					}
				};
				super.visitTry(tree);
				current = current.parent;
			}

			@Override
			public void visitVarDef(final JCVariableDecl tree) {
				current = new Scope(current, tree) {
					@Override
					public void refactor() {
						if (tree.init != null) {
							addStatement(Assign(Name(As.string(tree.name)), Expr(tree.init)));
						}
					}
				};

				variableDecls.add(current);
				super.visitVarDef(tree);
				current = current.parent;
			}

			@Override
			public void visitReturn(final JCReturn tree) {
				method.node().addError("The 'return' expression is permitted.");
			}

			@Override
			public void visitBreak(final JCBreak tree) {
				Scope target = null;
				Name label = tree.label;
				if (label != null) {
					Scope labelScope = current;
					while (labelScope != null) {
						if (labelScope.node instanceof JCLabeledStatement) {
							JCLabeledStatement labeledStatement = (JCLabeledStatement) labelScope.node;
							if (label == labeledStatement.label) {
								if (target != null) {
									method.node().addError("Invalid label.");
								}
								target = labelScope;
							}
						}
						labelScope = labelScope.parent;
					}
				} else {
					Scope labelScope = current;
					while (labelScope != null) {
						if (Is.oneOf(labelScope.node, JCForLoop.class, JCEnhancedForLoop.class, JCWhileLoop.class, JCDoWhileLoop.class, JCSwitch.class)) {
							target = labelScope;
							break;
						}
						labelScope = labelScope.parent;
					}
				}

				if (target == null) {
					method.node().addError("Invalid break.");
				}

				current = new Scope(current, tree) {
					@Override
					public void refactor() {
						Scope next = getFinallyScope(parent, target);
						Case label = getBreakLabel(target);

						if (next == null) {
							addStatement(setState(literal(label)));
							addStatement(Continue());
						} else {
							addStatement(Assign(Name(next.labelName), literal(label)));
							addStatement(setState(literal(getFinallyLabel(next))));
							addStatement(Continue());
						}
					}
				};

				current.target = target;
				breaks.add(current);
				current = current.parent;
			}

			@Override
			public void visitContinue(final JCContinue tree) {
				Scope target = null;
				Name label = tree.label;
				if (label != null) {
					Scope labelScope = current;
					while (labelScope != null) {
						if (labelScope.node instanceof JCLabeledStatement) {
							JCLabeledStatement labeledStatement = (JCLabeledStatement) labelScope.node;
							if (label == labeledStatement.label) {
								if (target != null) {
									method.node().addError("Invalid label.");
								}
								if (Is.oneOf(labelScope.node, JCForLoop.class, JCEnhancedForLoop.class, JCWhileLoop.class, JCDoWhileLoop.class)) {
									target = labelScope;
								} else {
									method.node().addError("Invalid continue.");
								}
							}
						}
						labelScope = labelScope.parent;
					}
				} else {
					Scope labelScope = current;
					while (labelScope != null) {
						if (Is.oneOf(labelScope.node, JCForLoop.class, JCEnhancedForLoop.class, JCWhileLoop.class, JCDoWhileLoop.class)) {
							target = labelScope;
							break;
						}
						labelScope = labelScope.parent;
					}
				}

				if (target == null) {
					method.node().addError("Invalid continue.");
				}

				current = new Scope(current, tree) {
					@Override
					public void refactor() {
						Scope next = getFinallyScope(parent, target);
						Case label = getIterationLabel(target);

						if (next == null) {
							addStatement(setState(literal(label)));
							addStatement(Continue());
						} else {
							addStatement(Assign(Name(next.labelName), literal(label)));
							addStatement(setState(literal(getFinallyLabel(next))));
							addStatement(Continue());
						}
					}
				};

				current.target = target;
				breaks.add(current);
				current = current.parent;
			}

			@Override
			public void visitApply(final JCMethodInvocation tree) {
				if (tree.meth instanceof JCIdent) {
					String name = As.string(tree.meth);
					if (Is.oneOf(name, "hasNext", "next", "remove", "close")) {
						method.node().addError("Cannot call method " + name + "(), as it is hidden.");
					}
				}

				super.visitApply(tree);
			}

			@Override
			public void visitExec(final JCExpressionStatement tree) {
				final JCExpression expression = getYieldExpression(tree.expr);
				if (expression != null) {
					current = new Scope(current, tree) {
						@Override
						public void refactor() {
							Case label = getBreakLabel(this);
							addStatement(Assign(Name(nextName), Expr(expression)));
							addStatement(setState(literal(label)));
							addStatement(Return(True()));
							addLabel(label);

							Scope next = getFinallyScope(parent, null);
							if (next != null) {
								breakCases.add(new Case(literal(label)) //
										.withStatement(Assign(Name(next.labelName), literal(getBreakLabel(root)))) //
										.withStatement(setState(literal(getFinallyLabel(next)))) //
										.withStatement(Continue()));
							}
						}
					};
					yields.add(current);
					scan(expression);
					current = current.parent;
				} else {
					super.visitExec(tree);
				}
			}

			@Override
			public void visitIdent(final JCIdent tree) {
				if ("this".equals(tree.name.toString())) {
					method.node().addError("No unqualified 'this' expression is permitted.");
				}
				if ("super".equals(tree.name.toString())) {
					method.node().addError("No unqualified 'super' expression is permitted.");
				}
				names.add(tree.name.toString());
				super.visitIdent(tree);
			}

			@Override
			public void visitNewClass(final JCNewClass tree) {
				scan(tree.encl);
				scan(tree.clazz);
				scan(tree.args);
			}
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy