All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scala.scalanative.optimizer.pass.InstCombine.scala Maven / Gradle / Ivy

The newest version!
package scala.scalanative
package optimizer
package pass

import analysis.ClassHierarchy.Top
import analysis.ControlFlow

import nir._
import Inst._
import Bin._
import Comp._

import scala.None
import scala.collection.mutable

class InstCombine()(implicit fresh: Fresh) extends Pass {
  import InstCombine._
  import ConstantFolding._

  override def onInsts(insts: Seq[Inst]): Seq[Inst] = {
    // This stores the encountered definitions for non-params locals
    val buf      = new nir.Buffer
    val defop    = mutable.HashMap.empty[Local, Op]
    val resolver = new DefOp(defop)
    val cfg      = ControlFlow.Graph(insts)

    /* Because of the pre-order traversal of the CFG, the definitions will be
     * seen before their use
     */
    cfg.foreach { b =>
      buf += b.label
      b.insts.foreach { inst =>
        val simplifiedSeq = simplify(inst, resolver)
        simplifiedSeq.foreach {
          case Let(n, op) => defop += (n -> op)
          case _          =>
        }
        buf ++= simplifiedSeq
      }
    }

    buf.toSeq
  }

  private def simplify(inst: Inst, defop: DefOp): Seq[Inst] = {
    val singleSimp = inst match {

      case Let(z, Op.Bin(Iadd, ty, y, IVal(a))) =>
        defop(y) match {
          // (x + b) + a = x + (a + b)
          case Some(Op.Bin(Iadd, _, x, IVal(b))) =>
            Let(z, Op.Bin(Iadd, ty, x, IVal(a + b, ty)))

          // (x - b) + a = x + (a - b)
          case Some(Op.Bin(Isub, _, x, IVal(b))) =>
            Let(z, Op.Bin(Iadd, ty, x, IVal(a - b, ty)))

          case _ => inst
        }

      case Let(z, Op.Bin(Isub, ty, y, IVal(a))) =>
        defop(y) match {
          // (x - b) - a = x - (a + b)
          case Some(Op.Bin(Isub, _, x, IVal(b))) =>
            Let(z, Op.Bin(Isub, ty, x, IVal(a + b, ty)))

          // (x + b) - a = x - (a - b)
          case Some(Op.Bin(Iadd, _, x, IVal(b))) =>
            Let(z, Op.Bin(Isub, ty, x, IVal(a - b, ty)))

          case _ => inst
        }

      case Let(z, Op.Bin(Imul, ty, y, IVal(a))) =>
        defop(y) match {
          // (x * b) * a = x * (a * b)
          case Some(Op.Bin(Imul, _, x, IVal(b))) =>
            Let(z, Op.Bin(Imul, ty, x, IVal(a * b, ty)))

          // (x << b) * a = x * (a * 2^b)
          case Some(Op.Bin(Shl, _, x, IVal(b))) =>
            Let(z, Op.Bin(Imul, ty, x, IVal(a * math.pow(2, b).toLong, ty)))

          case _ => inst
        }

      case Let(n, Op.Comp(Ieq, tyIeq, compared, IVal(0))) =>
        defop(compared) match {
          // ((x xor y) == 0) = (x == y)
          case Some(Op.Bin(Xor, tyXor, x, y)) if (tyIeq == tyXor) =>
            Let(n, Op.Comp(Ieq, tyIeq, x, y))

          case _ => inst
        }

      case Let(n, Op.Comp(Ine, tyIne, compared, IVal(0))) =>
        defop(compared) match {
          // ((x xor y) != 0) = (x != y)
          case Some(Op.Bin(Xor, tyXor, x, y)) if (tyIne == tyXor) =>
            Let(n, Op.Comp(Ine, tyIne, x, y))

          case _ => inst
        }

      case Let(n, Op.Select(cond, thenv, elsev)) =>
        defop(cond) match {
          // select (c xor true) then a else b  =  select c then b else a
          case Some(Op.Bin(Xor, _, negCond, Val.True)) =>
            Let(n, Op.Select(negCond, elsev, thenv))

          case _ => inst
        }

      case If(cond, thenp, elsep) =>
        defop(cond) match {
          // if (c xor true) then a else b  =  if c then b else a
          case Some(Op.Bin(Xor, _, negCond, Val.True)) =>
            If(negCond, elsep, thenp)

          case _ => inst
        }

      case Let(n, op @ Op.Bin(_, _, lhs, rhs)) =>
        Let(n, simplifyBin(defop(lhs), defop(rhs), op))

      // Nothing
      case _ => inst
    }

    simplifyExt(singleSimp, defop)
  }

  def simplifyBin(lhsDef: Option[Op], rhsDef: Option[Op], op: Op.Bin): Op = {
    (lhsDef, rhsDef, op) match {

      // (x << a) * b = x * (2^a * b)  [i32]
      case (Some(Op.Bin(Shl, Type.Int, x, Val.Int(a))),
            _,
            Op.Bin(Imul, Type.Int, _, Val.Int(b))) =>
        Op.Bin(Imul, Type.Int, x, Val.Int(math.pow(2, a & 31).toInt * b))

      // (x << a) * b = x * (2^a * b)  [i64]
      case (Some(Op.Bin(Shl, Type.Long, x, Val.Long(a))),
            _,
            Op.Bin(Imul, Type.Long, _, Val.Long(b))) =>
        Op.Bin(Imul, Type.Long, x, Val.Long(math.pow(2, a & 63).toLong * b))

      // (x * a) << b = x * (a * 2^b)  [i32]
      case (Some(Op.Bin(Imul, Type.Int, x, Val.Int(a))),
            _,
            Op.Bin(Shl, Type.Int, _, Val.Int(b))) =>
        Op.Bin(Imul, Type.Int, x, Val.Int(a * math.pow(2, b & 31).toInt))

      // (x * a) << b = x * (a * 2^b)  [i64]
      case (Some(Op.Bin(Imul, Type.Long, x, Val.Long(a))),
            _,
            Op.Bin(Shl, Type.Long, _, Val.Long(b))) =>
        Op.Bin(Imul, Type.Long, x, Val.Long(a * math.pow(2, b & 63).toLong))

      // x + (0 - y) = x - y
      case (_, Some(Op.Bin(Isub, _, IVal(0), y)), Op.Bin(Iadd, ty, x, _)) =>
        Op.Bin(Isub, ty, x, y)

      // x - (0 - y) = x + y
      case (_, Some(Op.Bin(Isub, _, IVal(0), y)), Op.Bin(Isub, ty, x, _)) =>
        Op.Bin(Iadd, ty, x, y)

      // (0 - x) + y = y - x
      case (Some(Op.Bin(Isub, _, IVal(0), x)), _, Op.Bin(Iadd, ty, _, y)) =>
        Op.Bin(Isub, ty, y, x)

      // (x - y) + y = x
      case (Some(Op.Bin(Isub, _, x, y)), _, Op.Bin(Iadd, _, _, c))
          if (y == c) =>
        Op.Copy(x)

      // (x + y) - y = x
      case (Some(Op.Bin(Iadd, _, x, y)), _, Op.Bin(Isub, _, _, c))
          if (y == c) =>
        Op.Copy(x)

      case _ => op
    }
  }

  def simplifyExt(inst: Inst, defop: DefOp)(implicit fresh: Fresh): Seq[Inst] = {
    inst match {
      // (x * z) + (y * z) = (x + y) * z
      case Let(n,
               Op.Bin(Iadd,
                      ty,
                      defop(Op.Bin(Imul, _, x, z)),
                      defop(Op.Bin(Imul, _, y, d)))) if (z == d) =>
        val tmp = fresh()
        Seq(Let(tmp, Op.Bin(Iadd, ty, x, y)),
            Let(n, Op.Bin(Imul, ty, Val.Local(tmp, ty), z)))

      // unbox (box x) = x
      case Let(n, Op.Unbox(_, defop(Op.Box(_, x)))) =>
        Seq(Let(n, Op.Copy(x)))

      // (x or a) or b = x or (a or b)
      case Let(n, Op.Bin(Or, ty, defop(Op.Bin(Or, _, x, IVal(a))), IVal(b))) =>
        Seq(Let(n, Op.Bin(Or, ty, x, IVal(a | b, ty))))

      // (x and a) and b = x and (a and b)
      case Let(n,
               Op.Bin(And, ty, defop(Op.Bin(And, _, x, IVal(a))), IVal(b))) =>
        Seq(Let(n, Op.Bin(And, ty, x, IVal(a & b, ty))))

      // (x xor a) xor b = x xor (a xor b)
      case Let(n,
               Op.Bin(Xor, ty, defop(Op.Bin(Xor, _, x, IVal(a))), IVal(b))) =>
        Seq(Let(n, Op.Bin(Xor, ty, x, IVal(a ^ b, ty))))

      // (x << a) << b = x << (a + b)  [i32]
      case Let(n,
               Op.Bin(Shl,
                      Type.Int,
                      defop(Op.Bin(Shl, Type.Int, x, Val.Int(a))),
                      Val.Int(b))) =>
        val totShift = (a & 31) + (b & 31)
        if (totShift >= 32)
          Seq(Let(n, Op.Copy(Val.Int(0))))
        else
          Seq(Let(n, Op.Bin(Shl, Type.Int, x, Val.Int(totShift))))

      // (x << a) << b = x << (a + b)  [i64]
      case Let(n,
               Op.Bin(Shl,
                      Type.Long,
                      defop(Op.Bin(Shl, Type.Long, x, Val.Long(a))),
                      Val.Long(b))) =>
        val totShift = (a & 63) + (b & 63)
        if (totShift >= 64)
          Seq(Let(n, Op.Copy(Val.Long(0))))
        else
          Seq(Let(n, Op.Bin(Shl, Type.Long, x, Val.Long(totShift))))

      // (x >>> a) >>> b = x >>> (a + b)  [i32]
      case Let(n,
               Op.Bin(Lshr,
                      Type.Int,
                      defop(Op.Bin(Lshr, Type.Int, x, Val.Int(a))),
                      Val.Int(b))) =>
        val totShift = (a & 31) + (b & 31)
        if (totShift >= 32)
          Seq(Let(n, Op.Copy(Val.Int(0))))
        else
          Seq(Let(n, Op.Bin(Lshr, Type.Int, x, Val.Int(totShift))))

      // (x >>> a) >>> b = x >>> (a + b)  [i64]
      case Let(n,
               Op.Bin(Lshr,
                      Type.Long,
                      defop(Op.Bin(Lshr, Type.Long, x, Val.Long(a))),
                      Val.Long(b))) =>
        val totShift = (a & 63) + (b & 63)
        if (totShift >= 64)
          Seq(Let(n, Op.Copy(Val.Long(0))))
        else
          Seq(Let(n, Op.Bin(Lshr, Type.Long, x, Val.Long(totShift))))

      // (x >> a) >> b = x >> (a + b)  [i32]
      case Let(n,
               Op.Bin(Lshr,
                      Type.Int,
                      defop(Op.Bin(Lshr, Type.Int, x, Val.Int(a))),
                      Val.Int(b))) =>
        val shiftSum = (a & 31) + (b & 31)
        val totShift = math.min(shiftSum, 31)
        Seq(Let(n, Op.Bin(Lshr, Type.Int, x, Val.Int(totShift))))

      // (x >> a) >> b = x >> (a + b)  [i64]
      case Let(n,
               Op.Bin(Lshr,
                      Type.Long,
                      defop(Op.Bin(Lshr, Type.Long, x, Val.Long(a))),
                      Val.Long(b))) =>
        val shiftSum = (a & 63) + (b & 63)
        val totShift = math.min(shiftSum, 63)
        Seq(Let(n, Op.Bin(Lshr, Type.Long, x, Val.Long(totShift))))

      // ((x + a) == b) = (x == (b - a))
      case Let(
          n,
          Op.Comp(Ieq, ty, defop(Op.Bin(Iadd, _, x, IVal(a))), IVal(b))) =>
        Seq(Let(n, Op.Comp(Ieq, ty, x, IVal(b - a, ty))))

      // ((x - a) == b) = (x == (a + b))
      case Let(
          n,
          Op.Comp(Ieq, ty, defop(Op.Bin(Isub, _, x, IVal(a))), IVal(b))) =>
        Seq(Let(n, Op.Comp(Ieq, ty, x, IVal(a + b, ty))))

      // ((a - x) == b) = (x == (a - b))
      case Let(
          n,
          Op.Comp(Ieq, ty, defop(Op.Bin(Isub, _, IVal(a), x)), IVal(b))) =>
        Seq(Let(n, Op.Comp(Ieq, ty, x, IVal(a - b, ty))))

      // ((x xor a) == b) = (x == (a xor b))
      case Let(n,
               Op.Comp(Ieq, ty, defop(Op.Bin(Xor, _, x, IVal(a))), IVal(b))) =>
        Seq(Let(n, Op.Comp(Ieq, ty, x, IVal(a ^ b, ty))))

      // ((x xor true) == y) = (x != y)
      case Let(n,
               Op.Comp(Ieq,
                       Type.Bool,
                       defop(Op.Bin(Xor, Type.Bool, x, Val.True)),
                       y)) =>
        Seq(Let(n, Op.Comp(Ine, Type.Bool, x, y)))

      // (x == (y xor true)) = (x != y)
      case Let(n,
               Op.Comp(Ieq,
                       Type.Bool,
                       x,
                       defop(Op.Bin(Xor, Type.Bool, y, Val.True)))) =>
        Seq(Let(n, Op.Comp(Ine, Type.Bool, x, y)))

      case _ => Seq(inst)
    }
  }

}

object InstCombine extends PassCompanion {
  override def apply(config: tools.Config, top: Top) =
    new InstCombine()(top.fresh)

  class DefOp(val defops: mutable.HashMap[Local, Op]) {

    def apply(value: Val): Option[Op] = {
      unapply(value)
    }

    def unapply(value: Val): Option[Op] = {
      value match {
        case Val.Local(l, _) => defops.get(l)
        case _               => None
      }
    }

  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy