All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scala.scalanative.optimizer.pass.GlobalValueNumbering.scala Maven / Gradle / Ivy

The newest version!
package scala.scalanative
package optimizer
package pass

import scala.collection.mutable
import scala.util.hashing.MurmurHash3

import analysis.ClassHierarchy.Top
import analysis.ControlFlow
import analysis.ControlFlow.Block
import analysis.DominatorTree

import nir._

class GlobalValueNumbering extends Pass {
  import GlobalValueNumbering._

  override def onInsts(insts: Seq[Inst]): Seq[Inst] = {
    val cfg        = ControlFlow.Graph(insts)
    val domination = DominatorTree.build(cfg)

    performSimpleValueNumbering(cfg, domination)
  }

  private def performSimpleValueNumbering(
      cfg: ControlFlow.Graph,
      domination: Map[Block, Set[Block]]): Seq[Inst] = {

    val variableVN   = mutable.HashMap.empty[Local, Hash]
    val instructions = mutable.HashMap.empty[Hash, List[Inst.Let]]
    val localDefs    = mutable.HashMap.empty[Local, Inst]

    val hash       = new HashFunction(variableVN)
    val deepEquals = new DeepEquals(localDefs)

    def blockDominatedByDef(dominatedBlock: Block,
                            dominatingDef: Local): Boolean = {

      domination(dominatedBlock).exists { dominatingBlock =>
        val foundInParam = dominatingBlock.params.exists {
          case Val.Local(paramName, _) => (paramName == dominatingDef)
        }
        val foundInInsts = dominatingBlock.insts.exists {
          case Inst.Let(name, _) => (name == dominatingDef)
          case _                 => false
        }

        foundInParam || foundInInsts
      }
    }

    val newInsts = cfg.map { block =>
      variableVN ++= block.params.map(lval =>
        (lval.name, HashFunction.rawLocal(lval.name)))
      localDefs ++= block.params.map(lval => (lval.name, block.label))

      val newBlockInsts = block.insts.map {

        case inst: Inst.Let => {
          val idempotent = isIdempotent(inst.op)

          val instHash =
            if (idempotent)
              hash(inst.op)
            else
              inst.hashCode // hash the assigned variable as well, so a = op(b) and c = op(b) don't have the same hash

          variableVN += (inst.name -> instHash)
          localDefs += (inst.name  -> inst)

          if (idempotent) {
            val hashEqualInstrs = instructions.getOrElse(instHash, Nil)
            instructions += (instHash -> (inst :: hashEqualInstrs))

            val equalInstrs =
              hashEqualInstrs.filter(otherInst =>
                deepEquals.eqInst(inst, otherInst))
            val redundantInstrs = equalInstrs.filter(eqInst =>
              blockDominatedByDef(block, eqInst.name)) // only redundant if the current block is dominated by the block in which the equal instruction occurs

            val newInstOpt = redundantInstrs.headOption.map(
              redInst =>
                Inst.Let(inst.name,
                         Op.Copy(Val.Local(redInst.name, redInst.op.resty))))
            newInstOpt.getOrElse(inst)
          } else {
            inst
          }
        }

        case otherInst @ _ =>
          otherInst
      }

      block.label +: newBlockInsts
    }

    newInsts.flatten
  }

}

object GlobalValueNumbering extends PassCompanion {
  def isIdempotent(op: Op): Boolean = {
    import Op._
    op match {
      // Always idempotent:
      case (_: Pure | _: Method | _: Dynmethod | _: As | _: Is | _: Copy |
          _: Sizeof | _: Module | _: Field | _: Box | _: Unbox) =>
        true

      // Never idempotent:
      case (_: Load | _: Store | _: Stackalloc | _: Classalloc | _: Call |
          _: Closure) =>
        false
    }
  }

  class DeepEquals(localDefs: Local => Inst) {

    def eqInst(instA: Inst.Let, instB: Inst.Let): Boolean = {
      (instA.name == instB.name) || eqOp(instA.op, instB.op)
    }

    def eqOp(opA: Op, opB: Op): Boolean = {
      import Op._
      if (!(isIdempotent(opA) && isIdempotent(opB)))
        false
      else {
        (opA, opB) match {

          case (Elem(tyA, ptrA, indexesA), Elem(tyB, ptrB, indexesB)) =>
            eqType(tyA, tyB) && eqVal(ptrA, ptrB) && eqVals(indexesA, indexesB)

          case (Extract(aggrA, indexesA), Extract(aggrB, indexesB)) =>
            eqVal(aggrA, aggrB) && (indexesA == indexesB)

          case (Insert(aggrA, valueA, indexesA),
                Insert(aggrB, valueB, indexesB)) =>
            eqVal(aggrA, aggrB) && eqVal(valueA, valueB) && (indexesA == indexesB)

          // TODO handle commutativity of some bins
          case (Bin(binA, tyA, lA, rA), Bin(binB, tyB, lB, rB)) =>
            eqBin(binA, binB) && eqType(tyA, tyB) && eqVal(lA, lB) && eqVal(rA,
                                                                            rB)

          case (Comp(compA, tyA, lA, rA), Comp(compB, tyB, lB, rB)) =>
            eqComp(compA, compB) && eqType(tyA, tyB) && eqVal(lA, lB) && eqVal(
              rA,
              rB)

          case (Conv(convA, tyA, valueA), Conv(convB, tyB, valueB)) =>
            eqConv(convA, convB) && eqType(tyA, tyB) && eqVal(valueA, valueB)

          case (Select(condA, thenvA, elsevA),
                Select(condB, thenvB, elsevB)) =>
            eqVals(Seq(condA, thenvA, elsevA), Seq(condB, thenvB, elsevB))

          case (Field(objA, nameA), Field(objB, nameB)) =>
            eqVal(objA, objB) && eqGlobal(nameA, nameB)

          case (Method(objA, nameA), Method(objB, nameB)) =>
            eqVal(objA, objB) && eqGlobal(nameA, nameB)

          case (Dynmethod(objA, signatureA), Dynmethod(objB, signatureB)) =>
            eqVal(objA, objB) && signatureA == signatureB

          case (Module(nameA, _), Module(nameB, _)) =>
            eqGlobal(nameA, nameB)

          case (As(tyA, objA), As(tyB, objB)) =>
            eqType(tyA, tyB) && eqVal(objA, objB)

          case (Is(tyA, objA), Is(tyB, objB)) =>
            eqType(tyA, tyB) && eqVal(objA, objB)

          case (Copy(valueA), Copy(valueB)) =>
            eqVal(valueA, valueB)

          case (Sizeof(tyA), Sizeof(tyB)) =>
            eqType(tyA, tyB)

          case (Box(tyA, objA), Box(tyB, objB)) =>
            tyA == tyB && eqVal(objA, objB)

          case (Unbox(tyA, objA), Unbox(tyB, objB)) =>
            tyA == tyB && eqVal(objA, objB)

          case _ => false // non-matching pairs of ops, or not idempotent ones
        }
      }
    }

    def eqVal(valueA: Val, valueB: Val): Boolean = {
      import Val._
      (valueA, valueB) match {
        case (Struct(nameA, valuesA), Struct(nameB, valuesB)) =>
          eqGlobal(nameA, nameB) && eqVals(valuesA, valuesB)

        case (Array(elemtyA, valuesA), Array(elemtyB, valuesB)) =>
          eqType(elemtyA, elemtyB) && eqVals(valuesA, valuesB)

        case (Const(valueA), Const(valueB)) =>
          eqVal(valueA, valueB)

        case (Local(nameA, valtyA), Local(nameB, valtyB)) =>
          lazy val eqDefs = (localDefs(nameA), localDefs(nameB)) match {
            case (_: Inst.Label, _: Inst.Label)     => (nameA == nameB)
            case (instA: Inst.Let, instB: Inst.Let) => eqInst(instA, instB)
            case _                                  => false
          }
          eqType(valtyA, valtyB) && ((nameA == nameB) || eqDefs)

        case _ =>
          valueA == valueB
      }
    }

    def eqVals(valsA: Seq[Val], valsB: Seq[Val]): Boolean = {
      val sizeEqual = (valsA.size == valsB.size)
      lazy val contentEqual =
        valsA.zip(valsB).forall { case (a, b) => eqVal(a, b) }
      sizeEqual && contentEqual
    }

    def eqType(tyA: Type, tyB: Type): Boolean = {
      tyA == tyB
    }

    def eqGlobal(globalA: Global, globalB: Global): Boolean = {
      globalA == globalB
    }

    def eqBin(binA: Bin, binB: Bin): Boolean = {
      binA == binB
    }

    def eqComp(compA: Comp, compB: Comp): Boolean = {
      compA == compB
    }

    def eqConv(convA: Conv, convB: Conv): Boolean = {
      convA == convB
    }

  }

  type Hash = Int

  class HashFunction(hashLocal: Local => Hash) extends (Any => Hash) {

    import HashFunction._

    def apply(obj: Any): Hash = {
      obj match {
        case op: Op     => hashOp(op)
        case value: Val => hashVal(value)

        case local: Local => hashLocal(local)

        case ty: Type   => hashType(ty)
        case g: Global  => hashGlobal(g)
        case bin: Bin   => hashBin(bin)
        case comp: Comp => hashComp(comp)
        case conv: Conv => hashConv(conv)

        case b: Boolean  => b.hashCode
        case i: Int      => i.hashCode
        case d: Double   => d.hashCode
        case str: String => str.hashCode

        case _ =>
          throw new IllegalArgumentException(
            s"Unable to hash value {${obj}} of type ${obj.getClass.getName}")
      }
    }

    def hashOp(op: Op): Hash = {
      import Op._
      val opFields: Seq[Any] = op match {
        case Call(ty, ptr, args, _)    => "Call" +: ty +: ptr +: args
        case Load(ty, ptr, isVolatile) => Seq("Load", ty, ptr, isVolatile)
        case Store(ty, ptr, value, isVolatile) =>
          Seq("Store", ty, ptr, value, isVolatile)
        case Elem(ty, ptr, indexes) => "Elem" +: ty +: ptr +: indexes
        case Extract(aggr, indexes) => "Extract" +: aggr +: indexes
        case Insert(aggr, value, indexes) =>
          "Insert" +: aggr +: value +: indexes

        case Stackalloc(ty, n)          => Seq("Stackalloc", ty, n)
        case Bin(bin, ty, l, r)         => Seq("Bin", bin, ty, l, r)
        case Comp(comp, ty, l, r)       => Seq("Comp", comp, ty, l, r)
        case Conv(conv, ty, value)      => Seq("Conv", ty, value)
        case Select(cond, thenv, elsev) => Seq("Select", cond, thenv, elsev)

        case Field(obj, name)           => Seq("Field", obj, name)
        case Method(obj, name)          => Seq("Method", obj, name)
        case Dynmethod(obj, signature)  => Seq("Dynmethod", obj, signature)
        case As(ty, obj)                => Seq("As", ty, obj)
        case Is(ty, obj)                => Seq("Is", ty, obj)
        case Copy(value)                => Seq("Copy", value)
        case Closure(ty, fun, captures) => "Closure" +: ty +: fun +: captures

        case Classalloc(name) => Seq("Classalloc", name)
        case Module(name, _)  => Seq("Module", name)
        case Sizeof(ty)       => Seq("Sizeof", ty)
        case Box(code, obj)   => Seq("Box", code.toString, obj)
        case Unbox(code, obj) => Seq("Unbox", code.toString, obj)
      }

      combineHashes(opFields.map(this.apply))
    }

    def hashVal(value: Val): Hash = {
      import Val._
      val fields: Seq[Any] = value match {
        case Struct(name, values)  => "Struct" +: name +: values
        case Array(elemty, values) => "Array" +: elemty +: values
        case Const(value)          => Seq("Const", value)

        case Local(name, _) => Seq(hashLocal(name))

        // the other val kinds can't have another Val in them
        case _ => Seq(value.hashCode)
      }

      combineHashes(fields.map(this.apply))
    }

    def hashType(ty: Type): Hash = {
      ty.hashCode
    }

    def hashGlobal(global: Global): Hash = {
      global.hashCode
    }

    def hashBin(bin: Bin): Hash = {
      bin.hashCode
    }

    def hashComp(comp: Comp): Hash = {
      comp.hashCode
    }

    def hashConv(conv: Conv): Hash = {
      conv.hashCode
    }

  }

  object HashFunction {

    def combineHashes(hashes: Seq[Hash]): Hash = {
      MurmurHash3.orderedHash(hashes)
    }

    def rawLocal(local: Local): Hash = {
      combineHashes(Seq(local.scope.hashCode, local.id.hashCode))
    }

  }

  override def apply(config: tools.Config, top: Top) =
    new GlobalValueNumbering()
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy