All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.kaitai.struct.translators.GoTranslator.scala Maven / Gradle / Ivy

package io.kaitai.struct.translators

import io.kaitai.struct.datatype.DataType
import io.kaitai.struct.datatype.DataType._
import io.kaitai.struct.exprlang.Ast
import io.kaitai.struct.format.{ClassSpec, Identifier}
import io.kaitai.struct.languages.GoCompiler
import io.kaitai.struct.precompile.TypeMismatchError
import io.kaitai.struct.{ImportList, StringLanguageOutputWriter, Utils}
import io.kaitai.struct.format.SpecialIdentifier
import io.kaitai.struct.format.NamedIdentifier
import io.kaitai.struct.format.InstanceIdentifier

sealed trait TranslatorResult
case class ResultString(s: String) extends TranslatorResult
case class ResultLocalVar(n: Int) extends TranslatorResult

class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, importList: ImportList)
  extends TypeDetector(provider)
  with AbstractTranslator
  with CommonLiterals
  with CommonOps
  with CommonArraysAndCast[TranslatorResult]
  with CommonMethods[TranslatorResult]
  with ByteArraysAsTrueArrays[TranslatorResult] {

  import io.kaitai.struct.languages.GoCompiler._

  var returnRes: Option[String] = None

  override def translate(v: Ast.expr): String = resToStr(translateExpr(v))

  def resToStr(r: TranslatorResult): String = r match {
    case ResultString(s) => s
    case ResultLocalVar(n) => localVarName(n)
  }

  def translateExpr(v: Ast.expr): TranslatorResult = {
    v match {
      case Ast.expr.IntNum(n) =>
        trIntLiteral(n)
      case Ast.expr.FloatNum(n) =>
        trFloatLiteral(n)
      case Ast.expr.Str(s) =>
        trStringLiteral(s)
      case Ast.expr.Bool(n) =>
        trBoolLiteral(n)
      case Ast.expr.EnumById(enumType, id, inType) =>
        val enumSpec = provider.resolveEnum(inType, enumType.name)
        trEnumById(enumSpec.name, translate(id))
      case Ast.expr.EnumByLabel(enumType, label, inType) =>
        val enumSpec = provider.resolveEnum(inType, enumType.name)
        trEnumByLabel(enumSpec.name, label.name)
      case Ast.expr.Name(name: Ast.identifier) =>
        if (name.name == Identifier.SIZEOF) {
          byteSizeOfClassSpec(provider.nowClass)
        } else {
          trLocalName(name.name)
        }
      case Ast.expr.InternalName(id: Identifier) =>
        trInternalName(id)
      case Ast.expr.UnaryOp(op: Ast.unaryop, inner: Ast.expr) =>
        val opStr = unaryOp(op)
        ResultString((op, inner) match {
          /** [[doIntLiteral]] has to know when a negative number is being translated - if it
           * doesn't, the result is things like `-uint32(2147483648)` that will not compile in Go
           * (the error is "constant -2147483648 overflows uint32") */
          case (Ast.unaryop.Minus, Ast.expr.IntNum(n)) => translate(Ast.expr.IntNum(-n))
          case (_, Ast.expr.IntNum(_) | Ast.expr.FloatNum(_)) =>
            s"$opStr${translate(inner)}"
          case _ =>
            s"$opStr(${translate(inner)})"
        })
      case Ast.expr.Compare(left, op, right) =>
        (detectType(left), detectType(right)) match {
          case (_: NumericType, _: NumericType) =>
            trNumericCompareOp(left, op, right)
          case (_: StrType, _: StrType) =>
            trStrCompareOp(left, op, right)
          case (_: BytesType, _: BytesType) =>
            trBytesCompareOp(left, op, right)
          case (_: BooleanType, _: BooleanType) =>
            trNumericCompareOp(left, op, right)
          case (_: EnumType, _: EnumType) =>
            trNumericCompareOp(left, op, right)
          case (ltype, rtype) =>
            throw new TypeMismatchError(s"can't do $ltype $op $rtype")
        }
      case Ast.expr.BinOp(left: Ast.expr, op: Ast.operator, right: Ast.expr) =>
        (detectType(left), detectType(right), op) match {
          case (_: NumericType, _: NumericType, _) =>
            trNumericBinOp(left, op, right)
          case (_: StrType, _: StrType, Ast.operator.Add) =>
            trStrConcat(left, right)
          case (ltype, rtype, _) =>
            throw new TypeMismatchError(s"can't do $ltype $op $rtype")
        }
      case Ast.expr.BoolOp(op, values) =>
        trBooleanOp(op, values)
      case Ast.expr.IfExp(condition, ifTrue, ifFalse) =>
        trIfExp(condition, ifTrue, ifFalse)
      case Ast.expr.Subscript(container, idx) =>
        arraySubscript(container, idx)
      case call: Ast.expr.Attribute =>
        translateAttribute(call)
      case call: Ast.expr.Call =>
        translateCall(call)
      case Ast.expr.List(elts) =>
        doGuessArrayLiteral(elts)
      case ctt: Ast.expr.CastToType =>
        doCastOrArray(ctt)
      case Ast.expr.ByteSizeOfType(typeName) =>
        doByteSizeOfType(typeName)
      case Ast.expr.BitSizeOfType(typeName) =>
        doBitSizeOfType(typeName)
    }
  }

  def trIntLiteral(n: BigInt): TranslatorResult = ResultString(doIntLiteral(n))
  def trFloatLiteral(n: BigDecimal): TranslatorResult = ResultString(doFloatLiteral(n))
  def trStringLiteral(s: String): TranslatorResult = ResultString(doStringLiteral(s))
  def trBoolLiteral(n: Boolean): TranslatorResult = ResultString(doBoolLiteral(n))

  def doByteSizeOfType(typeName: Ast.typeId): TranslatorResult = trIntLiteral(
    CommonSizeOf.bitToByteSize(
      CommonSizeOf.getBitsSizeOfType(
        typeName.nameAsStr, detectCastType(typeName)
      )
    )
  )
  def doBitSizeOfType(typeName: Ast.typeId): TranslatorResult = trIntLiteral(
    CommonSizeOf.getBitsSizeOfType(
      typeName.nameAsStr, detectCastType(typeName)
    )
  )
  def byteSizeOfClassSpec(cs: ClassSpec): TranslatorResult =
    trIntLiteral(CommonSizeOf.getByteSizeOfClassSpec(cs))

  def trBooleanOp(op: Ast.boolop, values: Seq[Ast.expr]) =
    ResultString(doBooleanOp(op, values))

  def trNumericBinOp(left: Ast.expr, op: Ast.operator, right: Ast.expr): TranslatorResult = {
    (detectType(left), detectType(right), op) match {
      case (t1: IntType, t2: IntType, Ast.operator.Mod) =>
        val v1 = allocateLocalVar()
        out.puts(s"${localVarName(v1)} := ${translate(left)} % ${translate(right)}")
        out.puts(s"if ${localVarName(v1)} < 0 {")
        out.inc
        out.puts(s"${localVarName(v1)} += ${translate(right)}")
        out.dec
        out.puts("}")
        ResultLocalVar(v1)
      case _ =>
        ResultString(numericBinOp(left, op, right))
    }
  }

  def trStrConcat(left: Ast.expr, right: Ast.expr): TranslatorResult =
    ResultString(translate(left) + " + " + translate(right))

  def trNumericCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): TranslatorResult =
    ResultString(doNumericCompareOp(left, op, right))

  def trStrCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): TranslatorResult =
    ResultString(doStrCompareOp(left, op, right))

  def trBytesCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): TranslatorResult = {
    importList.add("bytes")
    op match {
      case Ast.cmpop.Eq =>
        ResultString(s"bytes.Equal(${translate(left)}, ${translate(right)})")
      case _ =>
        ResultString(s"(bytes.Compare(${translate(left)}, ${translate(right)}) ${cmpOp(op)} 0)")
    }
  }

  override def doIntLiteral(n: BigInt): String = {
    if (n < -9223372036854775808L) {
      s"$n" // too low, no type conversion would help anyway
    } else if (n <= -2147483649L) {
      s"int64($n)" // -9223372036854775808..-2147483649
    } else if (n <= 2147483647L) {
      s"$n" // -2147483648..2147483647
    } else if (n <= 4294967295L) {
      s"uint32($n)" // 2147483648..4294967295
    } else if (n <= 9223372036854775807L) {
      s"int64($n)" // 4294967296..9223372036854775807
    } else if (n <= Utils.MAX_UINT64) {
      s"uint64($n)" // 9223372036854775808..18446744073709551615
    } else {
      s"$n" // too high, no type conversion would help anyway
    }
  }

  override def unaryOp(op: Ast.unaryop): String = op match {
    case Ast.unaryop.Invert => "^"
    case Ast.unaryop.Minus => "-"
    case Ast.unaryop.Not => "!"
  }

  def trLocalName(s: String): TranslatorResult = {
    s match {
      case Identifier.ROOT |
           Identifier.PARENT |
           Identifier.IO =>
        ResultString(s"this.${specialName(s)}")

      // These can be local only
      case Identifier.ITERATOR |
           Identifier.ITERATOR2 =>
        ResultString(specialName(s))
      case Identifier.INDEX => ResultString("i")

      case _ =>
        if (provider.isLazy(s)) {
          outVarCheckRes(s"this.${Utils.upperCamelCase(s)}()")
        } else {
          ResultString(s"this.${Utils.upperCamelCase(s)}")
        }
    }
  }

  def trInternalName(id: Identifier): TranslatorResult =
    id match {
      case SpecialIdentifier(name) => trLocalName(name)
      case NamedIdentifier(name) => trLocalName(name)
      case InstanceIdentifier(name) => trLocalName(name)
      case _ => ResultString(s"this.${GoCompiler.publicMemberName(id)}")
    }

  def specialName(id: String): String = id match {
    case Identifier.ROOT | Identifier.PARENT | Identifier.IO =>
      id
    case Identifier.ITERATOR =>
      "_it"
    case Identifier.ITERATOR2 =>
      "_buf"
  }

  def arraySubscript(container: Ast.expr, idx: Ast.expr) =
    ResultString(s"${translate(container)}[${translate(idx)}]")

  def trIfExp(condition: Ast.expr, ifTrue: Ast.expr, ifFalse: Ast.expr): ResultLocalVar = {
    val v1 = allocateLocalVar()
    val typ = detectType(ifTrue)
    out.puts(s"var ${localVarName(v1)} ${GoCompiler.kaitaiType2NativeType(typ)};")
    out.puts(s"if (${translate(condition)}) {")
    out.inc
    out.puts(s"${localVarName(v1)} = ${translate(ifTrue)}")
    out.dec
    out.puts("} else {")
    out.inc
    out.puts(s"${localVarName(v1)} = ${translate(ifFalse)}")
    out.dec
    out.puts("}")
    ResultLocalVar(v1)
  }

  def trEnumByLabel(enumTypeAbs: List[String], label: String) =
    ResultString(GoCompiler.enumToStr(enumTypeAbs, label))
  def trEnumById(enumTypeAbs: List[String], id: String) =
    ResultString(s"${types2class(enumTypeAbs)}($id)")

  override def doBytesCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): String = {
    op match {
      case Ast.cmpop.Eq =>
        s"Arrays.equals(${translate(left)}, ${translate(right)})"
      case Ast.cmpop.NotEq =>
        s"!Arrays.equals(${translate(left)}, ${translate(right)})"
      case _ =>
        s"(${GoCompiler.kstreamName}.byteArrayCompare(${translate(left)}, ${translate(right)}) ${cmpOp(op)} 0)"
    }
  }

  override def doCast(value: Ast.expr, typeName: DataType): TranslatorResult = ???

  override def doArrayLiteral(t: DataType, value: Seq[Ast.expr]) =
    ResultString(s"[]${GoCompiler.kaitaiType2NativeType(t)}{${value.map(translate).mkString(", ")}}")

  override def doByteArrayLiteral(arr: Seq[Byte]): TranslatorResult =
    ResultString("[]uint8{" + arr.map(_ & 0xff).mkString(", ") + "}")

  override def doByteArrayNonLiteral(elts: Seq[Ast.expr]): TranslatorResult =
    ResultString("[]uint8{" + elts.map(translate).mkString(", ") + "}")

  // Predefined methods of various types

  val IMPORT_CHARMAP = "golang.org/x/text/encoding/charmap"

  val ENCODINGS = Map(
    "cp437" -> ("charmap.CodePage437", IMPORT_CHARMAP),
    "iso8859-1" -> ("charmap.ISO8859_1", IMPORT_CHARMAP),
    "iso8859-2" -> ("charmap.ISO8859_2", IMPORT_CHARMAP),
    "iso8859-3" -> ("charmap.ISO8859_3", IMPORT_CHARMAP),
    "iso8859-4" -> ("charmap.ISO8859_4", IMPORT_CHARMAP),
    "sjis" -> ("japanese.ShiftJIS", "golang.org/x/text/encoding/japanese"),
    "big5" -> ("traditionalchinese.Big5", "golang.org/x/text/encoding/traditionalchinese"),
    "utf-16le" -> ("unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM)", "golang.org/x/text/encoding/unicode"),
    "utf-16be" -> ("unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM)", "golang.org/x/text/encoding/unicode")
  )

  override def bytesToStr(value: Ast.expr, expr: Ast.expr): TranslatorResult =
    bytesToStr(translate(value), expr)

  def bytesToStr(bytesExpr: String, encoding: Ast.expr): TranslatorResult = {
    val enc = encoding match {
      case Ast.expr.Str(s) => s
      case _ => throw new RuntimeException("Variable encodings are not supported in Go yet")
    }

    enc.toLowerCase match {
      case "ascii" | "utf-8" | "utf8" =>
        // no conversion
        // FIXME: may be add some checks for valid ASCII/UTF-8
        ResultString(s"string($bytesExpr)")
      case encStr =>
        ENCODINGS.get(encStr) match {
          case Some((decoderSrc, importName)) =>
            importList.add(importName)
            outVarCheckRes(s"kaitai.BytesToStr($bytesExpr, $decoderSrc.NewDecoder())")
          case None =>
            throw new RuntimeException(s"encoding '$encStr' in not supported in Go")
        }
    }
  }

//  override def strReverse(s: Ast.expr): String =
//    s"new StringBuilder(${translate(s)}).reverse().toString()"
//  override def strSubstring(s: Ast.expr, from: Ast.expr, to: Ast.expr): String =
//    s"${translate(s)}.substring(${translate(from)}, ${translate(to)})"

  override def arrayFirst(a: Ast.expr): TranslatorResult =
    ResultString(s"${translate(a)}[0]")
  override def arrayLast(a: Ast.expr): ResultString = {
    val v = allocateLocalVar()
    out.puts(s"${localVarName(v)} := ${translate(a)}")
    ResultString(s"${localVarName(v)}[len(${localVarName(v)}) - 1]")
  }
  override def arraySize(a: Ast.expr): TranslatorResult =
    ResultString(s"len(${translate(a)})")
//  override def arrayMin(a: Ast.expr): String =
//    s"Collections.min(${translate(a)})"
//  override def arrayMax(a: Ast.expr): String =
//    s"Collections.max(${translate(a)})"

  override def userTypeField(ut: UserType, value: Ast.expr, name: String): TranslatorResult = {
    val valueStr = translate(value)

    val (call, twoOuts) = name match {
      case Identifier.ROOT |
           Identifier.PARENT |
           Identifier.IO =>
        (specialName(name), false)
      case _ =>
        (Utils.upperCamelCase(name), provider.isLazy(ut.classSpec.get, name))
    }

    if (twoOuts) {
      outVarCheckRes(s"$valueStr.$call()")
    } else {
      ResultString(s"$valueStr.$call")
    }
  }
  def kaitaiStructField(value: Ast.expr, name: String): TranslatorResult = {
    val valueStr = translate(value)

    val (call, twoOuts) = name match {
      case Identifier.ROOT |
           Identifier.PARENT |
           Identifier.IO =>
        (specialName(name), false)
    }

    if (twoOuts) {
      outVarCheckRes(s"$valueStr.$call()")
    } else {
      ResultString(s"$valueStr.$call")
    }
  }

  override def strLength(s: Ast.expr): TranslatorResult = {
    importList.add("unicode/utf8")
    ResultString(s"utf8.RuneCountInString(${translate(s)})")
  }

  override def strReverse(s: Ast.expr): TranslatorResult = {
    ResultString(s"kaitai.StringReverse(${translate(s)})")
  }

  override def strToInt(s: Ast.expr, base: Ast.expr): TranslatorResult = {
    importList.add("strconv")
    outVarCheckRes(s"strconv.ParseInt(${translate(s)}, ${translate(base)}, 0)")
  }

  override def strSubstring(s: Ast.expr, from: Ast.expr, to: Ast.expr): TranslatorResult = {
    ResultString(s"${translate(s)}[${translate(from)}:${translate(to)}]")
  }

  override def intToStr(value: Ast.expr, base: Ast.expr): TranslatorResult = {
    importList.add("strconv")
    ResultString(s"strconv.FormatInt(int64(${translate(value)}), ${translate(base)})")
  }

  override def floatToInt(value: Ast.expr) =
    ResultString(s"int(${translate(value)})")

  override def kaitaiStreamSize(value: Ast.expr) =
    outVarCheckRes(s"${translate(value)}.Size()")

  override def kaitaiStreamEof(value: Ast.expr) =
    outVarCheckRes(s"${translate(value)}.EOF()")

  override def kaitaiStreamPos(value: Ast.expr) =
    outVarCheckRes(s"${translate(value)}.Pos()")

  override def arrayMin(a: Ast.expr): ResultLocalVar = {
    val min = allocateLocalVar()
    val value = allocateLocalVar()
    out.puts(s"${localVarName(min)} := ${translate(a)}[0]")
    out.puts(s"for _, ${localVarName(value)} := range ${translate(a)} {")
    out.inc
    out.puts(s"if ${localVarName(min)} > ${localVarName(value)} {")
    out.inc
    out.puts(s"${localVarName(min)} = ${localVarName(value)}")
    out.dec
    out.puts("}")
    out.dec
    out.puts("}")
    ResultLocalVar(min)
  }

  override def arrayMax(a: Ast.expr): ResultLocalVar = {
    val max = allocateLocalVar()
    val value = allocateLocalVar()
    out.puts(s"${localVarName(max)} := ${translate(a)}[0]")
    out.puts(s"for _, ${localVarName(value)} := range ${translate(a)} {")
    out.inc
    out.puts(s"if ${localVarName(max)} < ${localVarName(value)} {")
    out.inc
    out.puts(s"${localVarName(max)} = ${localVarName(value)}")
    out.dec
    out.puts("}")
    out.dec
    out.puts("}")
    ResultLocalVar(max)
  }

  override def enumToInt(value: Ast.expr, et: EnumType) =
    translateExpr(value)

  override def boolToInt(value: Ast.expr): ResultLocalVar = {
    val v = allocateLocalVar()
    out.puts(s"${localVarName(v)} := 0")
    out.puts(s"if ${translate(value)} {")
    out.inc
    out.puts(s"${localVarName(v)} = 1")
    out.dec
    out.puts("}")
    ResultLocalVar(v)
  }

  def userType(t: UserType, io: String) = {
    val v = allocateLocalVar()
    val parent = t.forcedParent match {
      case Some(USER_TYPE_NO_PARENT) => "nil"
      case Some(fp) => translate(fp)
      case None => "this"
    }
    val root = if (t.isOpaque) "nil" else "this._root"
    val addParams = t.args.map((a) => translate(a)).mkString(", ")
    out.puts(s"${localVarName(v)} := New${GoCompiler.types2class(t.classSpec.get.name)}($addParams)")
    out.puts(s"err = ${localVarName(v)}.Read($io, $parent, $root)")
    outAddErrCheck()
    ResultLocalVar(v)
  }

  def outVarCheckRes(expr: String): ResultLocalVar = {
    val v1 = allocateLocalVar()
    out.puts(s"${localVarName(v1)}, err := $expr")
    outAddErrCheck()
    ResultLocalVar(v1)
  }

  def outTransform(id: ResultLocalVar, expr: String): ResultLocalVar = {
    out.puts(s"${resToStr(id)} = $expr")
    id
  }

  private
  var localVarNum = 0

  def allocateLocalVar(): Int = {
    localVarNum += 1
    localVarNum
  }

  def localVarName(n: Int) = s"tmp$n"

  def outAddErrCheck() {
    out.puts("if err != nil {")
    out.inc

    val noValueAndErr = returnRes match {
      case None => "err"
      case Some(r) => s"$r, err"
    }

    out.puts(s"return $noValueAndErr")
    out.dec
    out.puts("}")
  }

  override def byteSizeOfValue(attrName: String, valType: DataType): TranslatorResult =
    trIntLiteral(CommonSizeOf.bitToByteSize(CommonSizeOf.getBitsSizeOfType(attrName, valType)))
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy