All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.kaitai.struct.languages.GoCompiler.scala Maven / Gradle / Ivy

package io.kaitai.struct.languages

import io.kaitai.struct.datatype.DataType._
import io.kaitai.struct.datatype._
import io.kaitai.struct.exprlang.Ast
import io.kaitai.struct.format._
import io.kaitai.struct.languages.components._
import io.kaitai.struct.translators.{GoTranslator, ResultString, TranslatorResult}
import io.kaitai.struct.{ClassTypeProvider, RuntimeConfig, Utils}

class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
  extends LanguageCompiler(typeProvider, config)
    with SingleOutputFile
    with UpperCamelCaseClasses
    with ObjectOrientedLanguage
    with UniversalFooter
    with UniversalDoc
    with AllocateIOLocalVar
    with GoReads {
  import GoCompiler._

  override val translator = new GoTranslator(out, typeProvider, importList)

  override def innerClasses = false

  override def headerComment = "Code generated by kaitai-struct-compiler from a .ksy source file. DO NOT EDIT."

  override def universalFooter: Unit = {
    out.dec
    out.puts("}")
  }

  override def indent: String = "\t"
  override def outFileName(topClassName: String): String =
    s"${config.goPackage}/$topClassName.go"

  override def outImports(topClass: ClassSpec): String = {
    val imp = importList.toList
    imp.size match {
      case 0 => ""
      case 1 => "import \"" + imp.head + "\"\n"
      case _ =>
        "import (\n" +
        imp.map((x) => indent + "\"" + x + "\"").mkString("", "\n", "\n") +
        ")\n"
    }
  }

  override def fileHeader(topClassName: String): Unit = {
    outHeader.puts(s"// $headerComment")
    if (config.goPackage.nonEmpty) {
      outHeader.puts
      outHeader.puts(s"package ${config.goPackage}")
    }
    outHeader.puts

    importList.add("github.com/kaitai-io/kaitai_struct_go_runtime/kaitai")

    out.puts
  }

  override def classHeader(name: List[String]): Unit = {
    out.puts(s"type ${types2class(name)} struct {")
    out.inc
  }

  override def classFooter(name: List[String]): Unit = {
    // TODO(jchw): where should this attribute actually be generated at?
    typeProvider.nowClass.meta.endian match {
      case Some(_: CalcEndian) | Some(InheritedEndian) =>
        out.puts(s"${idToStr(EndianIdentifier)} int")
      case _ =>
    }
    universalFooter
  }

  override def classConstructorHeader(name: List[String], parentType: DataType, rootClassName: List[String], isHybrid: Boolean, params: List[ParamDefSpec]): Unit = {
    val paramsArg = params.map((p) =>
      s"${paramName(p.id)} ${kaitaiType2NativeType(p.dataType)}"
    ).mkString(", ")
    out.puts(s"func New${types2class(name)}($paramsArg) *${types2class(name)} {")
    out.inc
    out.puts(s"return &${types2class(name)}{")
    out.inc
    params.foreach(p => out.puts(s"${idToStr(p.id)}: ${paramName(p.id)},"))
    out.dec
    out.puts("}")
    universalFooter
  }

  override def classConstructorFooter: Unit = {}

  override def runRead(name: List[String]): Unit = {
    out.puts("this.Read()")
  }

  override def runReadCalc(): Unit = {
    out.puts
    out.puts(s"switch ${privateMemberName(EndianIdentifier)} {")
    out.puts("case 0:")
    out.inc
    out.puts("err = this._read_be()")
    out.dec
    out.puts("case 1:")
    out.inc
    out.puts("err = this._read_le()")
    out.dec
    out.puts("default:")
    out.inc
    out.puts(s"err = ${GoCompiler.ksErrorName(UndecidedEndiannessError)}{}")
    out.dec
    out.puts("}")
  }

  override def readHeader(endian: Option[FixedEndian], isEmpty: Boolean): Unit = {
    endian match {
      case None =>
        out.puts
        out.puts(
          s"func (this *${types2class(typeProvider.nowClass.name)}) Read(" +
            s"io *$kstreamName, " +
            s"parent ${kaitaiType2NativeType(typeProvider.nowClass.parentType)}, " +
            s"root *${types2class(typeProvider.topClass.name)}) (err error) {"
        )
        out.inc
        out.puts(s"${privateMemberName(IoIdentifier)} = io")
        out.puts(s"${privateMemberName(ParentIdentifier)} = parent")
        out.puts(s"${privateMemberName(RootIdentifier)} = root")
        typeProvider.nowClass.meta.endian match {
          case Some(_: CalcEndian) =>
            out.puts(s"${privateMemberName(EndianIdentifier)} = -1")
          case Some(InheritedEndian) =>
            out.puts(s"${privateMemberName(EndianIdentifier)} = " +
              s"${privateMemberName(ParentIdentifier)}." +
              s"${idToStr(EndianIdentifier)}")
          case _ =>
        }
        out.puts
      case Some(e) =>
        out.puts
        out.puts(
          s"func (this *${types2class(typeProvider.nowClass.name)}) " +
            s"_read_${e.toSuffix}() (err error) {")
        out.inc
    }

  }
  override def readFooter(): Unit = {
    out.puts("return err")
    universalFooter
  }

  override def attributeDeclaration(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = {
    out.puts(s"${idToStr(attrName)} ${kaitaiType2NativeType(attrType)}")
    translator.returnRes = None
  }

  override def attributeReader(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = {}

  override def universalDoc(doc: DocSpec): Unit = {
    out.puts
    out.puts( "/**")

    doc.summary.foreach(summary => out.putsLines(" * ", summary))

    doc.ref.foreach {
      case TextRef(text) =>
        out.putsLines(" * ", "@see \"" + text + "\"")
      case ref: UrlRef =>
        out.putsLines(" * ", s"@see ${ref.toAhref}")
    }

    out.puts( " */")
  }

  override def attrParseHybrid(leProc: () => Unit, beProc: () => Unit): Unit = {
    out.puts(s"switch ${privateMemberName(EndianIdentifier)} {")
    out.puts("case 0:")
    out.inc
    beProc()
    out.dec
    out.puts("case 1:")
    out.inc
    leProc()
    out.dec
    out.puts("default:")
    out.inc
    out.puts(s"err = ${GoCompiler.ksErrorName(UndecidedEndiannessError)}{}")
    out.dec
    out.puts("}")
  }

  override def attrFixedContentsParse(attrName: Identifier, contents: Array[Byte]): Unit = {
    out.puts(s"${privateMemberName(attrName)}, err = $normalIO.ReadBytes(${contents.length})")

    out.puts(s"if err != nil {")
    out.inc
    out.puts("return err")
    out.dec
    out.puts("}")

    importList.add("bytes")
    importList.add("errors")
    val expected = translator.resToStr(translator.doByteArrayLiteral(contents))
    out.puts(s"if !bytes.Equal(${privateMemberName(attrName)}, $expected) {")
    out.inc
    out.puts("return errors.New(\"Unexpected fixed contents\")")
    out.dec
    out.puts("}")
  }

  override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit = {
    val srcExpr = getRawIdExpr(varSrc, rep)

    val expr = proc match {
      case ProcessXor(xorValue) =>
        translator.detectType(xorValue) match {
          case _: IntType =>
            s"kaitai.ProcessXOR($srcExpr, []byte{${expression(xorValue)}})"
          case _: BytesType =>
            s"kaitai.ProcessXOR($srcExpr, ${expression(xorValue)})"
        }
      case ProcessZlib =>
        translator.resToStr(translator.outVarCheckRes(s"kaitai.ProcessZlib($srcExpr)"))
      case ProcessRotate(isLeft, rotValue) =>
        val expr = if (isLeft) {
          expression(rotValue)
        } else {
          s"8 - (${expression(rotValue)})"
        }
        s"kaitai.ProcessRotateLeft($srcExpr, int($expr))"
      case ProcessCustom(name, args) =>
        // TODO(jchw): This hack is necessary because Go tests fail catastrophically otherwise...
        s"$srcExpr"
    }
    handleAssignment(varDest, ResultString(expr), rep, false)
  }

  override def allocateIO(varName: Identifier, rep: RepeatSpec): String = {
    val javaName = privateMemberName(varName)

    val ioName = idToStr(IoStorageIdentifier(varName))

    val args = rep match {
      case RepeatUntil(_) => translator.specialName(Identifier.ITERATOR2)
      case _ => getRawIdExpr(varName, rep)
    }

    importList.add("bytes")

    out.puts(s"$ioName := kaitai.NewStream(bytes.NewReader($args))")
    ioName
  }

  def getRawIdExpr(varName: Identifier, rep: RepeatSpec): String = {
    val memberName = privateMemberName(varName)
    rep match {
      case NoRepeat => memberName
      case RepeatExpr(_) => s"$memberName[i]"
      case _ => s"$memberName[len($memberName) - 1]"
    }
  }

  override def useIO(ioEx: Ast.expr): String = {
    out.puts(s"thisIo := ${expression(ioEx)}")
    "thisIo"
  }

  override def pushPos(io: String): Unit = {
    out.puts(s"_pos, err := $io.Pos()")
    translator.outAddErrCheck()
  }

  override def seek(io: String, pos: Ast.expr): Unit = {
    importList.add("io")

    out.puts(s"_, err = $io.Seek(int64(${expression(pos)}), io.SeekStart)")
    translator.outAddErrCheck()
  }

  override def popPos(io: String): Unit = {
    importList.add("io")

    out.puts(s"_, err = $io.Seek(_pos, io.SeekStart)")
    translator.outAddErrCheck()
  }

  override def alignToByte(io: String): Unit =
    out.puts(s"$io.AlignToByte()")

  override def condIfHeader(expr: Ast.expr): Unit = {
    out.puts(s"if (${expression(expr)}) {")
    out.inc
  }

  override def condRepeatCommonInit(id: Identifier, dataType: DataType, needRaw: NeedRaw): Unit = {
    // slices don't have to be manually initialized in Go: the built-in append()
    // function works even on `nil` slices (https://go.dev/tour/moretypes/15)
  }

  override def condRepeatEosHeader(id: Identifier, io: String, dataType: DataType): Unit = {
    out.puts(s"for i := 1;; i++ {")
    out.inc

    val eofVar = translator.allocateLocalVar()
    out.puts(s"${translator.localVarName(eofVar)}, err := this._io.EOF()")
    translator.outAddErrCheck()
    out.puts(s"if ${translator.localVarName(eofVar)} {")
    out.inc
    out.puts("break")
    out.dec
    out.puts("}")
  }

  override def handleAssignmentRepeatEos(id: Identifier, r: TranslatorResult): Unit = {
    val name = privateMemberName(id)
    val expr = translator.resToStr(r)
    out.puts(s"$name = append($name, $expr)")
  }

  override def condRepeatExprHeader(id: Identifier, io: String, dataType: DataType, repeatExpr: Ast.expr): Unit = {
    out.puts(s"for i := 0; i < int(${expression(repeatExpr)}); i++ {")
    out.inc
    // FIXME: Go throws a fatal compile error when the `i` variable is not used (unused variables
    // can only use the blank identifier `_`, see https://go.dev/doc/effective_go#blank), so we have
    // to silence it like this. It would be nice to be able to analyze all expressions that appear
    // in the loop body to decide whether to generate `for _ := range` or `for i := range` here, but
    // that would be really difficult to do properly in KSC with the current architecture.
    out.puts("_ = i")
  }

  override def handleAssignmentRepeatExpr(id: Identifier, r: TranslatorResult): Unit =
    handleAssignmentRepeatEos(id, r)

  override def condRepeatUntilHeader(id: Identifier, io: String, dataType: DataType, untilExpr: Ast.expr): Unit = {
    out.puts(s"for i := 1;; i++ {")
    out.inc
  }

  override def handleAssignmentRepeatUntil(id: Identifier, r: TranslatorResult, isRaw: Boolean): Unit = {
    val expr = translator.resToStr(r)
    val tempVar = translator.specialName(if (isRaw) Identifier.ITERATOR2 else Identifier.ITERATOR)
    out.puts(s"$tempVar := $expr")
    out.puts(s"${privateMemberName(id)} = append(${privateMemberName(id)}, $tempVar)")
  }

  override def condRepeatUntilFooter(id: Identifier, io: String, dataType: DataType, untilExpr: Ast.expr): Unit = {
    typeProvider._currentIteratorType = Some(dataType)
    out.puts(s"if ${expression(untilExpr)} {")
    out.inc
    out.puts("break")
    out.dec
    out.puts("}")
    out.dec
    out.puts("}")
  }

  private def castToType(r: TranslatorResult, dataType: DataType): TranslatorResult = {
    dataType match {
      case t @ (_: IntMultiType | _: FloatMultiType) =>
        ResultString(s"${kaitaiType2NativeType(t)}(${translator.resToStr(r)})")
      case _ =>
        r
    }
  }

  private def combinedType(dataType: DataType) = {
    dataType match {
      case st: SwitchType => st.combinedType
      case _ => dataType
    }
  }

  private def handleCompositeTypeCast(id: Identifier, r: TranslatorResult): TranslatorResult = {
    id match {
      case NamedIdentifier(name) =>
        castToType(r, combinedType(typeProvider.determineType(name)))
      case _ =>
        r
    }
  }

  override def handleAssignmentSimple(id: Identifier, r: TranslatorResult): Unit = {
    val expr = translator.resToStr(handleCompositeTypeCast(id, r))
    out.puts(s"${privateMemberName(id)} = $expr")
  }

  def handleAssignmentTempVar(dataType: DataType, id: String, expr: String): Unit =
    out.puts(s"$id := $expr")

  override def blockScopeHeader: Unit = {
    out.puts("{")
    out.inc
  }
  override def blockScopeFooter: Unit = universalFooter

  override def parseExpr(dataType: DataType, io: String, defEndian: Option[FixedEndian]): String = {
    dataType match {
      case t: ReadableType =>
        s"$io.Read${Utils.capitalize(t.apiCall(defEndian))}()"
      case blt: BytesLimitType =>
        s"$io.ReadBytes(int(${expression(blt.size)}))"
      case _: BytesEosType =>
        s"$io.ReadBytesFull()"
      case BytesTerminatedType(terminator, include, consume, eosError, _) =>
        s"$io.ReadBytesTerm($terminator, $include, $consume, $eosError)"
      case BitsType1(bitEndian) =>
        s"$io.ReadBitsInt${Utils.upperCamelCase(bitEndian.toSuffix)}(1)"
      case BitsType(width: Int, bitEndian) =>
        s"$io.ReadBitsInt${Utils.upperCamelCase(bitEndian.toSuffix)}($width)"
      case t: UserType =>
        val addArgs = if (t.isOpaque) {
          ""
        } else {
          val parent = t.forcedParent match {
            case Some(USER_TYPE_NO_PARENT) => "null"
            case Some(fp) => translator.translate(fp)
            case None => "this"
          }
          s", $parent, _root"
        }
        s"${types2class(t.name)}($io$addArgs)"
    }
  }

//  override def bytesPadTermExpr(expr0: String, padRight: Option[Int], terminator: Option[Int], include: Boolean) = {
//    val expr1 = padRight match {
//      case Some(padByte) => s"$kstreamName.bytesStripRight($expr0, (byte) $padByte)"
//      case None => expr0
//    }
//    val expr2 = terminator match {
//      case Some(term) => s"$kstreamName.bytesTerminate($expr1, (byte) $term, $include)"
//      case None => expr1
//    }
//    expr2
//  }

  override def switchStart(id: Identifier, on: Ast.expr): Unit = {
    out.puts(s"switch (${expression(on)}) {")
  }

  override def switchCaseStart(condition: Ast.expr): Unit = {
    out.puts(s"case ${expression(condition)}:")
    out.inc
  }

  override def switchCaseEnd(): Unit = {
    out.dec
  }

  override def switchElseStart(): Unit = {
    out.puts("default:")
    out.inc
  }

  override def switchEnd(): Unit =
    out.puts("}")

  override def switchShouldUseCompareFn(onType: DataType): (Option[String], () => Unit) = {
    onType match {
      case _: BytesType =>
        (Some("bytes.Equal"), () => importList.add("bytes"))
      case _ =>
        (None, () => {})
    }
  }

  override def switchCaseStartCompareFn(compareFn: String, switchOn: Ast.expr, condition: Ast.expr): Unit = {
    out.puts(s"case ${compareFn}(${expression(switchOn)}, ${expression(condition)}):")
    out.inc
  }

  override def instanceDeclaration(attrName: InstanceIdentifier, attrType: DataType, isNullable: Boolean): Unit = {
    out.puts(s"${calculatedFlagForName(attrName)} bool")
    out.puts(s"${idToStr(attrName)} ${kaitaiType2NativeType(attrType)}")
  }

  override def instanceHeader(className: List[String], instName: InstanceIdentifier, dataType: DataType, isNullable: Boolean): Unit = {
    out.puts(s"func (this *${types2class(className)}) ${publicMemberName(instName)}() (v ${kaitaiType2NativeType(dataType)}, err error) {")
    out.inc
    translator.returnRes = Some(dataType match {
      case _: NumericType => "0"
      case _: BooleanType => "false"
      case _: StrType => "\"\""
      case _ => "nil"
    })
  }

  override def instanceCalculate(instName: Identifier, dataType: DataType, value: Ast.expr): Unit = {
    val r = translator.translate(value)
    val converted = dataType match {
      case _: UserType => r
      case _ => s"${kaitaiType2NativeType(dataType)}($r)"
    }
    out.puts(s"${privateMemberName(instName)} = $converted")
  }

  override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = {
    out.puts(s"if (this.${calculatedFlagForName(instName)}) {")
    out.inc
    instanceReturn(instName, dataType)
    universalFooter
  }

  override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = {
    out.puts(s"return ${privateMemberName(instName)}, nil")
  }

  override def instanceSetCalculated(instName: InstanceIdentifier): Unit =
    out.puts(s"this.${calculatedFlagForName(instName)} = true")

  override def enumDeclaration(curClass: List[String], enumName: String, enumColl: Seq[(Long, EnumValueSpec)]): Unit = {
    val fullEnumName: List[String] = curClass ++ List(enumName)
    val fullEnumNameStr = types2class(fullEnumName)

    out.puts
    out.puts(s"type $fullEnumNameStr int")
    out.puts("const (")
    out.inc

    enumColl.foreach { case (id, label) =>
      out.puts(s"${enumToStr(fullEnumName, label.name)} $fullEnumNameStr = $id")
    }

    out.dec
    out.puts(")")
  }

  override def idToStr(id: Identifier): String = GoCompiler.idToStr(id)

  override def publicMemberName(id: Identifier): String = GoCompiler.publicMemberName(id)

  override def privateMemberName(id: Identifier): String = s"this.${idToStr(id)}"

  override def localTemporaryName(id: Identifier): String = s"_t_${idToStr(id)}"

  override def paramName(id: Identifier): String = Utils.lowerCamelCase(id.humanReadable)

  def calculatedFlagForName(id: Identifier) = s"_f_${idToStr(id)}"

  override def ksErrorName(err: KSError): String = GoCompiler.ksErrorName(err)

  override def attrValidateExpr(
    attrId: Identifier,
    attrType: DataType,
    checkExpr: Ast.expr,
    err: KSError,
    errArgs: List[Ast.expr]
  ): Unit = {
    val errArgsStr = errArgs.map(translator.translate).mkString(", ")
    out.puts(s"if !(${translator.translate(checkExpr)}) {")
    out.inc
    val errInst = s"kaitai.New${err.name}($errArgsStr)"
    val noValueAndErr = translator.returnRes match {
      case None => errInst
      case Some(r) => s"$r, $errInst"
    }
    out.puts(s"return $noValueAndErr")
    out.dec
    out.puts("}")
  }
}

object GoCompiler extends LanguageCompilerStatic
  with UpperCamelCaseClasses
  with StreamStructNames
  with ExceptionNames {

  override def getCompiler(
    tp: ClassTypeProvider,
    config: RuntimeConfig
  ): LanguageCompiler = new GoCompiler(tp, config)

  def idToStr(id: Identifier): String =
    id match {
      case SpecialIdentifier(name) => name
      case NamedIdentifier(name) => Utils.upperCamelCase(name)
      case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx"
      case InstanceIdentifier(name) => Utils.lowerCamelCase(name)
      case RawIdentifier(innerId) => s"_raw_${idToStr(innerId)}"
      case IoStorageIdentifier(innerId) => s"_io_${idToStr(innerId)}"
    }

  def publicMemberName(id: Identifier): String =
    id match {
      case IoIdentifier => "_IO"
      case RootIdentifier => "_Root"
      case ParentIdentifier => "_Parent"
      case InstanceIdentifier(name) => Utils.upperCamelCase(name)
      case _ => idToStr(id)
    }

  /**
    * Determine Go data type corresponding to a KS data type.
    *
    * @param attrType KS data type
    * @return Go data type
    */
  def kaitaiType2NativeType(attrType: DataType): String = {
    attrType match {
      case Int1Type(false) => "uint8"
      case IntMultiType(false, Width2, _) => "uint16"
      case IntMultiType(false, Width4, _) => "uint32"
      case IntMultiType(false, Width8, _) => "uint64"

      case Int1Type(true) => "int8"
      case IntMultiType(true, Width2, _) => "int16"
      case IntMultiType(true, Width4, _) => "int32"
      case IntMultiType(true, Width8, _) => "int64"

      case FloatMultiType(Width4, _) => "float32"
      case FloatMultiType(Width8, _) => "float64"

      case BitsType(_, _) => "uint64"

      case _: BooleanType => "bool"
      case CalcIntType => "int"
      case CalcFloatType => "float64"

      case _: StrType => "string"
      case _: BytesType => "[]byte"

      case AnyType => "interface{}"
      case KaitaiStreamType | OwnedKaitaiStreamType => "*" + kstreamName
      case KaitaiStructType | CalcKaitaiStructType => kstructName

      case t: UserType => "*" + types2class(t.classSpec match {
        case Some(cs) => cs.name
        case None => t.name
      })
      case t: EnumType => types2class(t.enumSpec.get.name)

      case at: ArrayType => s"[]${kaitaiType2NativeType(at.elType)}"

      case st: SwitchType => kaitaiType2NativeType(st.combinedType)
    }
  }

  def types2class(names: List[String]): String = names.map(x => type2class(x)).mkString("_")

  def enumToStr(enumTypeAbs: List[String]): String = {
    val enumName = enumTypeAbs.last
    val enumClass: List[String] = enumTypeAbs.dropRight(1)
    enumToStr(enumClass, enumName)
  }

  def enumToStr(typeName: List[String], enumName: String): String =
    types2class(typeName) + "__" + type2class(enumName)

  override def kstreamName: String = "kaitai.Stream"
  override def kstructName: String = "interface{}"
  override def ksErrorName(err: KSError): String = s"kaitai.${err.name}"
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy