All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.twitter.scrooge.backend.StructTemplate.scala Maven / Gradle / Ivy

The newest version!
package com.twitter.scrooge.backend

/**
 * Copyright 2011 Twitter, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License. You may obtain
 * a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import com.twitter.scrooge.ast._
import com.twitter.scrooge.mustache.Dictionary
import com.twitter.scrooge.mustache.Dictionary._
import com.twitter.scrooge.frontend.ScroogeInternalException

trait StructTemplate { self: TemplateGenerator =>

  case class Binding[FT <: FieldType](name: String, fieldType: FT)

  val TypeTemplate =
    Dictionary(
      "isList" -> v(false),
      "isSet" -> v(false),
      "isMap" -> v(false),
      "isStruct" -> v(false),
      "isEnum" -> v(false),
      "isBase" -> v(false))

  def genWireConstType(t: FunctionType): CodeFragment = t match {
    case _: EnumType => v("I32")
    case _ => genConstType(t)
  }

  def readWriteInfo[T <: FieldType](sid: SimpleID, t: FieldType): Dictionary = {
    t match {
      case t: ListType =>
        val elt = sid.append("_element")
        TypeTemplate + Dictionary(
          "fieldType" -> genType(t),
          "isList" -> v(Dictionary(
            "name" -> genID(sid),
            "eltName" -> genID(elt),
            "eltConstType" -> genConstType(t.eltType),
            "eltWireConstType" -> genWireConstType(t.eltType),
            "eltType" -> genType(t.eltType),
            "eltReadWriteInfo" -> v(readWriteInfo(elt, t.eltType))
          )))
      case t: SetType =>
        val elt =  sid.append("_element")
        TypeTemplate + Dictionary(
          "fieldType" -> genType(t),
          "isSet" -> v(Dictionary(
            "name" -> genID(sid),
            "eltName" -> genID(elt),
            "eltConstType" -> genConstType(t.eltType),
            "eltWireConstType" -> genWireConstType(t.eltType),
            "eltType" -> genType(t.eltType),
            "isEnumSet" -> v(t.eltType.isInstanceOf[EnumType]),
            "eltReadWriteInfo" -> v(readWriteInfo(elt, t.eltType))
          )))
      case t: MapType =>
        val key =  sid.append("_key")
        val value =  sid.append("_value")
        TypeTemplate + Dictionary(
          "fieldType" -> genType(t),
          "isMap" -> v(Dictionary(
            "name" -> genID(sid),
            "keyConstType" -> genConstType(t.keyType),
            "keyWireConstType" -> genWireConstType(t.keyType),
            "valueConstType" -> genConstType(t.valueType),
            "valueWireConstType" -> genWireConstType(t.valueType),
            "keyType" -> genType(t.keyType),
            "valueType" -> genType(t.valueType),
            "keyName" -> genID(key),
            "valueName" -> genID(value),
            "keyReadWriteInfo" -> v(readWriteInfo(key, t.keyType)),
            "valueReadWriteInfo" -> v(readWriteInfo(value, t.valueType))
          )))
      case t: StructType =>
        TypeTemplate + Dictionary(
          "isNamedType" -> v(true),
          "isImported" -> v(t.scopePrefix.isDefined),
          "fieldType" -> genType(t),
          "isStruct" -> v(Dictionary(
            "name" -> genID(sid)
          )))
      case t: EnumType =>
        TypeTemplate + Dictionary(
          "isNamedType" -> v(true),
          "isImported" -> v(t.scopePrefix.isDefined),
          "fieldType" -> {
            genType(t.copy(enum = t.enum.copy(t.enum.sid.toTitleCase)))
          },
          "isEnum" -> v(Dictionary(
            "name" -> genID(sid)
          )))
      case t: BaseType =>
        TypeTemplate + Dictionary(
          "fieldType" -> genType(t),
          "isBase" -> v(Dictionary(
            "type" -> genType(t),
            "name" -> genID(sid),
            "protocolWriteMethod" -> genProtocolWriteMethod(t),
            "protocolReadMethod" -> genProtocolReadMethod(t)
          )))
      case t: ReferenceType =>
        throw new ScroogeInternalException("ReferenceType should have been resolved by now")
    }
  }

  def fieldsToDict(
    fields: Seq[Field],
    blacklist: Seq[String],
    namespace: Option[Identifier] = None
  ): Seq[Dictionary] = {
    fields.zipWithIndex map {
      case (field, index) =>
        val valueVariableID = field.sid.append("_item")
        val fieldName = genID(field.sid)
        val camelCaseFieldName = if (fieldName.toString.indexOf('_') >= 0)
          genID(field.sid.toCamelCase)
        else
          NoValue

        Dictionary(
          "index" -> v(index.toString),
          "indexP1" -> v((index + 1).toString),
          "_fieldName" -> genID(field.sid.prepend("_")), // for Java only
          "unsetName" -> genID(field.sid.toTitleCase.prepend("unset")),
          "readName" -> genID(field.sid.toTitleCase.prepend("read")),
          "getBlobName" -> genID(field.sid.toTitleCase.prepend("get").append("Blob")),
          "readBlobName" -> genID(field.sid.toTitleCase.prepend("read").append("Blob")),
          "getName" -> genID(field.sid.toTitleCase.prepend("get")), // for Java only
          "isSetName" -> genID(field.sid.toTitleCase.prepend("isSet")), // for Java only
          "fieldName" -> fieldName,
          "fieldNameForWire" -> v(field.originalName),
          "fieldNameCamelCase" -> camelCaseFieldName,
          "newFieldName" -> genID(field.sid.toTitleCase.prepend("new")),
          "FieldName" -> genID(field.sid.toTitleCase),
          "FIELD_NAME" -> genID(field.sid.toUpperCase),
          "gotName" -> genID(field.sid.prepend("_got_")),
          "id" -> v(field.index.toString),
          "fieldConst" -> genID(field.sid.toTitleCase.append("Field")),
          "constType" -> genConstType(field.fieldType),
          "isPrimitive" -> v(isPrimitive(field.fieldType)),
          "isLazyReadEnabled" -> v(isLazyReadEnabled(field.fieldType, field.requiredness.isOptional)),
          "primitiveFieldType" -> genPrimitiveType(field.fieldType),
          "fieldType" -> genType(field.fieldType),
          "fieldKeyType" -> v(field.fieldType match {
            case MapType(keyType, _, _) => Some(genType(keyType))
            case _ => None
          }),
          "fieldValueType" -> v(field.fieldType match {
            case MapType(_, valueType, _) => Some(genType(valueType))
            case ListType(valueType, _) => Some(genType(valueType))
            case SetType(valueType, _) => Some(genType(valueType))
            case _ => None
          }),
          "fieldTypeAnnotations" -> StructTemplate.renderPairs(field.typeAnnotations),
          "fieldFieldAnnotations" -> StructTemplate.renderPairs(field.fieldAnnotations),
          "isImported" -> v(field.fieldType match {
            case n: NamedType => n.scopePrefix.isDefined
            case _ => false
          }),
          "isNamedType" -> v(field.fieldType.isInstanceOf[NamedType]),
          "passthroughFields" -> {
            val insides = buildPassthroughFields(field.fieldType)
            if (field.requiredness.isOptional) {
              v(Dictionary(
                "ptIter" -> insides
              ))
            } else {
              insides
            }
          },
          "isEnum" -> v(field.fieldType.isInstanceOf[EnumType]),
          // "qualifiedFieldType" is used to generate qualified type name even if it's not
          // imported, in case other same-named entities are generated in the same file.
          "qualifiedFieldType" -> v(templates("qualifiedFieldType")),
          "hasDefaultValue" -> v(genDefaultFieldValue(field).isDefined),
          "defaultFieldValue" -> genDefaultFieldValue(field).getOrElse(NoValue),
          "defaultReadValue" -> genDefaultReadValue(field),
          "hasGetter" -> v(!blacklist.contains(field.sid.name)),
          "hasIsDefined" -> v(field.requiredness.isOptional || (!field.requiredness.isRequired && !isPrimitive(field.fieldType))),
          "required" -> v(field.requiredness.isRequired),
          "optional" -> v(field.requiredness.isOptional),
          "nullable" -> v(isNullableType(field.fieldType, field.requiredness.isOptional)),
          "collection" -> v {
            (field.fieldType match {
              case ListType(eltType, _) => List(genType(eltType))
              case SetType(eltType, _) => List(genType(eltType))
              case MapType(keyType, valueType, _) => List(
                v("(" + genType(keyType).toData + ", " + genType(valueType).toData + ")"))
              case _ => Nil
            }) map { t => Dictionary("elementType" -> t) }
          },
          "readFieldValueName" -> genID(field.sid.toTitleCase.prepend("read").append("Value")),
          "writeFieldName" -> genID(field.sid.toTitleCase.prepend("write").append("Field")),
          "writeFieldValueName" -> genID(field.sid.toTitleCase.prepend("write").append("Value")),
          "readField" -> v(templates("readField")),
          "decodeProtocol" -> genDecodeProtocolMethod(field.fieldType),
          "offsetSkipProtocol" -> genOffsetSkipProtocolMethod(field.fieldType),
          "readUnionField" -> v(templates("readUnionField")),
          "readLazyField" -> v(templates("readLazyField")),
          "readValue" -> v(templates("readValue")),
          "writeField" -> v(templates("writeField")),
          "writeValue" -> v(templates("writeValue")),
          "writeList" -> v(templates("writeList")),
          "writeSet" -> v(templates("writeSet")),
          "writeMap" -> v(templates("writeMap")),
          "writeStruct" -> v(templates("writeStruct")),
          "writeEnum" -> v(templates("writeEnum")),
          "writeBase" -> v(templates("writeBase")),
          "readList" -> v(templates("readList")),
          "readSet" -> v(templates("readSet")),
          "readMap" -> v(templates("readMap")),
          "readStruct" -> v(templates("readStruct")),
          "readEnum" -> v(templates("readEnum")),
          "readBase" -> v(templates("readBase")),
          "optionalType" -> v(templates("optionalType")),
          "withoutPassthrough" -> v(templates("withoutPassthrough")),
          "readWriteInfo" -> v(readWriteInfo(valueVariableID, field.fieldType)),
          "valueVariableName" -> genID(valueVariableID)
        )
    }
  }

  val basePassthrough = Dictionary(
    "ptStruct" -> v(false),
    "ptIter" -> v(false),
    "ptMap" -> v(false),
    "ptPrimitive" -> v(false)
  )
  private def buildPassthroughFields(fieldType: FieldType): Value = {
    val overrides =
      fieldType match {
        case _: StructType => Dictionary("ptStruct" ->
          v(Dictionary(
            "className" -> genType(fieldType)
          ))
        )
        case t: SetType => Dictionary("ptIter" ->
          buildPassthroughFields(t.eltType)
        )
        case t: ListType => Dictionary("ptIter" ->
          buildPassthroughFields(t.eltType)
        )
        case t: MapType => Dictionary("ptMap" ->
          v(Dictionary(
            "ptKey" -> buildPassthroughFields(t.keyType),
            "ptValue" -> buildPassthroughFields(t.valueType)
          ))
        )
        case _ => Dictionary("ptPrimitive" -> v(true))
      }

    v(basePassthrough + overrides)
  }

  private def exceptionMsgFieldName(struct: StructLike): Option[SimpleID] = {
    val msgField: Option[Field] = struct.fields.find { field =>
      // 1st choice: find a field called message
      field.sid.name == "message"
    }.orElse {
      // 2nd choice: the first string field
      struct.fields.find {
        field => field.fieldType == TString
      }
    }

    msgField.map { _.sid }
  }

  def getSuccessType(result: FunctionResult): CodeFragment =
    result.success match {
      case Some(field) => genType(field.fieldType)
      case None => v("Unit")
    }

  def getSuccessValue(result: FunctionResult): CodeFragment =
    result.success match {
      case Some(field) => v("success")
      case None => v("Some(Unit)")
    }

  def getExceptionFields(result: FunctionResult): CodeFragment = {
    val exceptions = result.exceptions.map { field: Field => genID(field.sid).toData }.mkString(", ")
    v(s"Seq($exceptions)")
  }

  def structDict(
    struct: StructLike,
    namespace: Option[Identifier],
    includes: Seq[Include],
    serviceOptions: Set[ServiceOption],
    toplevel: Boolean = false // True if this struct is defined in its own file. False for internal structs.
  ): Dictionary = {
    val parentType = struct match {
      case e: Exception_ if (serviceOptions contains WithFinagle) =>
        "ThriftException with com.twitter.finagle.SourcedException with ThriftStruct"
      case e: Exception_ => "ThriftException with ThriftStruct"
      case u: Union => "ThriftUnion with ThriftStruct"
      case result: FunctionResult =>
        val resultType = getSuccessType(result)
        s"ThriftResponse[$resultType] with ThriftStruct"
      case _ => "ThriftStruct"
    }
    val arity = struct.fields.size

    val isStruct = struct.isInstanceOf[Struct]
    val isException = struct.isInstanceOf[Exception_]
    val isUnion = struct.isInstanceOf[Union]
    val isResponse = struct.isInstanceOf[FunctionResult]

    val exceptionMsgField: Option[SimpleID] =
      if (isException) exceptionMsgFieldName(struct) else None

    val fieldDictionaries = fieldsToDict(
      struct.fields,
      if (isException) Seq("message") else Nil,
      namespace
    )

    val structName = if (toplevel) genID(struct.sid.toTitleCase) else genID(struct.sid)

    Dictionary(
      "public" -> v(toplevel),
      "package" -> namespace.map(genID).getOrElse(v("")),
      "docstring" -> v(struct.docstring.getOrElse("")),
      "parentType" -> v(parentType),
      "fields" -> v(fieldDictionaries),
      "defaultFields" -> v(fieldsToDict(struct.fields.filter(!_.requiredness.isOptional), Nil)),
      "alternativeConstructor" -> v(
        struct.fields.exists(_.requiredness.isOptional)
        && struct.fields.exists(_.requiredness.isDefault)),
      "StructNameForWire" -> v(struct.originalName),
      "StructName" ->
        structName,
      "InstanceClassName" -> (if (isStruct) v("Immutable") else structName),
      "underlyingStructName" -> genID(struct.sid.prepend("_underlying_")),
      "arity" -> v(arity.toString),
      "isException" -> v(isException),
      "isResponse" -> v(isResponse),
      "hasExceptionMessage" -> v(exceptionMsgField.isDefined),
      "exceptionMessageField" -> exceptionMsgField.map(genID).getOrElse { v("")},
      "product" -> v(productN(struct.fields, namespace)),
      "arity0" -> v(arity == 0),
      "arity1" -> v((if (arity == 1) fieldDictionaries.take(1) else Nil)),
      "arityN" -> v(arity > 1 && arity <= 22),
      "withFieldGettersAndSetters" -> v(isStruct || isException),
      "withTrait" -> v(isStruct),
      "structAnnotations" -> StructTemplate.renderPairs(struct.annotations)
    )
  }
}

object StructTemplate {

  /**
   * Renders a map as:
   *   Dictionary("pairs" -> ListValue(Seq(Dictionary("key" -> ..., "value" -> ...)))
   */
  private def renderPairs(pairs: Map[String, String]): Dictionary.Value = {
    if (pairs.isEmpty) {
      NoValue
    } else {
      val pairDicts: Seq[Dictionary] =
        pairs.map { case (key, value) => Dictionary("key" -> v(key), "value" -> v(value)) }.toSeq
      v(Dictionary("pairs" -> v(pairDicts)))
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy