All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.twitter.elephantbird.util.Protobufs Maven / Gradle / Ivy

The newest version!
package com.twitter.elephantbird.util;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;

import com.google.common.base.Function;
import com.google.common.base.Predicates;
import com.google.common.collect.Lists;
import com.google.protobuf.DescriptorProtos.DescriptorProto;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.DescriptorValidationException;
import com.google.protobuf.Descriptors.EnumValueDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor.Type;
import com.google.protobuf.Message.Builder;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.DescriptorProtos;
import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.ProtocolMessageEnum;
import com.google.protobuf.UninitializedMessageException;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Protobufs {
  private static final Logger LOG = LoggerFactory.getLogger(Protobufs.class);

  public static final byte[] KNOWN_GOOD_POSITION_MARKER = new byte[] { 0x29, (byte)0xd8, (byte)0xd5, 0x06, 0x58,
                                                                         (byte)0xcd, 0x4c, 0x29, (byte)0xb2,
                                                                         (byte)0xbc, 0x57, (byte)0x99, 0x21, 0x71,
                                                                         (byte)0xbd, (byte)0xff };

  public static final byte NEWLINE_UTF8_BYTE = '\n';
  public static final byte[] NEWLINE_UTF8_BYTES = new byte[]{NEWLINE_UTF8_BYTE};

  public static final String IGNORE_KEY = "IGNORE";

  private static final String CLASS_CONF_PREFIX = "elephantbird.protobuf.class.for.";

  /**
   * Returns Protobuf class. The class name could be either normal name or
   * its canonical name (canonical name does not have a $ for inner classes).
   */
  public static Class getProtobufClass(String protoClassName) {
    return getProtobufClass(null, protoClassName);
  }

  private static Class getProtobufClass(Configuration conf, String protoClassName) {
    // Try both normal name and canonical name of the class.
    Class protoClass = null;
    try {
      if (conf == null) {
        protoClass = Class.forName(protoClassName);
      } else {
        protoClass = conf.getClassByName(protoClassName);
      }
    } catch (ClassNotFoundException e) {
      // the class name might be canonical name.
      protoClass = getInnerProtobufClass(protoClassName);
    }

    return protoClass.asSubclass(Message.class);
  }

  /**
   * For a configured protoClass, should the message be dynamic or is it a pre-generated Message class? If protoClass is
   * null or set to DynamicMessage.class, then the configurer intends for a dynamically generated protobuf to be used.
   */
  public static boolean useDynamicProtoMessage(Class protoClass) {
    return protoClass == null || protoClass.getCanonicalName().equals(DynamicMessage.class.getCanonicalName());
  }

  public static Class getInnerProtobufClass(String canonicalClassName) {
    // is an inner class and is not visible from the outside.  We have to instantiate
    String parentClass = canonicalClassName.substring(0, canonicalClassName.lastIndexOf("."));
    String subclass = canonicalClassName.substring(canonicalClassName.lastIndexOf(".") + 1);
    return getInnerClass(parentClass, subclass);
  }

  public static Class getInnerClass(String canonicalParentName, String subclassName) {
    try {
      Class outerClass = Class.forName(canonicalParentName);
      for (Class innerClass: outerClass.getDeclaredClasses()) {
        if (innerClass.getSimpleName().equals(subclassName)) {
          return innerClass.asSubclass(Message.class);
        }
      }
    } catch (ClassNotFoundException e) {
      LOG.error("Could not find class with parent " + canonicalParentName + " and inner class " + subclassName, e);
      throw new IllegalArgumentException(e);
    }
    return null;
  }

  public static Message.Builder getMessageBuilder(Class protoClass) {
    try {
      Method newBuilder = protoClass.getMethod("newBuilder", new Class[] {});
      return (Message.Builder) newBuilder.invoke(null, new Object[] {});
    } catch (NoSuchMethodException e) {
      LOG.error("Could not find method newBuilder in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    } catch (IllegalAccessException e) {
      LOG.error("Could not access method newBuilder in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    } catch (InvocationTargetException e) {
      LOG.error("Error invoking method newBuilder in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    }
  }

  public static Descriptor getMessageDescriptor(Class protoClass) {
    try {
      Method getDescriptor = protoClass.getMethod("getDescriptor", new Class[] {});
      return (Descriptor)getDescriptor.invoke(null, new Object[] {});
    } catch (NoSuchMethodException e) {
      LOG.error("Could not find method getDescriptor in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    } catch (IllegalAccessException e) {
      LOG.error("Could not access method getDescriptor in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    } catch (InvocationTargetException e) {
      LOG.error("Error invoking method getDescriptor in class " + protoClass, e);
    }

    return null;
  }

  public static List getMessageFieldNames(Class protoClass) {
    return Lists.transform(getMessageDescriptor(protoClass).getFields(), new Function() {
      public String apply(FieldDescriptor f) {
        return f.getName();
      }
    });
  }

  /**
   * Returns a Message {@link Descriptor} for a dynamically generated
   * DescriptorProto.
   *
   * @param descProto
   * @throws DescriptorValidationException
   */
  public static Descriptor makeMessageDescriptor(DescriptorProto descProto)
                                      throws DescriptorValidationException {

    DescriptorProtos.FileDescriptorProto fileDescP =
      DescriptorProtos.FileDescriptorProto.newBuilder().addMessageType(descProto).build();

    Descriptors.FileDescriptor[] fileDescs = new Descriptors.FileDescriptor[0];
    Descriptors.FileDescriptor dynamicDescriptor = Descriptors.FileDescriptor.buildFrom(fileDescP, fileDescs);
    return dynamicDescriptor.findMessageTypeByName(descProto.getName());
  }

   // Translates from protobuf field names to other field names, stripping out any that have value IGNORE.
   // i.e. if you have a protobuf with field names a, b, c and your fieldNameTranslations map looks like
   // { "a" => "A", "c" => "IGNORE" }, then you'll end up with A, b
  public static List getMessageFieldNames(Class protoClass, Map fieldNameTranslations) {
    return getMessageFieldNames(getMessageDescriptor(protoClass), fieldNameTranslations);
  }

   public static List getMessageFieldNames(Descriptor descriptor, Map fieldNameTranslations) {
     Function fieldTransformer = getFieldTransformerFor(fieldNameTranslations);
     return ListHelper.filter(Lists.transform(descriptor.getFields(), fieldTransformer), Predicates.notNull());
   }

   public static Function getFieldTransformerFor(final Map fieldNameTranslations) {
     return new Function() {
       public String apply(FieldDescriptor f) {
         String name = f.getName();
         if (fieldNameTranslations != null && fieldNameTranslations.containsKey(name)) {
           name = fieldNameTranslations.get(name);
         }
         if (IGNORE_KEY.equals(name)) {
           name = null;
         }
         return name;
       }
     };
   }


  @SuppressWarnings("unchecked")
  public static  M parseFrom(Class protoClass, byte[] messageBytes) {
    try {
      Method parseFrom = protoClass.getMethod("parseFrom", new Class[] { byte[].class });
      return (M)parseFrom.invoke(null, new Object[] { messageBytes });
    } catch (NoSuchMethodException e) {
      LOG.error("Could not find method parseFrom in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    } catch (IllegalAccessException e) {
      LOG.error("Could not access method parseFrom in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    } catch (InvocationTargetException e) {
      LOG.error("Error invoking method parseFrom in class " + protoClass, e);
    }

    return null;
  }

  public static DynamicMessage parseDynamicFrom(Class protoClass, byte[] messageBytes) {
    try {
      return DynamicMessage.parseFrom(getMessageDescriptor(protoClass), messageBytes);
    } catch (InvalidProtocolBufferException e) {
      LOG.error("Protocol buffer parsing error in parseDynamicFrom", e);
    } catch(UninitializedMessageException ume) {
      LOG.error("Uninitialized Message error in parseDynamicFrom " + protoClass.getName(), ume);
    }

    return null;
  }

  public static Message instantiateFromClassName(String canonicalClassName) {
    Class protoClass = getInnerProtobufClass(canonicalClassName);
    Message.Builder builder = getMessageBuilder(protoClass);
    return builder.build();
  }

  public static  M instantiateFromClass(Class protoClass) {
    return Protobufs.parseFrom(protoClass, new byte[] {});
  }

  public static Message addField(Message m, String name, Object value) {
    Message.Builder builder = m.toBuilder();
    setFieldByName(builder, name, value);
    return builder.build();
  }

  public static void setFieldByName(Message.Builder builder, String name, Object value) {
    FieldDescriptor fieldDescriptor = builder.getDescriptorForType().findFieldByName(name);
    if (value == null) {
      builder.clearField(fieldDescriptor);
    } else {
      builder.setField(fieldDescriptor, value);
    }
  }

  public static boolean isFieldSetByName(Message message, String name) {
    return message.hasField(message.getDescriptorForType().findFieldByName(name));
  }

  public static Object getFieldByName(Message message, String name) {
    return message.getField(message.getDescriptorForType().findFieldByName(name));
  }

  public static boolean hasFieldByName(Message message, String name) {
    return message.getDescriptorForType().findFieldByName(name) != null;
  }

  public static Type getTypeByName(Message message, String name) {
    return message.getDescriptorForType().findFieldByName(name).getType();
  }

  /**
   * Returns typeref for a Protobuf class
   */
  public static TypeRef getTypeRef(String protoClassName) {
    return new TypeRef(getProtobufClass(protoClassName)){};
  }

  /**
   * Returns TypeRef for the Protobuf class that was set using setClass(jobConf);
   */
  public static TypeRef getTypeRef(Configuration jobConf, Class genericClass) {
    String className = jobConf.get(CLASS_CONF_PREFIX + genericClass.getName());
    if (className == null) {
      throw new RuntimeException(CLASS_CONF_PREFIX + genericClass.getName() + " is not set");
    }
    return new TypeRef(getProtobufClass(jobConf, className)){};
  }

  public static void setClassConf(Configuration jobConf, Class genericClass,
      Class protoClass) {
    HadoopUtils.setClassConf(jobConf,
           CLASS_CONF_PREFIX + genericClass.getName(),
           protoClass);
  }

  public static Text toText(Message message) {
    return new Text(message.toByteArray());
  }

  public static  B mergeFromText(B builder, Text bytes)
                                    throws InvalidProtocolBufferException {
    @SuppressWarnings("unchecked")
    B b = (B) builder.mergeFrom(bytes.getBytes());
    return b;
  }

  /**
   * Serializes a single field. All the native fields are serialized using
   * "NoTag" methods in {@link CodedOutputStream}
   * e.g. writeInt32NoTag(). The field index is not written.
   *
   * @param output
   * @param fd
   * @param value
   * @throws IOException
   */
  public static void writeFieldNoTag(CodedOutputStream   output,
                                     FieldDescriptor     fd,
                                     Object              value)
                                     throws IOException {
    if (value == null) {
      return;
    }

    if (fd.isRepeated()) {
      @SuppressWarnings("unchecked")
      List values = (List) value;
      for(Object obj : values) {
        writeSingleFieldNoTag(output, fd, obj);
      }
    } else {
      writeSingleFieldNoTag(output, fd, value);
    }
  }

  private static void writeSingleFieldNoTag(CodedOutputStream   output,
                                            FieldDescriptor     fd,
                                            Object              value)
                                            throws IOException {
    switch (fd.getType()) {
    case DOUBLE:
      output.writeDoubleNoTag((Double) value);    break;
    case FLOAT:
      output.writeFloatNoTag((Float) value);      break;
    case INT64:
    case UINT64:
      output.writeInt64NoTag((Long) value);       break;
    case INT32:
      output.writeInt32NoTag((Integer) value);    break;
    case FIXED64:
      output.writeFixed64NoTag((Long) value);     break;
    case FIXED32:
      output.writeFixed32NoTag((Integer) value);  break;
    case BOOL:
      output.writeBoolNoTag((Boolean) value);     break;
    case STRING:
      output.writeStringNoTag((String) value);    break;
    case GROUP:
    case MESSAGE:
      output.writeMessageNoTag((Message) value);  break;
    case BYTES:
      output.writeBytesNoTag((ByteString) value); break;
    case UINT32:
      output.writeUInt32NoTag((Integer) value);   break;
    case ENUM:
      output.writeEnumNoTag(((ProtocolMessageEnum) value).getNumber()); break;
    case SFIXED32:
      output.writeSFixed32NoTag((Integer) value); break;
    case SFIXED64:
      output.writeSFixed64NoTag((Long) value);    break;
    case SINT32:
      output.writeSInt32NoTag((Integer) value);   break;
    case SINT64:
      output.writeSInt64NoTag((Integer) value);   break;

    default:
      throw new IllegalArgumentException("Unknown type " + fd.getType()
                                         + " for " + fd.getFullName());
    }
  }

  /**
   *  Wrapper around {@link #readFieldNoTag(CodedInputStream, FieldDescriptor, Builder)}. 
* same as
* builder.setField(fd, readFieldNoTag(input, fd, builder)); */ public static void setFieldValue(CodedInputStream input, FieldDescriptor fd, Builder builder) throws IOException { builder.setField(fd, readFieldNoTag(input, fd, builder)); } /** * Deserializer for protobuf objects written with * {@link #writeFieldNoTag(CodedOutputStream, FieldDescriptor, Object)}. * * @param input * @param fd * @param enclosingBuilder required to create a builder when the field * is a message * @throws IOException */ public static Object readFieldNoTag(CodedInputStream input, FieldDescriptor fd, Builder enclosingBuilder) throws IOException { if (!fd.isRepeated()) { return readSingleFieldNoTag(input, fd, enclosingBuilder); } // repeated field List values = Lists.newArrayList(); while (!input.isAtEnd()) { values.add(readSingleFieldNoTag(input, fd, enclosingBuilder)); } return values; } private static Object readSingleFieldNoTag(CodedInputStream input, FieldDescriptor fd, Builder enclosingBuilder) throws IOException { switch (fd.getType()) { case DOUBLE: return input.readDouble(); case FLOAT: return input.readFloat(); case INT64: case UINT64: return input.readInt64(); case INT32: return input.readInt32(); case FIXED64: return input.readFixed64(); case FIXED32: return input.readFixed32(); case BOOL: return input.readBool(); case STRING: return input.readString(); case GROUP: case MESSAGE: Builder fieldBuilder = enclosingBuilder.newBuilderForField(fd); input.readMessage(fieldBuilder, null); return fieldBuilder.build(); case BYTES: return input.readBytes(); case UINT32: return input.readUInt32(); case ENUM: EnumValueDescriptor eVal = fd.getEnumType().findValueByNumber(input.readEnum()); // ideally if a given enum does not exist, we should search // unknown fields. but we don't have access to that here. return default return eVal != null ? eVal : fd.getDefaultValue(); case SFIXED32: return input.readSFixed32(); case SFIXED64: return input.readSFixed64(); case SINT32: return input.readSInt32(); case SINT64: return input.readSInt64(); default: throw new IllegalArgumentException("Unknown type " + fd.getType() + " for " + fd.getFullName()); } } }