All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.kaitai.struct.languages.RustCompiler.scala Maven / Gradle / Ivy

package io.kaitai.struct.languages

import io.kaitai.struct.datatype.DataType._
import io.kaitai.struct.datatype.{DataType, FixedEndian, InheritedEndian, KSError, NeedRaw}
import io.kaitai.struct.exprlang.Ast
import io.kaitai.struct.format.{NoRepeat, RepeatEos, RepeatExpr, RepeatSpec, _}
import io.kaitai.struct.languages.components._
import io.kaitai.struct.translators.RustTranslator
import io.kaitai.struct.{ClassTypeProvider, RuntimeConfig, Utils}

class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
  extends LanguageCompiler(typeProvider, config)
    with ObjectOrientedLanguage
    with UpperCamelCaseClasses
    with SingleOutputFile
    with AllocateIOLocalVar
    with UniversalFooter
    with UniversalDoc
    with FixedContentsUsingArrayByteLiteral
    with EveryReadIsExpression {

  import RustCompiler._

  override def innerClasses = false

  override def innerEnums = false

  override val translator: RustTranslator = new RustTranslator(typeProvider, config)

  override def universalFooter: Unit = {
    out.dec
    out.puts("}")
  }

  override def outImports(topClass: ClassSpec) =
    importList.toList.map((x) => s"use $x;").mkString("", "\n", "\n")

  override def indent: String = "    "
  override def outFileName(topClassName: String): String = s"$topClassName.rs"

  override def fileHeader(topClassName: String): Unit = {
    outHeader.puts(s"// $headerComment")
    outHeader.puts

    importList.add("std::option::Option")
    importList.add("std::boxed::Box")
    importList.add("std::io::Result")
    importList.add("std::io::Cursor")
    importList.add("std::vec::Vec")
    importList.add("std::default::Default")
    importList.add("kaitai_struct::KaitaiStream")
    importList.add("kaitai_struct::KaitaiStruct")

    out.puts
  }

  override def opaqueClassDeclaration(classSpec: ClassSpec): Unit = {
    val name = type2class(classSpec.name.last)
    val pkg = type2classAbs(classSpec.name)

    importList.add(s"$pkg::$name")
  }

  override def classHeader(name: List[String]): Unit =
    classHeader(name, Some(kstructName))

  def classHeader(name: List[String], parentClass: Option[String]): Unit = {
    out.puts("#[derive(Default)]")
    out.puts(s"pub struct ${type2class(name)} {")
  }

  override def classFooter(name: List[String]): Unit = universalFooter

  override def classConstructorHeader(name: List[String], parentType: DataType, rootClassName: List[String], isHybrid: Boolean, params: List[ParamDefSpec]): Unit = {
    out.puts("}")
    out.puts

    out.puts(s"impl KaitaiStruct for ${type2class(name)} {")
    out.inc

    // Parameter names
    val pIo = paramName(IoIdentifier)
    val pParent = paramName(ParentIdentifier)
    val pRoot = paramName(RootIdentifier)

    // Types
    val tIo = kstreamName
    val tParent = kaitaiType2NativeType(parentType)

    out.puts(s"fn new(stream: &mut S,")
    out.puts(s"                        _parent: &Option>,")
    out.puts(s"                        _root: &Option>)")
    out.puts(s"                        -> Result")
    out.inc
    out.puts(s"where Self: Sized {")

    out.puts(s"let mut s: Self = Default::default();")
    out.puts

    out.puts(s"s.stream = stream;")

    out.puts(s"s.read(stream, _parent, _root)?;")
    out.puts

    out.puts("Ok(s)")
    out.dec
    out.puts("}")
    out.puts
  }

  override def runRead(name: List[String]): Unit = {

  }

  override def runReadCalc(): Unit = {

  }

  override def readHeader(endian: Option[FixedEndian], isEmpty: Boolean) = {
    out.puts
    out.puts(s"fn read(&mut self,")
    out.puts(s"                         stream: &mut S,")
    out.puts(s"                         _parent: &Option>,")
    out.puts(s"                         _root: &Option>)")
    out.puts(s"                         -> Result<()>")
    out.inc
    out.puts(s"where Self: Sized {")
  }

  override def readFooter(): Unit = {
    out.puts
    out.puts("Ok(())")
    out.dec
    out.puts("}")
  }

  override def attributeDeclaration(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = {
    attrName match {
      case ParentIdentifier | RootIdentifier | IoIdentifier =>
        // just ignore it for now
      case IoIdentifier =>
        out.puts(s"     stream: ${kaitaiType2NativeType(attrType)},")
      case _ =>
        out.puts(s"    pub ${idToStr(attrName)}: ${kaitaiType2NativeType(attrType)},")
    }
  }

  override def attributeReader(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = {

  }

  override def universalDoc(doc: DocSpec): Unit = {
    if (doc.summary.isDefined) {
      out.puts
      out.puts("/*")
      doc.summary.foreach((summary) => out.putsLines(" * ", summary))
      out.puts(" */")
    }
  }

  override def attrParseHybrid(leProc: () => Unit, beProc: () => Unit): Unit = {
    out.puts("if ($this->_m__is_le) {")
    out.inc
    leProc()
    out.dec
    out.puts("} else {")
    out.inc
    beProc()
    out.dec
    out.puts("}")
  }

  override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit =
    out.puts(s"${privateMemberName(attrName)} = $normalIO.ensureFixedContents($contents);")

  override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit = {
    val srcExpr = getRawIdExpr(varSrc, rep)

    val expr = proc match {
      case ProcessXor(xorValue) =>
        val procName = translator.detectType(xorValue) match {
          case _: IntType => "processXorOne"
          case _: BytesType => "processXorMany"
        }
        s"$kstreamName::$procName($srcExpr, ${expression(xorValue)})"
      case ProcessZlib =>
        s"$kstreamName::processZlib($srcExpr);"
      case ProcessRotate(isLeft, rotValue) =>
        val expr = if (isLeft) {
          expression(rotValue)
        } else {
          s"8 - (${expression(rotValue)})"
        }
        s"$kstreamName::processRotateLeft($srcExpr, $expr, 1)"
      case ProcessCustom(name, args) =>
        val procClass = if (name.length == 1) {
          val onlyName = name.head
          val className = type2class(onlyName)
          importList.add(s"$onlyName::$className")
          className
        } else {
          val pkgName = type2classAbs(name.init)
          val className = type2class(name.last)
          importList.add(s"$pkgName::$className")
          s"$pkgName::$className"
        }

        out.puts(s"let _process = $procClass::new(${args.map(expression).mkString(", ")});")
        s"_process.decode($srcExpr)"
    }
    handleAssignment(varDest, expr, rep, false)
  }

  override def allocateIO(id: Identifier, rep: RepeatSpec): String = {
    val memberName = privateMemberName(id)

    val args = rep match {
      case RepeatUntil(_) => translator.doLocalName(Identifier.ITERATOR2)
      case _ => getRawIdExpr(id, rep)
    }

    out.puts(s"let mut io = Cursor::new($args);")
    "io"
  }

  def getRawIdExpr(varName: Identifier, rep: RepeatSpec): String = {
    val memberName = privateMemberName(varName)
    rep match {
      case NoRepeat => memberName
      case _ => s"$memberName.last()"
    }
  }

  override def useIO(ioEx: Ast.expr): String = {
    out.puts(s"let mut io = ${expression(ioEx)};")
    "io"
  }

  override def pushPos(io: String): Unit =
    out.puts(s"let _pos = $io.pos();")

  override def seek(io: String, pos: Ast.expr): Unit =
    out.puts(s"$io.seek(${expression(pos)});")

  override def popPos(io: String): Unit =
    out.puts(s"$io.seek(_pos);")

  override def alignToByte(io: String): Unit =
    out.puts(s"$io.alignToByte();")

  override def condIfHeader(expr: Ast.expr): Unit = {
    out.puts(s"if ${expression(expr)} {")
    out.inc
  }

  override def condRepeatCommonInit(id: Identifier, dataType: DataType, needRaw: NeedRaw): Unit = {
    if (needRaw.level >= 1)
      out.puts(s"${privateMemberName(RawIdentifier(id))} = vec!();")
    if (needRaw.level >= 2)
      out.puts(s"${privateMemberName(RawIdentifier(RawIdentifier(id)))} = vec!();")
    out.puts(s"${privateMemberName(id)} = vec!();")
  }

  override def condRepeatEosHeader(id: Identifier, io: String, dataType: DataType): Unit = {
    out.puts(s"while !$io.isEof() {")
    out.inc
  }

  override def handleAssignmentRepeatEos(id: Identifier, expr: String): Unit = {
    out.puts(s"${privateMemberName(id)}.append($expr);")
  }

  override def condRepeatEosFooter: Unit = {
    super.condRepeatEosFooter
  }

  override def condRepeatExprHeader(id: Identifier, io: String, dataType: DataType, repeatExpr: Ast.expr): Unit = {
    out.puts(s"for i in 0..${expression(repeatExpr)} {")
    out.inc
  }

  override def handleAssignmentRepeatExpr(id: Identifier, expr: String): Unit =
    handleAssignmentRepeatEos(id, expr)

  override def condRepeatUntilHeader(id: Identifier, io: String, dataType: DataType, untilExpr: Ast.expr): Unit = {
    out.puts("while {")
    out.inc
  }

  override def handleAssignmentRepeatUntil(id: Identifier, expr: String, isRaw: Boolean): Unit = {
    val tempVar = if (isRaw) {
      translator.doLocalName(Identifier.ITERATOR2)
    } else {
      translator.doLocalName(Identifier.ITERATOR)
    }
    out.puts(s"let $tempVar = $expr;")
    out.puts(s"${privateMemberName(id)}.append($tempVar);")
  }

  override def condRepeatUntilFooter(id: Identifier, io: String, dataType: DataType, untilExpr: Ast.expr): Unit = {
    typeProvider._currentIteratorType = Some(dataType)
    out.puts(s"!(${expression(untilExpr)})")
    out.dec
    out.puts("} { }")
  }

  override def handleAssignmentSimple(id: Identifier, expr: String): Unit = {
    out.puts(s"${privateMemberName(id)} = $expr;")
  }

  override def parseExpr(dataType: DataType, assignType: DataType, io: String, defEndian: Option[FixedEndian]): String = {
    dataType match {
      case t: ReadableType =>
        s"$io.read_${t.apiCall(defEndian)}()?"
      case blt: BytesLimitType =>
        s"$io.read_bytes(${expression(blt.size)})?"
      case _: BytesEosType =>
        s"$io.read_bytes_full()?"
      case BytesTerminatedType(terminator, include, consume, eosError, _) =>
        s"$io.read_bytes_term($terminator, $include, $consume, $eosError)?"
      case BitsType1(bitEndian) =>
        s"$io.read_bits_int(1)? != 0"
      case BitsType(width: Int, bitEndian) =>
        s"$io.read_bits_int($width)?"
      case t: UserType =>
        val addParams = Utils.join(t.args.map((a) => translator.translate(a)), "", ", ", ", ")
        val addArgs = if (t.isOpaque) {
          ""
        } else {
          val parent = t.forcedParent match {
            case Some(USER_TYPE_NO_PARENT) => "null"
            case Some(fp) => translator.translate(fp)
            case None => "self"
          }
          val addEndian = t.classSpec.get.meta.endian match {
            case Some(InheritedEndian) => s", ${privateMemberName(EndianIdentifier)}"
            case _ => ""
          }
          s", $parent, ${privateMemberName(RootIdentifier)}$addEndian"
        }

        s"Box::new(${translator.types2classAbs(t.classSpec.get.name)}::new(self.stream, self, _root)?)"
    }
  }

  override def bytesPadTermExpr(expr0: String, padRight: Option[Int], terminator: Option[Int], include: Boolean): String = {
    val expr1 = padRight match {
      case Some(padByte) => s"$kstreamName::bytesStripRight($expr0, $padByte)"
      case None => expr0
    }
    val expr2 = terminator match {
      case Some(term) => s"$kstreamName::bytesTerminate($expr1, $term, $include)"
      case None => expr1
    }
    expr2
  }

  var switchIfs = false
  val NAME_SWITCH_ON = Ast.expr.Name(Ast.identifier(Identifier.SWITCH_ON))

  override def switchStart(id: Identifier, on: Ast.expr): Unit = {
    val onType = translator.detectType(on)

    switchIfs = onType match {
      case _: ArrayTypeInStream | _: BytesType => true
      case _ => false
    }

    if (!switchIfs) {
      out.puts(s"match ${expression(on)} {")
      out.inc
    }
  }

  def switchCmpExpr(condition: Ast.expr): String =
    expression(
      Ast.expr.Compare(
        NAME_SWITCH_ON,
        Ast.cmpop.Eq,
        condition
      )
    )

  override def switchCaseFirstStart(condition: Ast.expr): Unit = {
    if (switchIfs) {
      out.puts(s"if ${switchCmpExpr(condition)} {")
      out.inc
    } else {
      switchCaseStart(condition)
    }
  }

  override def switchCaseStart(condition: Ast.expr): Unit = {
    if (switchIfs) {
      out.puts(s"else if ${switchCmpExpr(condition)} {")
      out.inc
    } else {
      out.puts(s"${expression(condition)} => {")
      out.inc
    }
  }

  override def switchCaseEnd(): Unit = {
    if (switchIfs) {
      out.dec
      out.puts("}")
    } else {
      out.dec
      out.puts("},")
    }
  }

  override def switchElseStart(): Unit = {
    if (switchIfs) {
      out.puts("else {")
      out.inc
    } else {
      out.puts("_ => {")
      out.inc
    }
  }

  override def switchElseEnd(): Unit = {
    out.dec
    out.puts("}")
  }

  override def switchEnd(): Unit = universalFooter

  override def instanceDeclaration(attrName: InstanceIdentifier, attrType: DataType, isNullable: Boolean): Unit = {
    out.puts(s"    pub ${idToStr(attrName)}: Option<${kaitaiType2NativeType(attrType)}>,")
  }

  override def instanceDeclHeader(className: List[String]): Unit = {
    out.dec
    out.puts("}")
    out.puts

    out.puts(s"impl ${type2class(className)} {")
    out.inc
  }

  override def instanceHeader(className: List[String], instName: InstanceIdentifier, dataType: DataType, isNullable: Boolean): Unit = {
    out.puts(s"fn ${idToStr(instName)}(&mut self) -> ${kaitaiType2NativeType(dataType)} {")
    out.inc
  }

  override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = {
    out.puts(s"if let Some(x) = ${privateMemberName(instName)} {")
    out.inc
    out.puts("return x;")
    out.dec
    out.puts("}")
    out.puts
  }

  override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = {
    out.puts(s"return ${privateMemberName(instName)};")
  }

  override def enumDeclaration(curClass: List[String], enumName: String, enumColl: Seq[(Long, EnumValueSpec)]): Unit = {
    val enumClass = type2class(curClass ::: List(enumName))

    out.puts(s"enum $enumClass {")
    out.inc

    enumColl.foreach { case (id, label) =>
      universalDoc(label.doc)
      out.puts(s"${value2Const(label.name)},")
    }

    out.dec
    out.puts("}")
  }

  def value2Const(label: String) = Utils.upperUnderscoreCase(label)

  def idToStr(id: Identifier): String = {
    id match {
      case SpecialIdentifier(name) => name
      case NamedIdentifier(name) => Utils.lowerCamelCase(name)
      case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx"
      case InstanceIdentifier(name) => Utils.lowerCamelCase(name)
      case RawIdentifier(innerId) => "_raw_" + idToStr(innerId)
    }
  }

  override def privateMemberName(id: Identifier): String = {
    id match {
      case IoIdentifier => s"self.stream"
      case RootIdentifier => s"_root"
      case ParentIdentifier => s"_parent"
      case _ => s"self.${idToStr(id)}"
    }
  }

  override def publicMemberName(id: Identifier) = idToStr(id)

  override def localTemporaryName(id: Identifier): String = s"$$_t_${idToStr(id)}"

  override def paramName(id: Identifier): String = s"${idToStr(id)}"

  def kaitaiType2NativeType(attrType: DataType): String = {
    attrType match {
      case Int1Type(false) => "u8"
      case IntMultiType(false, Width2, _) => "u16"
      case IntMultiType(false, Width4, _) => "u32"
      case IntMultiType(false, Width8, _) => "u64"

      case Int1Type(true) => "i8"
      case IntMultiType(true, Width2, _) => "i16"
      case IntMultiType(true, Width4, _) => "i32"
      case IntMultiType(true, Width8, _) => "i64"

      case FloatMultiType(Width4, _) => "f32"
      case FloatMultiType(Width8, _) => "f64"

      case BitsType(_, _) => "u64"

      case _: BooleanType => "bool"
      case CalcIntType => "i32"
      case CalcFloatType => "f64"

      case _: StrType => "String"
      case _: BytesType => "Vec"

      case t: UserType => t.classSpec match {
        case Some(cs) => s"Box<${type2class(cs.name)}>"
        case None => s"Box<${type2class(t.name)}>"
      }

      case t: EnumType => t.enumSpec match {
        case Some(cs) => s"Box<${type2class(cs.name)}>"
        case None => s"Box<${type2class(t.name)}>"
      }

      case at: ArrayType => s"Vec<${kaitaiType2NativeType(at.elType)}>"

      case KaitaiStreamType | OwnedKaitaiStreamType => s"Option>"
      case KaitaiStructType | CalcKaitaiStructType => s"Option>"

      case st: SwitchType => kaitaiType2NativeType(st.combinedType)
    }
  }

  def kaitaiType2Default(attrType: DataType): String = {
    attrType match {
      case Int1Type(false) => "0"
      case IntMultiType(false, Width2, _) => "0"
      case IntMultiType(false, Width4, _) => "0"
      case IntMultiType(false, Width8, _) => "0"

      case Int1Type(true) => "0"
      case IntMultiType(true, Width2, _) => "0"
      case IntMultiType(true, Width4, _) => "0"
      case IntMultiType(true, Width8, _) => "0"

      case FloatMultiType(Width4, _) => "0"
      case FloatMultiType(Width8, _) => "0"

      case BitsType(_, _) => "0"

      case _: BooleanType => "false"
      case CalcIntType => "0"
      case CalcFloatType => "0"

      case _: StrType => "\"\""
      case _: BytesType => "vec!()"

      case t: UserType => "Default::default()"
      case t: EnumType => "Default::default()"

      case ArrayTypeInStream(inType) => "vec!()"

      case KaitaiStreamType | OwnedKaitaiStreamType => "None"
      case KaitaiStructType => "None"

      case _: SwitchType => ""
      // TODO
    }
  }

  def type2class(names: List[String]) = types2classRel(names)

  def type2classAbs(names: List[String]) =
    names.mkString("::")

  override def ksErrorName(err: KSError): String = RustCompiler.ksErrorName(err)
}

object RustCompiler extends LanguageCompilerStatic
  with StreamStructNames
  with UpperCamelCaseClasses
  with ExceptionNames {
  override def getCompiler(
    tp: ClassTypeProvider,
    config: RuntimeConfig
  ): LanguageCompiler = new RustCompiler(tp, config)

  override def kstructName = "&Option>"
  override def kstreamName = "&mut S"
  override def ksErrorName(err: KSError): String = ???

  def types2class(typeName: Ast.typeId) = {
    typeName.names.map(type2class).mkString(
      if (typeName.absolute) "__" else "",
      "__",
      ""
    )
  }

  def types2classRel(names: List[String]) =
    names.map(type2class).mkString("__")
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy