All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scalafix.internal.patch.ImportPatchOps.scala Maven / Gradle / Ivy

package scalafix
package internal.patch

import scala.annotation.tailrec
import scala.collection.immutable.Seq
import scala.collection.mutable
import scala.meta._
import scalafix.internal.util.SymbolOps
import scalafix.patch.Patch
import scalafix.patch.TokenPatch
import scalafix.patch.TreePatch
import scalafix.patch.TreePatch.ImportPatch
import scalafix.rule.RuleCtx
import scalafix.syntax._
import scalafix.util.Newline

object ImportPatchOps {
  object symbols {
    val Scala: Symbol = Symbol("_root_.scala.")
    val Predef: Symbol = Symbol("_root_.scala.Predef.")
    val Java: Symbol = Symbol("_root_.java.lang.")
    val Immutable: Symbol = Symbol("_root_.scala.collection.immutable.")
  }

  def isPredef(symbol: Symbol): Boolean = {
    import symbols._
    symbol match {
      case Symbol.Global(`Immutable`, Signature.Type("List" | "Map" | "Set")) =>
        true
      case Symbol.Global(owner, _) =>
        owner == Scala ||
          owner == Predef ||
          owner == Java
      case _ => false
    }
  }

  private def fallbackToken(ctx: RuleCtx): Token = {
    def loop(tree: Tree): Token = tree match {
      case Source((stat: Pkg) :: _) => loop(stat)
      case Source(_) => ctx.toks(tree).head
      case Pkg(_, stat :: _) => loop(stat)
      case els => ctx.tokenList.prev(ctx.tokenList.prev(ctx.toks(els).head))
    }
    loop(ctx.tree)
  }
  private def extractImports(stats: Seq[Stat]): Seq[Import] = {
    stats
      .takeWhile(_.is[Import])
      .collect { case i: Import => i }
  }

  @tailrec private final def getLastTopLevelPkg(potPkg: Stat): Stat =
    potPkg match {
      case Pkg(_, head +: Nil) => getLastTopLevelPkg(head)
      case Pkg(_, head +: _) => head
      case _ => potPkg
    }

  @tailrec private final def getGlobalImports(ast: Tree): Seq[Import] =
    ast match {
      case Pkg(_, Seq(pkg: Pkg)) => getGlobalImports(pkg)
      case Source(Seq(pkg: Pkg)) => getGlobalImports(pkg)
      case Pkg(_, stats) => extractImports(stats)
      case Source(stats) => extractImports(stats)
      case _ => Nil
    }

  // NOTE(olafur): This method is the simplest/dummest thing I can think of
  private[scalafix] def superNaiveImportPatchToTokenPatchConverter(
      ctx: RuleCtx,
      importPatches: Seq[ImportPatch])(
      implicit index: SemanticdbIndex): Iterable[Patch] = {
    val allImports = ctx.tree.collect { case i: Import => i }
    val allImporters = allImports.flatMap(_.importers)
    lazy val allImportersSyntax = allImporters.map(_.syntax)
    val allImportees = allImporters.flatMap(_.importees)
    val allNamedImports = allImportees.collect {
      case Importee.Name(n) if index.names.contains(n.pos) =>
        n.symbol
      // TODO(olafur) handle rename.
    }
    val allImporteeSymbols = allImportees.flatMap(importee =>
      importee.symbol.map(_.normalized -> importee))
    val globalImports = getGlobalImports(ctx.tree)
    val editToken: Token = {
      if (globalImports.isEmpty) fallbackToken(ctx)
      else ctx.toks(globalImports.last).last
    }
    val isRemovedImportee = mutable.LinkedHashSet.empty[Importee]
    importPatches.foreach {
      case TreePatch.RemoveGlobalImport(sym) =>
        allImporteeSymbols
          .withFilter(_._1 == sym.normalized)
          .foreach { case (_, x) => isRemovedImportee += x }
      case x: TreePatch.RemoveImportee => isRemovedImportee += x.importee
      case _ =>
    }
    val importersToAdd = {
      val isAlreadyImported = mutable.Set.empty[Symbol]
      for { // register global imports
        import_ <- globalImports
        importer <- import_.importers
        importee <- importer.importees
        symbol <- index.symbol(importee).toList
        underlying <- SymbolOps.underlyingSymbols(symbol)
      } {
        isAlreadyImported += underlying
      }
      importPatches.flatMap {
        case TreePatch.AddGlobalSymbol(symbol)
            if !allNamedImports.contains(symbol) &&
              !isAlreadyImported(symbol) &&
              !isPredef(symbol) =>
          isAlreadyImported += symbol
          SymbolOps.toImporter(symbol).toList
        case TreePatch.AddGlobalImport(importer)
            // best effort deduplication for syntactic addGlobalImport(Importer)
            if !allImportersSyntax.contains(importer.syntax) =>
          importer :: Nil
        case _ => Nil
      }
    }
    val grouped: Seq[Importer] =
      if (ctx.config.groupImportsByPrefix)
        importersToAdd
          .groupBy(_.ref.syntax)
          .map {
            case (_, is) =>
              Importer(
                is.head.ref,
                is.flatMap(_.importees)
                  .sortBy({
                    case Importee.Name(n) => n.value
                    case Importee.Rename(n, _) => n.value
                    case Importee.Unimport(n) => n.value
                    case Importee.Wildcard() => '\uFFFF'.toString
                  })
                  .toList
              )
          }
          .toList
      else importersToAdd
    val extraPatches =
      grouped
        .sortBy(_.ref.syntax)
        .map(is => ctx.addRight(editToken, s"\nimport ${is.syntax}"))
    val isRemovedImporter =
      allImporters.toIterator
        .filter(_.importees.forall(isRemovedImportee))
        .toSet
    def removeSpaces(tokens: scala.Seq[Token]): Patch =
      tokens
        .takeWhile {
          case Token.Space() => true
          case _ => false
        }
        .map(ctx.removeToken(_))
        .asPatch
    val curlyBraceRemoves = allImporters.map { importer =>
      val keptImportees = importer.importees.filterNot(isRemovedImportee)
      val hasRemovedImportee = importer.importees.exists(isRemovedImportee)
      keptImportees match {
        case (Importee.Wildcard() | Importee.Name(_)) +: Nil
            if hasRemovedImportee =>
          ctx
            .toks(importer)
            .collectFirst {
              case open @ Token.LeftBrace() =>
                ctx.matchingParens
                  .close(open)
                  .map { close =>
                    ctx.removeToken(open) +
                      removeSpaces(ctx.tokenList.trailing(open)) +
                      ctx.removeToken(close) +
                      removeSpaces(ctx.tokenList.leading(close))
                  }
                  .asPatch
            }
            .asPatch
        case _ => Patch.empty
      }
    }
    // NOTE: keeps track of which comma is removed by which tree to prevent the
    // same comma being removed twice.
    val isRemovedComma = mutable.Map.empty[Token.Comma, Tree]
    val isRemovedImport =
      allImports.filter(_.importers.forall(isRemovedImporter))
    def remove(toRemove: Tree) = {
      val tokens = ctx.toks(toRemove)
      def removeFirstComma(lst: Iterable[Token]): Iterable[Patch] = {
        lst
          .takeWhile {
            case lf @ Token.LF() if ctx.tokenList.prev(lf).is[Token.Comma] =>
              true
            case Token.Space() => true
            case comma @ Token.Comma() =>
              if (!isRemovedComma.contains(comma)) {
                isRemovedComma(comma) = toRemove
              }
              true
            case _ => false
          }
          .map(ctx.removeToken(_))
      }
      val leadingComma =
        removeFirstComma(ctx.tokenList.leading(tokens.head))
      val hadLeadingComma = leadingComma.exists {
        case TokenPatch.Add(comma: Token.Comma, _, _, keepTok @ false) =>
          isRemovedComma.get(comma).contains(toRemove)
        case _ => false
      }
      val trailingComma =
        if (hadLeadingComma) List(Patch.empty)
        else removeFirstComma(ctx.tokenList.trailing(tokens.last))
      ctx.removeTokens(tokens) ++ trailingComma ++ leadingComma
    }

    val leadingNewlines = isRemovedImport.map { i =>
      var newline = false
      ctx.tokenList
        .leading(ctx.toks(i).head)
        .takeWhile(x =>
          !newline && {
            x.is[Token.Space] || {
              val isNewline = x.is[Newline]
              if (isNewline) newline = true
              isNewline
            }
        })
        .map(tok => ctx.removeToken(tok))
        .asPatch
    }

    leadingNewlines ++
      curlyBraceRemoves ++
      extraPatches ++
      (isRemovedImportee ++
        isRemovedImporter ++
        isRemovedImport).map(remove)
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy