All Downloads are FREE. Search and download functionalities are using the official Maven repository.

co.cask.common.internal.io.ReflectionDatumWriter Maven / Gradle / Ivy

/*
 * Copyright © 2014 Cask Data, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package co.cask.common.internal.io;

import co.cask.common.io.Encoder;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.primitives.Longs;
import com.google.common.reflect.TypeToken;

import java.io.IOException;
import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

/**
 * A {@link DatumWriter} that uses java reflection to encode data. The encoding schema it uses is
 * the same as the binary encoding as specified in Avro, with the enhancement of support non-string
 * map key.
 *
 * @param  Type T to be written.
 */
public final class ReflectionDatumWriter implements DatumWriter {

  private final Schema schema;

  public ReflectionDatumWriter(Schema schema) {
    this.schema = schema;
  }

  public Schema getSchema() {
    return schema;
  }

  @Override
  public void encode(T data, Encoder encoder) throws IOException {
    Set seenRefs = Sets.newIdentityHashSet();
    write(data, encoder, schema, seenRefs);
  }

  private void write(Object object, Encoder encoder, Schema objSchema, Set seenRefs) throws IOException {
    if (object != null) {
      if (seenRefs.contains(object)) {
        throw new IOException("Circular reference not supported.");
      }
      if (objSchema.getType() == Schema.Type.RECORD) {
        seenRefs.add(object);
      }
    }

    switch(objSchema.getType()) {
      case NULL:
        encoder.writeNull();
        break;
      case BOOLEAN:
        encoder.writeBool((Boolean) object);
        break;
      case INT:
        encoder.writeInt(((Number) object).intValue());
        break;
      case LONG:
        encoder.writeLong(((Number) object).longValue());
        break;
      case FLOAT:
        encoder.writeFloat((Float) object);
        break;
      case DOUBLE:
        encoder.writeDouble((Double) object);
        break;
      case STRING:
        encoder.writeString(object.toString());
        break;
      case BYTES:
        writeBytes(object, encoder);
        break;
      case ENUM:
        writeEnum(object.toString(), encoder, objSchema);
        break;
      case ARRAY:
        writeArray(object, encoder, objSchema.getComponentSchema(), seenRefs);
        break;
      case MAP:
        writeMap(object, encoder, objSchema.getMapSchema(), seenRefs);
        break;
      case RECORD:
        writeRecord(object, encoder, objSchema, seenRefs);
        break;
      case UNION:
        // Assumption in schema generation that index 0 is the object type, index 1 is null.
        if (object == null) {
          encoder.writeInt(1);
        } else {
          seenRefs.remove(object);
          encoder.writeInt(0);
          write(object, encoder, objSchema.getUnionSchema(0), seenRefs);
        }
        break;
    }
  }

  private void writeBytes(Object object, Encoder encoder) throws IOException {
    if (object instanceof ByteBuffer) {
      encoder.writeBytes((ByteBuffer) object);
    } else if (object instanceof UUID) {
      UUID uuid = (UUID)  object;
      ByteBuffer buf = ByteBuffer.allocate(Longs.BYTES * 2);
      buf.putLong(uuid.getMostSignificantBits()).putLong(uuid.getLeastSignificantBits());
      encoder.writeBytes((ByteBuffer) buf.flip());
    } else {
      encoder.writeBytes((byte[]) object);
    }
  }

  private void writeEnum(String value, Encoder encoder, Schema schema) throws IOException {
    int idx = schema.getEnumIndex(value);
    if (idx < 0) {
      throw new IOException("Invalid enum value " + value);
    }
    encoder.writeInt(idx);
  }

  private void writeArray(Object array, Encoder encoder,
                          Schema componentSchema, Set seenRefs) throws IOException {
    int size = 0;
    if (array instanceof Collection) {
      Collection col = (Collection) array;
      encoder.writeInt(col.size());
      for (Object obj : col) {
        write(obj, encoder, componentSchema, seenRefs);
      }
      size = col.size();
    } else {
      size = Array.getLength(array);
      encoder.writeInt(size);
      for (int i = 0; i < size; i++) {
        write(Array.get(array, i), encoder, componentSchema, seenRefs);
      }
    }
    if (size > 0) {
      encoder.writeInt(0);
    }
  }

  private void writeMap(Object map, Encoder encoder, Map.Entry mapSchema,
                        Set seenRefs) throws IOException {
    Map objMap = (Map) map;
    int size = objMap.size();
    encoder.writeInt(size);
    for (Map.Entry entry : objMap.entrySet()) {
      write(entry.getKey(), encoder, mapSchema.getKey(), seenRefs);
      write(entry.getValue(), encoder, mapSchema.getValue(), seenRefs);
    }
    if (size > 0) {
      encoder.writeInt(0);
    }
  }

  private void writeRecord(Object record, Encoder encoder,
                           Schema recordSchema, Set seenRefs) throws IOException {
    try {
      TypeToken type = TypeToken.of(record.getClass());

      Map methods = collectByMethod(type, Maps.newHashMap());
      Map fields = collectByFields(type, Maps.newHashMap());

      for (Schema.Field field : recordSchema.getFields()) {
        String fieldName = field.getName();
        Object value;
        Field recordField = fields.get(fieldName);
        if (recordField != null) {
          recordField.setAccessible(true);
          value = recordField.get(record);
        } else {
          Method method = methods.get(fieldName);
          if (method == null) {
            throw new IOException("Unable to read field value through getter. Class=" + type + ", field=" + fieldName);
          }
          value = method.invoke(record);
        }

        Schema fieldSchema = field.getSchema();
        write(value, encoder, fieldSchema, seenRefs);
      }
    } catch (Exception e) {
      if (e instanceof IOException) {
        throw (IOException) e;
      }
      throw new IOException(e);
    }
  }

  private Map collectByFields(TypeToken typeToken, Map fields) {
    // Collect the field types
    for (TypeToken classType : typeToken.getTypes().classes()) {
      Class rawType = classType.getRawType();
      if (rawType.equals(Object.class)) {
        // Ignore all object fields
        continue;
      }

      for (Field field : rawType.getDeclaredFields()) {
        if (Modifier.isTransient(field.getModifiers()) || field.isSynthetic()) {
          continue;
        }
        fields.put(field.getName(), field);
      }
    }
    return fields;
  }

  private Map collectByMethod(TypeToken typeToken, Map methods) {
    for (Method method : typeToken.getRawType().getMethods()) {
      if (method.getDeclaringClass().equals(Object.class)) {
        // Ignore all object methods
        continue;
      }
      String methodName = method.getName();
      if (!(methodName.startsWith("get") || methodName.startsWith("is"))
           || method.isSynthetic() || method.getParameterTypes().length != 0) {
        // Ignore not getter methods
        continue;
      }
      String fieldName = methodName.startsWith("get") ?
                           methodName.substring("get".length()) : methodName.substring("is".length());
      if (fieldName.isEmpty()) {
        continue;
      }
      fieldName = String.format("%c%s", Character.toLowerCase(fieldName.charAt(0)), fieldName.substring(1));
      if (methods.containsKey(fieldName)) {
        continue;
      }
      methods.put(fieldName, method);
    }
    return methods;
  }
}