All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.scalafmt.internal.BestFirstSearch.scala Maven / Gradle / Ivy

package org.scalafmt.internal

import scala.meta.Defn

import org.scalafmt.FormatResult
import org.scalafmt.internal.ExpiresOn.Right
import org.scalafmt.internal.ExpiresOn.Left
import org.scalafmt.internal.Length.StateColumn
import org.scalafmt.internal.Length.Num
import org.scalafmt.Error.CantFormatFile
import org.scalafmt.FormatEvent.CompleteFormat
import org.scalafmt.FormatEvent.Enqueue
import org.scalafmt.FormatEvent.Explored
import org.scalafmt.FormatEvent.VisitToken
import org.scalafmt.util.LoggerOps
import org.scalafmt.util.TokenOps
import org.scalafmt.util.TreeOps
import scala.collection.mutable
import scala.meta.tokens.Token

/**
  * Implements best first search to find optimal formatting.
  */
class BestFirstSearch(
    val formatOps: FormatOps, range: Set[Range], formatWriter: FormatWriter) {
  import LoggerOps._
  import Token._
  import TokenOps._
  import TreeOps._
  import formatOps._
  import formatOps.runner.optimizer._

  /**
    * Precomputed table of splits for each token.
    */
  val routes: Array[Seq[Split]] = {
    val router = new Router(formatOps)
    val result = Array.newBuilder[Seq[Split]]
    tokens.foreach { t =>
      result += router.getSplitsMemo(t)
    }
    result.result()
  }
  val noOptimizations = noOptimizationZones(tree)
  var explored = 0
  var deepestYet = State.start
  var deepestYetSafe = State.start
  var statementCount = 0
  val best = mutable.Map.empty[Token, State]
  var pathologicalEscapes = 0
  val visits = mutable.Map.empty[FormatToken, Int].withDefaultValue(0)

  type StateHash = Long

  def isInsideNoOptZone(token: FormatToken): Boolean = {
    !disableOptimizationsInsideSensitiveAreas ||
    noOptimizations.contains(token.left)
  }

  def getLeftLeft(curr: State): Token = {
    tokens(Math.max(0, curr.splits.length - 1)).left
  }

  /**
    * Returns true if it's OK to skip over state.
    */
  def shouldEnterState(curr: State): Boolean = {
    val splitToken = tokens(curr.splits.length)
    val insideOptimizationZone =
      curr.policy.noDequeue || isInsideNoOptZone(splitToken)
    def hasBestSolution = !pruneSlowStates || insideOptimizationZone || {
      val splitToken = tokens(curr.splits.length)
      // TODO(olafur) document why/how this optimization works.
      val result = !best.get(splitToken.left).exists(_.alwaysBetter(curr))
      if (!result) {
        logger.trace(s"Eliminated $curr ${curr.splits.last}")
      }
      result
    }
    hasBestSolution
  }

  def shouldRecurseOnBlock(curr: State, stop: Token) = {
    val leftLeft = getLeftLeft(curr)
    val leftLeftOwner = ownersMap(hash(leftLeft))
    val splitToken = tokens(curr.splits.length)
    recurseOnBlocks && isInsideNoOptZone(splitToken) &&
    leftLeft.isInstanceOf[`{`] &&
    matchingParentheses(hash(leftLeft)) != stop && {
      // Block must span at least 3 lines to be worth recursing.
      val close = matchingParentheses(hash(leftLeft))
      // TODO(olafur) magic number
      close.start - leftLeft.end > style.maxColumn * 3
    } && extractStatementsIfAny(leftLeftOwner).nonEmpty
  }

  def provided(formatToken: FormatToken): Split = {
    // TODO(olafur) the indentation is not correctly set.
    val split = Split(Provided(formatToken.between.map(_.code).mkString), 0)
    val result =
      if (formatToken.left.isInstanceOf[`{`])
        split.withIndent(
            Num(2), matchingParentheses(hash(formatToken.left)), Right)
      else split
    result
  }

  def stateColumnKey(state: State): StateHash = {
    state.column << 8 | state.indentation
  }

  def hasReachedEof(state: State): Boolean = {
    explored > runner.maxStateVisits || state.splits.length == tokens.length
  }

  val memo = mutable.Map.empty[(Int, StateHash), State]

  def shortestPathMemo(start: State, stop: Token, depth: Int, maxCost: Int)(
      implicit line: sourcecode.Line): State = {
    val key = (start.splits.length, stateColumnKey(start))
    val cachedState = memo.get(key)
    cachedState match {
      case Some(state) => state
      case None =>
        // Only update state if it reached stop.
        val nextState = shortestPath(start, stop, depth, maxCost)
        if (tokens(nextState.splits.length).left == stop) {
          memo.update(key, nextState)
        }
        nextState
    }
  }

  def untilNextStatement(state: State): State = {
    var curr = state
    while (!hasReachedEof(curr) &&
           !statementStarts.contains(hash(tokens(curr.splits.length).left))) {
      val tok = tokens(curr.splits.length)
      curr = State.next(curr, style, provided(tok), tok)
    }
    curr
  }

  /**
    * Runs best first search to find lowest penalty split.
    */
  def shortestPath(start: State,
                   stop: Token,
                   depth: Int = 0,
                   maxCost: Int = Integer.MAX_VALUE)(
      implicit line: sourcecode.Line): State = {
    val Q = new mutable.PriorityQueue[State]()
    var result = start
    Q += start
    // TODO(olafur) this while loop is waaaaaaaaaaaaay tooo big.
    while (Q.nonEmpty) {
      val curr = Q.dequeue()
      explored += 1
      runner.eventCallback(Explored(explored, depth, Q.size))
      if (hasReachedEof(curr) || {
            val token = tokens(curr.splits.length)
            // If token is empty we can take one more split before reaching stop.
            token.left.code.nonEmpty && token.left.start >= stop.start
          }) {
        result = curr
        Q.dequeueAll
      } else if (shouldEnterState(curr)) {
        val splitToken = tokens(curr.splits.length)
        if (depth == 0 && curr.splits.length > deepestYet.splits.length) {
          deepestYet = curr
        }
        if (depth == 0 && curr.policy.isSafe &&
            curr.splits.length > deepestYetSafe.splits.length) {
          deepestYetSafe = curr
        }
        runner.eventCallback(VisitToken(splitToken))
        visits.put(splitToken, visits(splitToken) + 1)

        if (dequeueOnNewStatements &&
            dequeueSpots.contains(hash(splitToken.left)) &&
            (depth > 0 || !isInsideNoOptZone(splitToken)) &&
            curr.splits.last.modification.isNewline) {
          Q.dequeueAll
        }

        if (shouldRecurseOnBlock(curr, stop)) {
          val close = matchingParentheses(hash(getLeftLeft(curr)))
          val nextState = shortestPathMemo(
              curr, close, depth = depth + 1, maxCost = maxCost)
          val nextToken = tokens(nextState.splits.length)
          if (nextToken.left == close) {
            Q.enqueue(nextState)
          }
        } else {
          if (escapeInPathologicalCases &&
              visits(splitToken) > MaxVisitsPerToken) {
            // Danger zone: escape hatch for pathological cases.
            Q.dequeueAll
            best.clear()
            visits.clear()
            if (pathologicalEscapes >= MaxEscapes) {
              // Last resort. No other optimization has worked.
              Q.enqueue(untilNextStatement(curr))
            } else {
              // We are stuck, but try to continue with one cheap/fast and
              // one expensive/slow state.
              Q.enqueue(deepestYetSafe)
              pathologicalEscapes += 1
            }
          }

          val splits: Seq[Split] =
            if (curr.formatOff) List(provided(splitToken))
            else if (splitToken.inside(range)) routes(curr.splits.length)
            else List(provided(splitToken))

          val actualSplit = {
            curr.policy
              .execute(Decision(splitToken, splits))
              .splits
              .filter(!_.ignoreIf)
              .sortBy(_.cost)
          }
          var optimalNotFound = true
          actualSplit.foreach { split =>
            val nextState = State.next(curr, style, split, splitToken)
            if (depth == 0 && split.modification.isNewline &&
                !best.contains(splitToken.left)) {
              best.update(splitToken.left, nextState)
            }
            runner.eventCallback(Enqueue(split))
            split.optimalAt match {
              case Some(OptimalToken(token, killOnFail))
                  if acceptOptimalAtHints && actualSplit.length > 1 &&
                  depth < MaxDepth && split.cost == 0 =>
                val nextNextState =
                  shortestPath(nextState, token, depth + 1, maxCost = 0)(
                      sourcecode.Line.generate)
                if (hasReachedEof(nextNextState) ||
                    (nextNextState.splits.length < tokens.length && tokens(
                            nextNextState.splits.length).left.start >= token.start)) {
                  optimalNotFound = false
                  Q.enqueue(nextNextState)
                } else if (!killOnFail &&
                           nextState.cost - curr.cost <= maxCost) {
                  // TODO(olafur) DRY. This solution can still be optimal.
                  Q.enqueue(nextState)
                } // else kill branch
              case _
                  if optimalNotFound &&
                  nextState.cost - curr.cost <= maxCost =>
                Q.enqueue(nextState)
              case _ => // Kill branch.
            }
          }
        }
      }
    }
    result
  }

  def getBestPath: SearchResult = {
    val state = shortestPath(State.start, tree.tokens.last)
    if (state.splits.length == tokens.length) {
      runner.eventCallback(CompleteFormat(explored, state, tokens))
      SearchResult(state.splits, reachedEOF = true)
    } else {
      val nextSplits = routes(deepestYet.splits.length)
      val tok = tokens(deepestYet.splits.length)
      val splitsAfterPolicy =
        deepestYet.policy.execute(Decision(tok, nextSplits))
      val msg = s"""UNABLE TO FORMAT,
                   |tok=$tok
                   |state.length=${state.splits.length}
                   |toks.length=${tokens.length}
                   |deepestYet.length=${deepestYet.splits.length}
                   |policies=${deepestYet.policy.policies}
                   |nextSplits=$nextSplits
                   |splitsAfterPolicy=$splitsAfterPolicy""".stripMargin
      logger.error(s"""Failed to format
                      |$msg""".stripMargin)
      runner.eventCallback(CompleteFormat(explored, deepestYet, tokens))
      SearchResult(deepestYet.splits, reachedEOF = false)
    }
  }
}

case class SearchResult(splits: Vector[Split], reachedEOF: Boolean)




© 2015 - 2025 Weber Informatics LLC | Privacy Policy