scala.scalanative.nir.serialization.BinaryDeserializer.scala Maven / Gradle / Ivy
package scala.scalanative
package nir
package serialization
import java.net.URI
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import scala.collection.mutable
import scala.scalanative.nir.serialization.{Tags => T}
import scala.reflect.NameTransformer
final class BinaryDeserializer(buffer: ByteBuffer, bufferName: String) {
import buffer._
private var lastPosition: Position = Position.NoPosition
private val (prelude, header, files)
: (Prelude, Seq[(Global, Int)], Array[URI]) = {
buffer.position(0)
val prelude = Prelude.readFrom(buffer, bufferName)
val files = Array.fill(getInt())(new URI(getUTF8String()))
val pairs = getSeq((getGlobal(), getInt()))
(prelude, pairs, files)
}
private val usesEncodedMemberNames = prelude.revision >= 9
final def deserialize(): Seq[Defn] = {
val allDefns = mutable.UnrolledBuffer.empty[Defn]
header.foreach {
case (g, offset) =>
buffer.position(offset)
allDefns += getDefn()
}
allDefns.toSeq
}
private def getSeq[T](getT: => T): Seq[T] =
(1 to getInt).map(_ => getT).toSeq
private def getOpt[T](getT: => T): Option[T] =
if (get == 0) None else Some(getT)
private def getInts(): Seq[Int] = getSeq(getInt)
private def getUTF8String(): String = {
new String(getBytes(), StandardCharsets.UTF_8)
}
private def getBytes(): Array[Byte] = {
val arr = new Array[Byte](getInt)
get(arr)
arr
}
private def getBool(): Boolean = get != 0
private def getAttrs(): Attrs = Attrs.fromSeq(getSeq(getAttr()))
private def getAttr(): Attr = getInt match {
case T.MayInlineAttr => Attr.MayInline
case T.InlineHintAttr => Attr.InlineHint
case T.NoInlineAttr => Attr.NoInline
case T.AlwaysInlineAttr => Attr.AlwaysInline
case T.MaySpecialize => Attr.MaySpecialize
case T.NoSpecialize => Attr.NoSpecialize
case T.UnOptAttr => Attr.UnOpt
case T.NoOptAttr => Attr.NoOpt
case T.DidOptAttr => Attr.DidOpt
case T.BailOptAttr => Attr.BailOpt(getUTF8String())
case T.DynAttr => Attr.Dyn
case T.StubAttr => Attr.Stub
case T.ExternAttr => Attr.Extern
case T.LinkAttr => Attr.Link(getUTF8String())
case T.AbstractAttr => Attr.Abstract
case T.LinktimeResolvedAttr => Attr.LinktimeResolved
}
private def getBin(): Bin = getInt match {
case T.IaddBin => Bin.Iadd
case T.FaddBin => Bin.Fadd
case T.IsubBin => Bin.Isub
case T.FsubBin => Bin.Fsub
case T.ImulBin => Bin.Imul
case T.FmulBin => Bin.Fmul
case T.SdivBin => Bin.Sdiv
case T.UdivBin => Bin.Udiv
case T.FdivBin => Bin.Fdiv
case T.SremBin => Bin.Srem
case T.UremBin => Bin.Urem
case T.FremBin => Bin.Frem
case T.ShlBin => Bin.Shl
case T.LshrBin => Bin.Lshr
case T.AshrBin => Bin.Ashr
case T.AndBin => Bin.And
case T.OrBin => Bin.Or
case T.XorBin => Bin.Xor
}
private def getInsts(): Seq[Inst] = getSeq(getInst())
private def getInst(): Inst = {
implicit val pos: nir.Position = getPosition()
getInt() match {
case T.LabelInst => Inst.Label(getLocal(), getParams())
case T.LetInst => Inst.Let(getLocal(), getOp(), Next.None)
case T.LetUnwindInst => Inst.Let(getLocal(), getOp(), getNext())
case T.RetInst => Inst.Ret(getVal())
case T.JumpInst => Inst.Jump(getNext())
case T.IfInst => Inst.If(getVal(), getNext(), getNext())
case T.SwitchInst => Inst.Switch(getVal(), getNext(), getNexts())
case T.ThrowInst => Inst.Throw(getVal(), getNext())
case T.UnreachableInst => Inst.Unreachable(getNext())
case T.LinktimeIfInst =>
Inst.LinktimeIf(getLinktimeCondition(), getNext(), getNext())
}
}
private def getComp(): Comp = getInt match {
case T.IeqComp => Comp.Ieq
case T.IneComp => Comp.Ine
case T.UgtComp => Comp.Ugt
case T.UgeComp => Comp.Uge
case T.UltComp => Comp.Ult
case T.UleComp => Comp.Ule
case T.SgtComp => Comp.Sgt
case T.SgeComp => Comp.Sge
case T.SltComp => Comp.Slt
case T.SleComp => Comp.Sle
case T.FeqComp => Comp.Feq
case T.FneComp => Comp.Fne
case T.FgtComp => Comp.Fgt
case T.FgeComp => Comp.Fge
case T.FltComp => Comp.Flt
case T.FleComp => Comp.Fle
}
private def getConv(): Conv = getInt match {
case T.TruncConv => Conv.Trunc
case T.ZextConv => Conv.Zext
case T.SextConv => Conv.Sext
case T.FptruncConv => Conv.Fptrunc
case T.FpextConv => Conv.Fpext
case T.FptouiConv => Conv.Fptoui
case T.FptosiConv => Conv.Fptosi
case T.UitofpConv => Conv.Uitofp
case T.SitofpConv => Conv.Sitofp
case T.PtrtointConv => Conv.Ptrtoint
case T.InttoptrConv => Conv.Inttoptr
case T.BitcastConv => Conv.Bitcast
}
private def getDefns(): Seq[Defn] = getSeq(getDefn())
private def getDefn(): Defn = {
implicit val pos: nir.Position = getPosition()
getInt() match {
case T.VarDefn =>
Defn.Var(getAttrs(), getGlobal(), getType(), getVal())
case T.ConstDefn =>
Defn.Const(getAttrs(), getGlobal(), getType(), getVal())
case T.DeclareDefn =>
Defn.Declare(getAttrs(), getGlobal(), getType())
case T.DefineDefn =>
Defn.Define(getAttrs(), getGlobal(), getType(), getInsts())
case T.TraitDefn =>
Defn.Trait(getAttrs(), getGlobal(), getGlobals())
case T.ClassDefn =>
Defn.Class(getAttrs(), getGlobal(), getGlobalOpt(), getGlobals())
case T.ModuleDefn =>
Defn.Module(getAttrs(), getGlobal(), getGlobalOpt(), getGlobals())
}
}
private def getGlobals(): Seq[Global] = getSeq(getGlobal())
private def getGlobalOpt(): Option[Global] = getOpt(getGlobal())
private def getGlobal(): Global = getInt match {
case T.NoneGlobal =>
Global.None
case T.TopGlobal =>
Global.Top(getUTF8String())
case T.MemberGlobal =>
Global.Member(Global.Top(getUTF8String()), getSig())
}
private def getSig(): Sig = {
val sig = new Sig(getUTF8String())
if (usesEncodedMemberNames) sig
else
sig.unmangled match {
case s: Sig.Field => s.copy(id = NameTransformer.encode(s.id))
case s: Sig.Method => s.copy(id = NameTransformer.encode(s.id))
case sig => sig
}
}
private def getLocal(): Local =
Local(getLong)
private def getNexts(): Seq[Next] = getSeq(getNext())
private def getNext(): Next = getInt match {
case T.NoneNext => Next.None
case T.UnwindNext => Next.Unwind(getParam(), getNext())
case T.CaseNext => Next.Case(getVal(), getNext())
case T.LabelNext => Next.Label(getLocal(), getVals())
}
private def getOp(): Op = getInt match {
case T.CallOp => Op.Call(getType(), getVal(), getVals())
case T.LoadOp => Op.Load(getType(), getVal())
case T.StoreOp => Op.Store(getType(), getVal(), getVal())
case T.ElemOp => Op.Elem(getType(), getVal(), getVals())
case T.ExtractOp => Op.Extract(getVal(), getInts())
case T.InsertOp => Op.Insert(getVal(), getVal(), getInts())
case T.StackallocOp => Op.Stackalloc(getType(), getVal())
case T.BinOp => Op.Bin(getBin(), getType(), getVal(), getVal())
case T.CompOp => Op.Comp(getComp(), getType(), getVal(), getVal())
case T.ConvOp => Op.Conv(getConv(), getType(), getVal())
case T.ClassallocOp => Op.Classalloc(getGlobal())
case T.FieldloadOp => Op.Fieldload(getType(), getVal(), getGlobal())
case T.FieldstoreOp =>
Op.Fieldstore(getType(), getVal(), getGlobal(), getVal())
case T.FieldOp => Op.Field(getVal(), getGlobal())
case T.MethodOp => Op.Method(getVal(), getSig())
case T.DynmethodOp => Op.Dynmethod(getVal(), getSig())
case T.ModuleOp => Op.Module(getGlobal())
case T.AsOp => Op.As(getType(), getVal())
case T.IsOp => Op.Is(getType(), getVal())
case T.CopyOp => Op.Copy(getVal())
case T.SizeofOp => Op.Sizeof(getType())
case T.BoxOp => Op.Box(getType(), getVal())
case T.UnboxOp => Op.Unbox(getType(), getVal())
case T.VarOp => Op.Var(getType())
case T.VarloadOp => Op.Varload(getVal())
case T.VarstoreOp => Op.Varstore(getVal(), getVal())
case T.ArrayallocOp => Op.Arrayalloc(getType(), getVal())
case T.ArrayloadOp => Op.Arrayload(getType(), getVal(), getVal())
case T.ArraystoreOp =>
Op.Arraystore(getType(), getVal(), getVal(), getVal())
case T.ArraylengthOp => Op.Arraylength(getVal())
}
private def getParams(): Seq[Val.Local] = getSeq(getParam())
private def getParam(): Val.Local = Val.Local(getLocal(), getType())
private def getTypes(): Seq[Type] = getSeq(getType())
private def getType(): Type = getInt match {
case T.VarargType => Type.Vararg
case T.PtrType => Type.Ptr
case T.BoolType => Type.Bool
case T.CharType => Type.Char
case T.ByteType => Type.Byte
case T.ShortType => Type.Short
case T.IntType => Type.Int
case T.LongType => Type.Long
case T.FloatType => Type.Float
case T.DoubleType => Type.Double
case T.ArrayValueType => Type.ArrayValue(getType(), getInt)
case T.StructValueType => Type.StructValue(getTypes())
case T.FunctionType => Type.Function(getTypes(), getType())
case T.NullType => Type.Null
case T.NothingType => Type.Nothing
case T.VirtualType => Type.Virtual
case T.VarType => Type.Var(getType())
case T.UnitType => Type.Unit
case T.ArrayType => Type.Array(getType(), getBool())
case T.RefType => Type.Ref(getGlobal(), getBool(), getBool())
}
private def getVals(): Seq[Val] = getSeq(getVal())
private def getVal(): Val = getInt match {
case T.TrueVal => Val.True
case T.FalseVal => Val.False
case T.NullVal => Val.Null
case T.ZeroVal => Val.Zero(getType())
case T.CharVal => Val.Char(getShort.toChar)
case T.ByteVal => Val.Byte(get)
case T.ShortVal => Val.Short(getShort)
case T.IntVal => Val.Int(getInt)
case T.LongVal => Val.Long(getLong)
case T.FloatVal => Val.Float(getFloat)
case T.DoubleVal => Val.Double(getDouble)
case T.StructValueVal => Val.StructValue(getVals())
case T.ArrayValueVal => Val.ArrayValue(getType(), getVals())
case T.CharsVal => Val.Chars(getBytes().toIndexedSeq)
case T.LocalVal => Val.Local(getLocal(), getType())
case T.GlobalVal => Val.Global(getGlobal(), getType())
case T.UnitVal => Val.Unit
case T.ConstVal => Val.Const(getVal())
case T.StringVal =>
Val.String {
val chars = Array.fill(getInt)(getChar)
new String(chars)
}
case T.VirtualVal => Val.Virtual(getLong)
case T.ClassOfVal => Val.ClassOf(getGlobal())
}
private def getLinktimeCondition(): LinktimeCondition = getInt() match {
case LinktimeCondition.Tag.SimpleCondition =>
LinktimeCondition.SimpleCondition(
propertyName = getUTF8String(),
comparison = getComp(),
value = getVal()
)(getPosition())
case LinktimeCondition.Tag.ComplexCondition =>
LinktimeCondition.ComplexCondition(
op = getBin(),
left = getLinktimeCondition(),
right = getLinktimeCondition()
)(getPosition())
case n => util.unsupported(s"Unknown linktime condition tag: ${n}")
}
// Ported from Scala.js
def getPosition(): Position = {
import PositionFormat._
def readPosition(): Position = {
val first = get()
if (first == FormatNoPositionValue) {
Position.NoPosition
} else {
val result = if ((first & FormatFullMask) == FormatFullMaskValue) {
val file = files(getInt())
val line = getInt()
val column = getInt()
Position(file, line, column)
} else {
assert(
lastPosition != Position.NoPosition,
"Position format error: first position must be full"
)
if ((first & Format1Mask) == Format1MaskValue) {
val columnDiff = first >> Format1Shift
Position(
lastPosition.source,
lastPosition.line,
lastPosition.column + columnDiff
)
} else if ((first & Format2Mask) == Format2MaskValue) {
val lineDiff = first >> Format2Shift
val column = get() & 0xff // unsigned
Position(lastPosition.source, lastPosition.line + lineDiff, column)
} else {
assert(
(first & Format3Mask) == Format3MaskValue,
s"Position format error: first byte $first does not match any format"
)
val lineDiff = getShort()
val column = get() & 0xff // unsigned
Position(lastPosition.source, lastPosition.line + lineDiff, column)
}
}
lastPosition = result
result
}
}
readPosition()
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy