All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scalafix.internal.rule.Disable.scala Maven / Gradle / Ivy

package scalafix.internal.rule

import metaconfig.{Conf, Configured}

import scala.meta._
import scala.meta.transversers.Traverser
import scalafix.internal.config.{DisableConfig, DisabledSymbol}
import scalafix.internal.util.SymbolOps
import scalafix.lint.{LintCategory, LintMessage}
import scalafix.rule.{Rule, RuleCtx, SemanticRule}
import scalafix.util.SemanticdbIndex

object Disable {

  /**
    * A tree traverser to collect values with a custom context.
    * At every tree node, either builds a new Context or returns a new Value to accumulate.
    * To collect all accumulated values, use result(Tree).
    */
  class ContextTraverser[Value, Context](initContext: Context)(
      fn: PartialFunction[(Tree, Context), Either[Value, Context]])
      extends Traverser {
    private var context: Context = initContext
    private val buf = scala.collection.mutable.ListBuffer[Value]()

    private val liftedFn = fn.lift

    override def apply(tree: Tree): Unit = {
      liftedFn((tree, context)) match {
        case Some(Left(res)) =>
          buf += res
        case Some(Right(newContext)) =>
          val oldContext = context
          context = newContext
          super.apply(tree)
          context = oldContext
        case None =>
          super.apply(tree)
      }
    }

    def result(tree: Tree): List[Value] = {
      context = initContext
      buf.clear()
      apply(tree)
      buf.toList
    }
  }

  final class DisableSymbolMatcher(symbols: List[DisabledSymbol])(
      implicit index: SemanticdbIndex) {
    def findMatch(symbol: Symbol): Option[DisabledSymbol] =
      symbols.find(_.matches(symbol))

    def unapply(tree: Tree): Option[(Tree, DisabledSymbol)] =
      index
        .symbol(tree)
        .flatMap(findMatch(_).map(ds => (tree, ds)))

    def unapply(symbol: Symbol): Option[(Symbol, DisabledSymbol)] =
      findMatch(symbol).map(ds => (symbol, ds))
  }
}

final case class Disable(index: SemanticdbIndex, config: DisableConfig)
    extends SemanticRule(index, "Disable") {

  import Disable._

  private lazy val errorCategory: LintCategory =
    LintCategory.error(
      """Some constructs are unsafe to use and should be avoided""".stripMargin
    )

  override def description: String =
    "Linter that reports an error on a configurable set of symbols."

  override def init(config: Conf): Configured[Rule] =
    config
      .getOrElse("disable", "Disable")(DisableConfig.default)
      .map(Disable(index, _))

  private val safeBlock = new DisableSymbolMatcher(config.allSafeBlocks)
  private val disabledSymbolInSynthetics =
    new DisableSymbolMatcher(config.ifSynthetic)

  private def createLintMessage(
      symbol: Symbol.Global,
      disabled: DisabledSymbol,
      pos: Position,
      details: String = ""): LintMessage = {
    val message = disabled.message.getOrElse(
      s"${symbol.signature.name} is disabled$details")

    val id = disabled.id.getOrElse(symbol.signature.name)

    errorCategory
      .copy(id = id)
      .at(message, pos)
  }

  private def checkTree(ctx: RuleCtx): Seq[LintMessage] = {
    def filterBlockedSymbolsInBlock(
        blockedSymbols: List[DisabledSymbol],
        block: Tree): List[DisabledSymbol] =
      ctx.index.symbol(block) match {
        case Some(symbolBlock: Symbol.Global) =>
          val symbolsInMatchedBlocks =
            config.unlessInside.flatMap(
              u =>
                if (u.safeBlock.matches(symbolBlock)) u.symbols
                else List.empty)
          val res = blockedSymbols.filterNot(symbolsInMatchedBlocks.contains)
          res
        case _ => blockedSymbols
      }

    def skipTermSelect(term: Term): Boolean = term match {
      case _: Term.Name => true
      case Term.Select(q, _) => skipTermSelect(q)
      case _ => false
    }

    def handleName(t: Name, blockedSymbols: List[DisabledSymbol])
      : Either[LintMessage, List[DisabledSymbol]] = {
      val isBlocked = new DisableSymbolMatcher(blockedSymbols)
      ctx.index.symbol(t) match {
        case Some(isBlocked(s: Symbol.Global, disabled)) =>
          SymbolOps.normalize(s) match {
            case g: Symbol.Global if g.signature.name != "" =>
              Left(createLintMessage(g, disabled, t.pos))
            case _ => Right(blockedSymbols)
          }
        case _ => Right(blockedSymbols)
      }
    }

    new ContextTraverser(config.allDisabledSymbols)({
      case (_: Import, _) => Right(List.empty)
      case (Term.Select(q, name), blockedSymbols) if skipTermSelect(q) =>
        handleName(name, blockedSymbols)
      case (Type.Select(q, name), blockedSymbols) if skipTermSelect(q) =>
        handleName(name, blockedSymbols)
      case (
          Term
            .Apply(Term.Select(block @ safeBlock(_, _), Term.Name("apply")), _),
          blockedSymbols) =>
        Right(filterBlockedSymbolsInBlock(blockedSymbols, block)) // .apply
      case (Term.Apply(block @ safeBlock(_, _), _), blockedSymbols) =>
        Right(filterBlockedSymbolsInBlock(blockedSymbols, block)) // (...)
      case (_: Defn.Def, _) =>
        Right(config.allDisabledSymbols) // reset blocked symbols in def
      case (_: Term.Function, _) =>
        Right(config.allDisabledSymbols) // reset blocked symbols in (...) => (...)
      case (t: Name, blockedSymbols) =>
        handleName(t, blockedSymbols)
    }).result(ctx.tree)
  }

  private def checkSynthetics(ctx: RuleCtx): Seq[LintMessage] = {
    for {
      document <- ctx.index.documents.view
      ResolvedName(
        pos,
        disabledSymbolInSynthetics(symbol @ Symbol.Global(_, _), disabled),
        false
      ) <- document.synthetics.view.flatMap(_.names)
    } yield {
      val (details, caret) = pos.input match {
        case synthetic @ Input.Synthetic(_, input, start, end) =>
          // For synthetics the caret should point to the original position
          // but display the inferred code.
          s" and it got inferred as `${synthetic.text}`" ->
            Position.Range(input, start, end)
        case _ =>
          "" -> pos
      }
      createLintMessage(symbol, disabled, caret, details)
    }
  }

  override def check(ctx: RuleCtx): Seq[LintMessage] = {
    checkTree(ctx) ++ checkSynthetics(ctx)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy