com.eharmony.aloha.semantics.compiled.plugin.proto.CompiledSemanticsProtoPlugin.scala Maven / Gradle / Ivy
package com.eharmony.aloha.semantics.compiled.plugin.proto
import com.eharmony.aloha.semantics.compiled.plugin.MorphableCompiledSemanticsPlugin
import scalaz.ValidationNel
import scalaz.syntax.validation.ToValidationV // scalaz.syntax.validation.ToValidationOps for latest scalaz
import com.google.protobuf.GeneratedMessage
import com.google.protobuf.Descriptors.{ FieldDescriptor, Descriptor }
import com.eharmony.aloha.semantics.compiled.{ OptionalAccessorCode, RequiredAccessorCode, VariableAccessorCode, CompiledSemanticsPlugin }
import com.eharmony.aloha.semantics.compiled.plugin.proto.accessor._
import com.eharmony.aloha.semantics.compiled.plugin.proto.codegen.CodeGenerators
import com.eharmony.aloha.semantics.compiled.plugin.proto.codegen.MapType._
import com.eharmony.aloha.reflect.{ RefInfo, RefInfoOps }
import com.eharmony.aloha.util.EitherHelpers
/**
*
* {{{
* // Idiomatic scala construction. Implicitly inject manifest.
* val scala = ProtoSemantics[UserProto]
*
* // idiomatic java construction. Construct manifest from Class object.
* Semantics java = new ProtoSemantics(UserProto.class);
*
* // Java construction via prototype. Construct manifest from Class object extracted from prototype instance.
* Semantics via = new ProtoSemantics(UserProto.getDefaultInstance());
* }}}
*
* Note, there are a bunch of equivalent ways to construct the semantics. See below:
*
* {{{
* scala> val lst = Seq(
* | ProtoSemantics[UserProto],
* | new ProtoSemantics[UserProto],
* | ProtoSemantics[UserProto](true),
* | new ProtoSemantics[UserProto](true),
* | new ProtoSemantics[UserProto](classOf[UserProto], true),
* | new ProtoSemantics(classOf[UserProto], true),
* | new ProtoSemantics(UserProto.getDefaultInstance, true)
* | )
* lst: Seq[ProtoSemantics[com.eharmony.matching.common.value.UserProtoBuffs.UserProto]] =
* List(ProtoSemantics(true), ProtoSemantics(true), ProtoSemantics(true), ProtoSemantics(true),
* ProtoSemantics(true), ProtoSemantics(true), ProtoSemantics(true))
*
* scala> lst sameElements Seq.fill(lst.size)(lst.head)
* res132: Boolean = true
* }}}
* @param dereferenceAsOptional Whether to treat dereferenced list variables as an Option. If '''true''' treat the
* return value of a dereference operation as an Option. This removes the possibility
* of a [[http://docs.oracle.com/javase/7/docs/api/java/lang/IndexOutOfBoundsException.html java.lang.IndexOutOfBoundsException]]
* being thrown. Instead, it will silently return None. If '''false''', treat the
* returned value as a required field and don't do any index checking. The default
* value is '''true'''.
* @param refInfoA
* @tparam A a type of generated protocol buffer message.
*/
case class CompiledSemanticsProtoPlugin[A <: GeneratedMessage](dereferenceAsOptional: Boolean = true)(implicit val refInfoA: RefInfo[A])
extends CompiledSemanticsPlugin[A]
with MorphableCompiledSemanticsPlugin
with EitherHelpers {
/**
* This type represents the separation of field accessors into 3 values:
* 1 the sequence of accessors before a repeated field accessor
* 1 the repeated (list) field accessor
* 1 the sequence of accessors after a repeated field accessor
* If no repeated accessor is present, then the sequence before is empty.
*/
private[this]type FieldAccessorPartition = (List[FieldAccessor], Option[Repeated], List[FieldAccessor])
/**
* This is a further breakdown of the List[FieldAccessor] values in FieldAccessorPartition.
* The first field is for flat maps, the second for maps and the last for required fields.
*/
private[this]type MappingPartition = (List[(Seq[Req], Opt)], Option[(Seq[Req], Opt)], List[Req])
/**
* Construct the semantics, given the message's Class object.
* @param classValue A Class object of the protocol buffer generated message. For instance, for Java
* {{{
* final Semantics semantics =
* new ProtoSemantics(UserProto.class, false);
* }}}
* @param dereferenceAsOptional Whether to treat dereferenced list variables as an Option.
*/
def this(classValue: Class[A], dereferenceAsOptional: Boolean) =
this(dereferenceAsOptional)(RefInfoOps fromSimpleClass classValue)
/**
* Construct the semantics, given a prototypical instance of the message type.
* @param prototype a prototype instance of the type of message to create. For instance, from Java:
* {{{
* final Semantics semantics =
* new ProtoSemantics(UserProto.getDefaultInstance(), false);
* }}}
* @param dereferenceAsOptional Whether to treat dereferenced list variables as an Option.
*/
def this(prototype: A, dereferenceAsOptional: Boolean) = this(dereferenceAsOptional)({
require(prototype != null, "prototype value cannot be null.")
RefInfoOps fromSimpleClass prototype.getClass
})
/**
* Construct the semantics, given the message's Class object.
* @param classValue A Class object of the protocol buffer generated message. For instance, for Java
* {{{
* final Semantics semantics =
* new ProtoSemantics(UserProto.class);
* }}}
*/
def this(classValue: Class[A]) = this()(RefInfoOps fromSimpleClass classValue)
/**
* Construct the semantics, given a prototypical instance of the message type.
* @param prototype a prototype instance of the type of message to create. For instance, from Java:
* {{{
* final Semantics semantics =
* new ProtoSemantics(UserProto.getDefaultInstance());
* }}}
*/
def this(prototype: A) = this()({
require(prototype != null, "prototype value cannot be null.")
RefInfoOps fromSimpleClass prototype.getClass
})
require(refInfoA != null, "Implicit RefInfo refInfoA cannot be null.")
/**
* The descriptor associated with type A. This class needs to be serialized to work with Spark. Because
* of this descriptor must be serializable, however it is not in protobuf (outside of our code). This will
* say "memoize this variable when possible, but if it's not possible don't worry about it."
*/
@transient private lazy val descriptor: ValidationNel[String, Descriptor] =
toValidationNel(RefInfoOps.execStaticNoArgFunc[A]("getDescriptor")).map(_.asInstanceOf[Descriptor])
//def descriptor() = toValidationNel(RefInfoOps.execStaticNoArgFunc[A]("getDescriptor")).map(_.asInstanceOf[Descriptor])
/**
* The string representation of the function arguments.
*/
private[proto] val functionParamList = "(_0: " + inputTypeString + ") => "
/**
* Generate the function body given a spec.
* @param spec a String specification of the feature for which function code should be generated
* @return
*/
def accessorFunctionCode(spec: String): Either[Seq[String], VariableAccessorCode] = {
import DescriptorPimpz.PimpedDescriptor
val function = for {
d <- descriptor
tokens <- ProtobufTokenizer.getTokens(spec)
subfields <- d.subfields(createSpec(tokens))
fap <- getFieldAccessorPartition(subfields, tokens)
} yield generateFunction(fap)
function.toEither.left.map(l => l.head :: l.tail) // Don't want to rely on scalaz for outward facing APIs.
}
/**
* Create a spec for use in extraction of Protocol Buffer Descriptor objects. This is just the subsequence
* of Field tokens converted into their name in the Protocol Buffer descriptor.
* @param tokens a list of tokens
* @return
*/
private[proto] def createSpec(tokens: Seq[Token]) = tokens collect { case Field(f) => f } mkString "."
/**
* Convert, when appropriate, the leading dereferenced repeated fields preceding a non-dereferenced repeated field.
* If we don't want to treat dereferenced fields as optional, then we aren't already treating the field accessors
* before the repeated field as optional. Since we always want to an non-dereferenced repeated field to produce
* a list, we don't want to err due to an out-of-bounds exception. So we turn leading dereferenced repeated
* fields to optional, even if ''dereferenceAsOptional'' is false.
*
* @param fa a list of field accessors that precede a non-dereferenced repeated field.
* @return
*/
private[proto] def convertLeadingFieldAccessors(fa: List[FieldAccessor]) =
if (dereferenceAsOptional) fa
else fa map {
case d: DerefReq => d.toOpt
case f => f
}
/**
* Produce a field accessor that is used to dereference a repeated field.
* @param field a protocol buffer field descriptor representing a repeated field.
* @param index an index into the repeated field's list.
* @return
*/
private[proto] def dereferencedRepeatedField(field: FieldDescriptor, index: Int) =
if (dereferenceAsOptional) DerefOpt(field, index) else DerefReq(field, index)
/**
* Add fd to fa after optionally transforming fa.
* @param fd a field descriptor
* @param fa a list of field accessors
* @return
*/
private[proto] def directlyAccessedField(fd: FieldDescriptor, fa: List[FieldAccessor]) =
if (fd.isRequired) Required(fd) :: fa
else if (fd.isOptional) Optional(fd) :: fa
else Repeated(fd) :: convertLeadingFieldAccessors(fa)
/**
* Produce a FieldAccessorPartition. This is:
* 1 the list of fields preceding a non-dereferenced repeated field
* 1 an optional non-dereferenced repeated field
* 1 the list of fields following a non-dereferenced repeated field
* For more information, see the documentation for the type declaration.
* @param fa list of field accessors in reverse order.
*/
private[proto] def partition(fa: List[FieldAccessor]) = {
def g(l: List[FieldAccessor], before: List[FieldAccessor], repeated: Option[Repeated], after: List[FieldAccessor]): FieldAccessorPartition = l match {
case Nil => (before, repeated, after)
case (h: Repeated) :: t => (t.reverse, Some(h), after)
case h :: t => g(t, before, repeated, h :: after)
}
g(fa, Nil, None, Nil)
}
/**
* Get the mapping partition. This includes the chains of variables that need to be:
* 1. flat mapped
* 1. mapped
* 1. added at the end of the statement
* These correspond to the fields in the returned tuple.
* @param accessors A list of field accessors
*/
private[proto] def determineMappingPartition(accessors: List[FieldAccessor]): MappingPartition = {
// At the conclusion of the fold, each item in chains contains a chain of required variables and one
// optional variable. req contains the sequence of required variables following the last chain. Note that
// both of these variables are in reverse order and need to be reversed prior to returning.
val (chains, req) = accessors.foldLeft((List.empty[(Seq[Req], Opt)], List.empty[Req])) {
case ((chains, req), r: Req) => (chains, r :: req)
case ((chains, req), o: Opt) => ((req.reverse, o) :: chains, Nil)
case (p, _) => p
}
// chains to be flat mapped, optional chain to mapped, and sequence of required variables to be
// added to the end of the operation.
(chains.drop(1).reverse, chains.headOption, req.reverse)
}
/**
* Generate a function
*
* @param p
* @return
*/
private[proto] def generateFunction(p: FieldAccessorPartition) = {
val (before, repeated, after) = p
val b = determineMappingPartition(before)
val a = determineMappingPartition(after)
generateFunctionHelper(b, repeated, a)
}
/**
* Get a list of lines in the function. This includes function signature but not the name or return type.
* @param before
* @param repeated
* @param after
* @return
*/
private[proto] def generateFunctionHelper(before: MappingPartition, repeated: Option[Repeated], after: MappingPartition) = {
val (bfm, bm, br) = before
val (afm, am, ar) = after
// Generate any flat mapped elements before the appearance of the repeated element.
val bfms = (1 to bfm.size).zip(bfm) map { case (i, (r, o)) => CodeGenerators.containerCodeGen(r, o, i, FLAT_MAP) }
// Generate the 0 or 1 mapped element before the appearance of the repeated element.
val bms = bm map { case (r, o) => CodeGenerators.containerCodeGen(r, o, bfms.size + 1, MAP) }
// Generate the 0 or 1 repeated elements.
val mapList = Seq(afm, am.toSeq, ar).foldLeft(false)(_ || _.size > 0)
val rep = repeated map { case r => CodeGenerators.containerCodeGen(br, r, bfms.size + bms.size + 1, if (mapList) MAP else NONE) }
// Generate any mapped elements appearing after the repeated element.
val afmi = Seq.range(0, afm.size) map { _ + bfms.size + bms.size + rep.size + 1 }
val afms = afmi.zip(afm) map { case (i, (r, o)) => CodeGenerators.containerCodeGen(r, o, i, FLAT_MAP) }
// Generate the 0 or 1 mapped elements appearing after the repeated element.
val ami = Seq.range(0, am.size) map { _ + bfms.size + bms.size + rep.size + afm.size + 1 }
val ams = ami.zip(am) map { case (i, (r, o)) => CodeGenerators.containerCodeGen(r, o, i, if (ar.nonEmpty) MAP else NONE) }
// Generate the required elements appearing after the repeated element.
val ari = bfms.size + bms.size + rep.size + afm.size + ams.size + 1
val ars = Option(ar.nonEmpty).filter(identity).map(_ => CodeGenerators.NoSuffixCodeGen.unit(ar, ari))
// Assemble all the generated lines of code. Determine the appropriate number of right parentheses
// (number of generated lines minus 1) and add to the last line. Finally, if the code produces optional
// data before the repeated element, then it would result in one of the following types:
// Option[Seq[A]] or Option[Seq[Option[A]]]. Because we want to avoid the outermost Option, we map None
// to an empty sequence.
val r = bfms ++ bms ++ rep ++ afms ++ ams ++ ars
val optBeforeList = hasOptionalStuff(before)
val optional = repeated.isEmpty && hasOptionalStuff(after)
val lastLineSuffix = rightParenthesize(r) + (if (optBeforeList) ".getOrElse(Nil)" else "")
val lines = r.dropRight(1) ++ r.lastOption.map(_ + lastLineSuffix)
// Make sure that we have the implicit function imported for converting a Java List to a Scala Buffer.
val finalLines =
if (repeated.nonEmpty) Seq(functionParamList + "{", " import scala.collection.JavaConversions.asScalaBuffer;") ++ lines ++ Seq("}")
else Seq(functionParamList) ++ lines
if (optional) OptionalAccessorCode(finalLines) else RequiredAccessorCode(finalLines)
}
/**
* Given a mapping partition, determine if it will generate optional (Option) data types.
* @param pm
* @return
*/
private[proto] def hasOptionalStuff(pm: MappingPartition) = Seq(pm._1, pm._2.toSeq).foldLeft(0)(_ + _.size) > 0
/**
* Generate the appropriate number of right parentheses for the function body. This is equal to the number of
* lines minus 1
*
* @param a
* @return
*/
private[proto] def rightParenthesize(a: Seq[_]) = Seq.fill(a.size - 1)(")").mkString("")
/**
* Find errors in the specification of the feature.
* @param descriptors the sequence of protocol buffer field descriptors that represents the subsequence of
* all Field tokens in the tokens variable.
* @param tokens a list of tokens
* @return
*/
private[proto] def getFieldAccessorPartition(descriptors: List[FieldDescriptor], tokens: List[Token]) = {
def g(d: List[FieldDescriptor], remaining: List[Token], consumed: List[Token], numLists: Int, fa: List[FieldAccessor]): ValidationNel[String, FieldAccessorPartition] = d match {
case Nil => partition(fa).success
case dh :: dt => remaining match {
case (f: Field) :: (i: Index) :: t =>
if (!dh.isRepeated) err(i :: f :: consumed, "The field is not repeated so it cannot be dereferenced.")
else g(dt, t, i :: f :: consumed, numLists, dereferencedRepeatedField(dh, i.index) :: fa)
case (f: Field) :: t =>
val nl = numLists + (if (dh.isRepeated) 1 else 0)
if (nl > 1) err(f :: consumed, "Too many list levels produced. Limit 1.")
else g(dt, t, f :: consumed, nl, directlyAccessedField(dh, fa))
case _ => err(remaining.headOption.toList ::: consumed, "This should never happen!")
}
}
g(descriptors, tokens, Nil, 0, Nil)
}
/**
* Produce an error message for use in the getFieldAccessorPartition function.
* @param consumed the tokens already consumed
* @param addlMsg any additional message to add to the error that will be returned.
* @return
*/
private[proto] def err(consumed: List[Token], addlMsg: String = ""): ValidationNel[String, Nothing] = {
val problem = consumed.reverse.map({ case Field(f) => f; case Index(i) => "[" + i + "]" }).mkString(".").replaceAll("""\.\[""", "[")
("Problem found at: '" + problem + "'. " + addlMsg).trim.failNel
}
override def morph[B](implicit ri: RefInfo[B]): Option[CompiledSemanticsPlugin[B]] = {
Option(this) collect {
case CompiledSemanticsProtoPlugin(deref) if RefInfoOps.isSubType(ri, RefInfo[GeneratedMessage]) =>
// TODO: Attempt to remove these horrible casts.
// It's known by the IF condition above that this is true.
// Can implicit evidence somehow be provided instead?
val castedRefInfo = ri.asInstanceOf[RefInfo[GeneratedMessage]]
// TODO: Remove commented code after getting SBT build working.
CompiledSemanticsProtoPlugin(deref)(castedRefInfo).asInstanceOf[CompiledSemanticsPlugin[B]]
}
}
}
object CompiledSemanticsProtoPlugin {
def apply[A <: GeneratedMessage: RefInfo]: CompiledSemanticsProtoPlugin[A] = new CompiledSemanticsProtoPlugin
object Implicits {
implicit def protoSemantics[A <: GeneratedMessage: RefInfo]: CompiledSemanticsProtoPlugin[A] = apply[A]
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy