All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.jetbrains.python.refactoring.introduce.field.PyIntroduceFieldHandler Maven / Gradle / Ivy

Go to download

A packaging of the IntelliJ Community Edition python-community library. This is release number 1 of trunk branch 142.

The newest version!
/*
 * Copyright 2000-2014 JetBrains s.r.o.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.jetbrains.python.refactoring.introduce.field;

import com.intellij.lang.ASTNode;
import com.intellij.openapi.actionSystem.DataContext;
import com.intellij.openapi.application.AccessToken;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.editor.CaretModel;
import com.intellij.openapi.editor.Document;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.editor.SelectionModel;
import com.intellij.openapi.project.Project;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.PsiReference;
import com.intellij.psi.search.LocalSearchScope;
import com.intellij.psi.search.searches.ReferencesSearch;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.refactoring.RefactoringBundle;
import com.intellij.refactoring.introduce.inplace.InplaceVariableIntroducer;
import com.intellij.refactoring.util.CommonRefactoringUtil;
import com.intellij.util.Function;
import com.intellij.util.FunctionUtil;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.inspections.quickfix.AddFieldQuickFix;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyFunctionBuilder;
import com.jetbrains.python.refactoring.PyReplaceExpressionUtil;
import com.jetbrains.python.refactoring.introduce.IntroduceHandler;
import com.jetbrains.python.refactoring.introduce.IntroduceOperation;
import com.jetbrains.python.refactoring.introduce.variable.PyIntroduceVariableHandler;
import com.jetbrains.python.testing.PythonUnitTestUtil;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import javax.swing.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;

/**
 * @author Dennis.Ushakov
 */
public class PyIntroduceFieldHandler extends IntroduceHandler {

  public PyIntroduceFieldHandler() {
    super(new IntroduceFieldValidator(), RefactoringBundle.message("introduce.field.title"));
  }

  public void invoke(@NotNull Project project, Editor editor, PsiFile file, DataContext dataContext) {
    final IntroduceOperation operation = new IntroduceOperation(project, editor, file, null);
    operation.addAvailableInitPlace(InitPlace.CONSTRUCTOR);
    if (isTestClass(file, editor)) {
      operation.addAvailableInitPlace(InitPlace.SET_UP);
    }
    performAction(operation);
  }

  private static boolean isTestClass(PsiFile file, Editor editor) {
    PsiElement element1 = null;
    final SelectionModel selectionModel = editor.getSelectionModel();
    if (selectionModel.hasSelection()) {
      element1 = file.findElementAt(selectionModel.getSelectionStart());
    }
    else {
      final CaretModel caretModel = editor.getCaretModel();
      final Document document = editor.getDocument();
      int lineNumber = document.getLineNumber(caretModel.getOffset());
      if ((lineNumber >= 0) && (lineNumber < document.getLineCount())) {
        element1 = file.findElementAt(document.getLineStartOffset(lineNumber));
      }
    }
    if (element1 != null) {
      final PyClass clazz = PyUtil.getContainingClassOrSelf(element1);
      if (clazz != null && PythonUnitTestUtil.isTestCaseClass(clazz)) return true;
    }
    return false;
  }

  @Override
  protected PsiElement replaceExpression(PsiElement expression, PyExpression newExpression, IntroduceOperation operation) {
    if (operation.getInitPlace() != InitPlace.SAME_METHOD) {
      return PyReplaceExpressionUtil.replaceExpression(expression, newExpression);
    }
    return super.replaceExpression(expression, newExpression, operation);
  }

  @Override
  protected boolean checkEnabled(IntroduceOperation operation) {
    if (PyUtil.getContainingClassOrSelf(operation.getElement()) == null) {
      CommonRefactoringUtil.showErrorHint(operation.getProject(), operation.getEditor(), "Cannot introduce field: not in class", myDialogTitle,
                                          getHelpId());
      return false;
    }
    if (dependsOnLocalScopeValues(operation.getElement())) {
      operation.removeAvailableInitPlace(InitPlace.CONSTRUCTOR);
      operation.removeAvailableInitPlace(InitPlace.SET_UP);
    }
    return true;
  }

  private static boolean dependsOnLocalScopeValues(PsiElement initializer) {
    ScopeOwner scope = PsiTreeUtil.getParentOfType(initializer, ScopeOwner.class);
    ResolvingVisitor visitor = new ResolvingVisitor(scope);
    initializer.accept(visitor);
    return visitor.hasLocalScopeDependencies;
    
  }
  
  private static class ResolvingVisitor extends PyRecursiveElementVisitor {
    private boolean hasLocalScopeDependencies = false;
    private final ScopeOwner myScope;

    public ResolvingVisitor(ScopeOwner scope) {
      myScope = scope;
    }

    @Override
    public void visitPyReferenceExpression(PyReferenceExpression node) {
      super.visitPyReferenceExpression(node);
      final PsiElement result = node.getReference().resolve();
      if (result != null && PsiTreeUtil.getParentOfType(result, ScopeOwner.class) == myScope) {
        if (result instanceof PyParameter && myScope instanceof PyFunction) {
          final PyFunction function = (PyFunction)myScope;
          final PyParameter[] parameters = function.getParameterList().getParameters();
          if (parameters.length > 0 && result == parameters[0]) {
            final PyFunction.Modifier modifier = function.getModifier();
            if (modifier != PyFunction.Modifier.STATICMETHOD) {
              // 'self' is not a local scope dependency
              return;
            }
          }
        }
        hasLocalScopeDependencies = true;
      }
    }
  }

  @Nullable
  @Override
  protected PsiElement addDeclaration(@NotNull PsiElement expression, @NotNull PsiElement declaration, @NotNull IntroduceOperation operation) {
    final PsiElement expr = expression instanceof PyClass ? expression : expression.getParent();    
    PsiElement anchor = PyUtil.getContainingClassOrSelf(expr);
    assert anchor instanceof PyClass;
    final PyClass clazz = (PyClass)anchor;
    final Project project = anchor.getProject();
    if (operation.getInitPlace() == InitPlace.CONSTRUCTOR && !inConstructor(expression)) {
      return AddFieldQuickFix.addFieldToInit(project, clazz, "", new AddFieldDeclaration(declaration));
    } else if (operation.getInitPlace() == InitPlace.SET_UP) {
      return addFieldToSetUp(clazz, new AddFieldDeclaration(declaration));
    }
    return PyIntroduceVariableHandler.doIntroduceVariable(expression, declaration, operation.getOccurrences(), operation.isReplaceAll());
  }

  private static boolean inConstructor(@NotNull PsiElement expression) {
    final PsiElement expr = expression instanceof PyClass ? expression : expression.getParent();
    PyClass clazz = PyUtil.getContainingClassOrSelf(expr);
    PsiElement current = PyUtil.getConcealingParent(expression);
    if (clazz != null && current != null && current instanceof PyFunction) {
      PyFunction init = clazz.findMethodByName(PyNames.INIT, false);
      if (current == init) {
        return true;
      }
    }
    return false;
  }

  @Nullable
  private static PsiElement addFieldToSetUp(PyClass clazz, final Function callback) {
    final PyFunction init = clazz.findMethodByName(PythonUnitTestUtil.TESTCASE_SETUP_NAME, false);
    if (init != null) {
      return AddFieldQuickFix.appendToMethod(init, callback);
    }
    final PyFunctionBuilder builder = new PyFunctionBuilder(PythonUnitTestUtil.TESTCASE_SETUP_NAME);
    builder.parameter(PyNames.CANONICAL_SELF);
    PyFunction setUp = builder.buildFunction(clazz.getProject(), LanguageLevel.getDefault());
    final PyStatementList statements = clazz.getStatementList();
    final PsiElement anchor = statements.getFirstChild();
    setUp = (PyFunction)statements.addBefore(setUp, anchor);
    return AddFieldQuickFix.appendToMethod(setUp, callback);
  }

  @Override
  protected List getOccurrences(PsiElement element, @NotNull PyExpression expression) {
    if (isAssignedLocalVariable(element)) {
      PyFunction function = PsiTreeUtil.getParentOfType(element, PyFunction.class);
      Collection references = ReferencesSearch.search(element, new LocalSearchScope(function)).findAll();
      ArrayList result = new ArrayList();
      for (PsiReference reference : references) {
        PsiElement refElement = reference.getElement();
        if (refElement != element) {
          result.add(refElement);
        }
      }
      return result;
    }
    return super.getOccurrences(element, expression);
  }

  @Override
  protected PyExpression createExpression(Project project, String name, PsiElement declaration) {
    final String text = declaration.getText();
    final String self_name = text.substring(0, text.indexOf('.'));
    return PyElementGenerator.getInstance(project).createExpressionFromText(self_name + "." + name);
  }

  @Override
  protected PyAssignmentStatement createDeclaration(Project project, String assignmentText, PsiElement anchor) {
    final PyFunction container = PsiTreeUtil.getParentOfType(anchor, PyFunction.class);
    String selfName = PyUtil.getFirstParameterName(container);
    final LanguageLevel langLevel = LanguageLevel.forElement(anchor);
    return PyElementGenerator.getInstance(project).createFromText(langLevel, PyAssignmentStatement.class, selfName + "." + assignmentText);
  }

  @Override
  protected void postRefactoring(PsiElement element) {
    if (isAssignedLocalVariable(element)) {
      element.getParent().delete();
    }
  }

  private static boolean isAssignedLocalVariable(PsiElement element) {
    if (element instanceof PyTargetExpression && element.getParent() instanceof PyAssignmentStatement &&
        PsiTreeUtil.getParentOfType(element, PyFunction.class) != null) {
      PyAssignmentStatement stmt = (PyAssignmentStatement) element.getParent();
      if (stmt.getTargets().length == 1) {
        return true;
      }
    }
    return false;
  }

  @Override
  protected String getHelpId() {
    return "python.reference.introduceField";
  }

  @Override
  protected boolean checkIntroduceContext(PsiFile file, Editor editor, PsiElement element) {
    if (element != null && isInStaticMethod(element)) {
      CommonRefactoringUtil.showErrorHint(file.getProject(), editor, "Introduce Field refactoring cannot be used in static methods",
                                          RefactoringBundle.message("introduce.field.title"),
                                          "refactoring.extractMethod");
      return false;
    }
    return super.checkIntroduceContext(file, editor, element);
  }

  private static boolean isInStaticMethod(PsiElement element) {
    PyFunction containingMethod = PsiTreeUtil.getParentOfType(element, PyFunction.class, false, PyClass.class);
    if (containingMethod != null) {
      final PyFunction.Modifier modifier = containingMethod.getModifier();
      return modifier == PyFunction.Modifier.STATICMETHOD;
    }
    return false;
  }

  @Override
  protected boolean isValidIntroduceContext(PsiElement element) {
    return super.isValidIntroduceContext(element) &&
           PsiTreeUtil.getParentOfType(element, PyFunction.class, false, PyClass.class) != null &&
           PsiTreeUtil.getParentOfType(element, PyDecoratorList.class) == null &&
           !isInStaticMethod(element);
  }

  private static class AddFieldDeclaration implements Function {
    private final PsiElement myDeclaration;

    private AddFieldDeclaration(PsiElement declaration) {
      myDeclaration = declaration;
    }

    public PyStatement fun(String self_name) {
      if (PyNames.CANONICAL_SELF.equals(self_name)) {
        return (PyStatement)myDeclaration;
      }
      final String text = myDeclaration.getText();
      final Project project = myDeclaration.getProject();
      return PyElementGenerator.getInstance(project).createFromText(LanguageLevel.getDefault(), PyStatement.class,
                                                                    text.replaceFirst(PyNames.CANONICAL_SELF + "\\.", self_name + "."));
    }
  }

  @Override
  protected void performInplaceIntroduce(IntroduceOperation operation) {
    final PsiElement statement = performRefactoring(operation);
    // put caret on identifier after "self."
    if (statement instanceof PyAssignmentStatement) {
        final List occurrences = operation.getOccurrences();
        final PsiElement occurrence = findOccurrenceUnderCaret(occurrences, operation.getEditor());
        PyTargetExpression target = (PyTargetExpression) ((PyAssignmentStatement)statement).getTargets() [0];
        putCaretOnFieldName(operation.getEditor(), occurrence != null ? occurrence : target);
        final InplaceVariableIntroducer introducer = new PyInplaceFieldIntroducer(target, operation, occurrences);
        introducer.performInplaceRefactoring(new LinkedHashSet(operation.getSuggestedNames()));
      }
    }

  private static void putCaretOnFieldName(Editor editor, PsiElement occurrence) {
    PyQualifiedExpression qExpr = PsiTreeUtil.getParentOfType(occurrence, PyQualifiedExpression.class, false);
    if (qExpr != null && !qExpr.isQualified()) {
      qExpr = PsiTreeUtil.getParentOfType(qExpr, PyQualifiedExpression.class);
    }
    if (qExpr != null) {
      final ASTNode nameElement = qExpr.getNameElement();
      if (nameElement != null) {
        final int offset = nameElement.getTextRange().getStartOffset();
        editor.getCaretModel().moveToOffset(offset);
      }
    }
  }

  private static class PyInplaceFieldIntroducer extends InplaceVariableIntroducer {
    private final PyTargetExpression myTarget;
    private final IntroduceOperation myOperation;
    private final PyIntroduceFieldPanel myPanel;

    public PyInplaceFieldIntroducer(PyTargetExpression target,
                                    IntroduceOperation operation,
                                    List occurrences) {
      super(target, operation.getEditor(), operation.getProject(), "Introduce Field",
            occurrences.toArray(new PsiElement[occurrences.size()]), null);
      myTarget = target;
      myOperation = operation;
      if (operation.getAvailableInitPlaces().size() > 1) {
        myPanel = new PyIntroduceFieldPanel(myProject, operation.getAvailableInitPlaces());
      }
      else {
        myPanel = null;
      }
    }

    @Override
    protected PsiElement checkLocalScope() {
      return myTarget.getContainingFile();
    }

    @Override
    protected JComponent getComponent() {
      return myPanel == null ? null : myPanel.getRootPanel();
    }

    @Override
    protected void moveOffsetAfter(boolean success) {
      if (success && (myPanel != null && myPanel.getInitPlace() != InitPlace.SAME_METHOD) || myOperation.getInplaceInitPlace() != InitPlace.SAME_METHOD) {
        final AccessToken accessToken = ApplicationManager.getApplication().acquireWriteActionLock(getClass());
        try {
          final PyAssignmentStatement initializer = PsiTreeUtil.getParentOfType(myTarget, PyAssignmentStatement.class);
          assert initializer != null;
          final Function callback = FunctionUtil.constant(initializer);
          final PyClass pyClass = PyUtil.getContainingClassOrSelf(initializer);
          InitPlace initPlace = myPanel != null ? myPanel.getInitPlace() : myOperation.getInplaceInitPlace();
          if (initPlace == InitPlace.CONSTRUCTOR) {
            AddFieldQuickFix.addFieldToInit(myProject, pyClass, "", callback);
          }
          else if (initPlace == InitPlace.SET_UP) {
            addFieldToSetUp(pyClass, callback);
          }
          if (myOperation.getOccurrences().size() > 0) {
            initializer.delete();
          }
          else {
            final PyExpression copy =
              PyElementGenerator.getInstance(myProject).createExpressionFromText(LanguageLevel.forElement(myTarget), myTarget.getText());
            initializer.replace(copy);
          }
          initializer.delete();
        }
        finally {
          accessToken.finish();
        }
      }
    }
  }

  @Override
  protected String getRefactoringId() {
    return "refactoring.python.introduce.field";
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy