All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.sonar.python.checks.TorchModuleModeShouldBeSetAfterLoadingCheck Maven / Gradle / Ivy

There is a newer version: 4.23.0.17664
Show newest version
/*
 * SonarQube Python Plugin
 * Copyright (C) 2011-2024 SonarSource SA
 * mailto:info AT sonarsource DOT com
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the Sonar Source-Available License Version 1, as published by SonarSource SA.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the Sonar Source-Available License for more details.
 *
 * You should have received a copy of the Sonar Source-Available License
 * along with this program; if not, see https://sonarsource.com/license/ssal/
 */
package org.sonar.python.checks;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import org.sonar.check.Rule;
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
import org.sonar.plugins.python.api.symbols.Symbol;
import org.sonar.plugins.python.api.symbols.Usage;
import org.sonar.plugins.python.api.tree.CallExpression;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.QualifiedExpression;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.plugins.python.api.types.InferredType;
import org.sonar.python.tree.TreeUtils;

@Rule(key = "S6982")
public class TorchModuleModeShouldBeSetAfterLoadingCheck extends PythonSubscriptionCheck {
  private static final Set STATE_SETTING_FUNCTION_FQNS = Set.of("eval", "train");
  private static final String LOAD_STATE_DICT_NAME = "load_state_dict";
  private static final String MESSAGE = "Set the module in training or evaluation mode.";


  @Override
  public void initialize(Context context) {
    context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, ctx -> {
      CallExpression callExpr = (CallExpression) ctx.syntaxNode();
      List receiverUsages = getForwardUsagesOfReceiver(callExpr);
      if (isLoadStateDictCall(callExpr) && !hasEvalOrTrainUsage(receiverUsages) && !isModelPassedOn(receiverUsages)) {
        ctx.addIssue(callExpr.callee(), MESSAGE);
      }
    });
  }

  private static boolean isLoadStateDictCall(CallExpression callExpr) {
    // To properly check if the correct load_state_dict is called, typeshed type information would be required.
    // Since this is currently not possible, we check if the parameter to load_state_dict is torch.load(...),
    // with the assumption that if torch.load is passed to this load_state_dict, it is probably the correct method
    if (callExpr.callee() instanceof QualifiedExpression qualifiedExpr) {
      InferredType qualifierType = qualifiedExpr.qualifier().type();
      boolean isModule = qualifierType.mustBeOrExtend("torch.nn.modules.module.Module")
        || qualifierType.mustBeOrExtend("torch.nn.Module");
      return isModule && LOAD_STATE_DICT_NAME.equals(qualifiedExpr.name().name());
    }
    return false;
  }

  private static List getForwardUsagesOfReceiver(CallExpression callExpr) {
    List usages = getFunctionCallReceiverName(callExpr)
      .flatMap(name -> Optional.ofNullable(name.symbol()))
      .map(Symbol::usages)
      .orElse(Collections.emptyList());

    return usages.stream()
      .filter(usage -> usage.tree().firstToken().line() > callExpr.firstToken().line())
      .toList();
  }

  private static Optional getFunctionCallReceiverName(CallExpression callExpr) {
    return Optional.ofNullable(callExpr.callee())
      .flatMap(TreeUtils.toOptionalInstanceOfMapper(QualifiedExpression.class))
      .flatMap(qualifiedExpr -> Optional.ofNullable(qualifiedExpr.qualifier()))
      .flatMap(TreeUtils.toOptionalInstanceOfMapper(Name.class));
  }

  private static boolean hasEvalOrTrainUsage(List usages) {
    return usages.stream().anyMatch(TorchModuleModeShouldBeSetAfterLoadingCheck::isEvalOrTrain);
  }

  private static boolean isEvalOrTrain(Usage usage) {
    Tree callTree = TreeUtils.firstAncestorOfKind(usage.tree(), Tree.Kind.CALL_EXPR);
    if (callTree != null) {
      CallExpression usageCall = (CallExpression) callTree;
      Symbol usageCallSymbol = usageCall.calleeSymbol();
      return usageCallSymbol != null && STATE_SETTING_FUNCTION_FQNS.contains(usageCallSymbol.name());
    }
    return false;
  }

  private static boolean isModelPassedOn(List usages) {
    return usages.stream().anyMatch(TorchModuleModeShouldBeSetAfterLoadingCheck::isPassingModel);
  }

  private static boolean isPassingModel(Usage usage) {
    return TreeUtils.firstAncestorOfKind(usage.tree(), Tree.Kind.CALL_EXPR) != null;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy