All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.facebook.swift.codec.internal.compiler.ThriftCodecByteCodeGenerator Maven / Gradle / Ivy

/*
 * Copyright (C) 2012 Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License. You may obtain
 * a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package com.facebook.swift.codec.internal.compiler;

import com.facebook.swift.codec.ThriftCodec;
import com.facebook.swift.codec.ThriftCodecManager;
import com.facebook.swift.codec.ThriftProtocolType;
import com.facebook.swift.codec.internal.TProtocolReader;
import com.facebook.swift.codec.internal.TProtocolWriter;
import com.facebook.swift.codec.internal.compiler.byteCode.CaseStatement;
import com.facebook.swift.codec.internal.compiler.byteCode.ClassDefinition;
import com.facebook.swift.codec.internal.compiler.byteCode.FieldDefinition;
import com.facebook.swift.codec.internal.compiler.byteCode.LocalVariableDefinition;
import com.facebook.swift.codec.internal.compiler.byteCode.MethodDefinition;
import com.facebook.swift.codec.internal.compiler.byteCode.NamedParameterDefinition;
import com.facebook.swift.codec.internal.compiler.byteCode.ParameterizedType;
import com.facebook.swift.codec.metadata.DefaultThriftTypeReference;
import com.facebook.swift.codec.metadata.FieldKind;
import com.facebook.swift.codec.metadata.ReflectionHelper;
import com.facebook.swift.codec.metadata.ThriftConstructorInjection;
import com.facebook.swift.codec.metadata.ThriftExtraction;
import com.facebook.swift.codec.metadata.ThriftFieldExtractor;
import com.facebook.swift.codec.metadata.ThriftFieldInjection;
import com.facebook.swift.codec.metadata.ThriftFieldMetadata;
import com.facebook.swift.codec.metadata.ThriftInjection;
import com.facebook.swift.codec.metadata.ThriftMethodExtractor;
import com.facebook.swift.codec.metadata.ThriftMethodInjection;
import com.facebook.swift.codec.metadata.ThriftParameterInjection;
import com.facebook.swift.codec.metadata.ThriftStructMetadata;
import com.facebook.swift.codec.metadata.ThriftType;
import com.facebook.swift.codec.metadata.ThriftTypeReference;
import com.google.common.base.Function;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.reflect.TypeToken;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import org.apache.thrift.protocol.TProtocol;
import com.facebook.swift.codec.$internal.asm.ClassReader;
import com.facebook.swift.codec.$internal.asm.ClassWriter;
import com.facebook.swift.codec.$internal.asm.util.CheckClassAdapter;

import javax.annotation.concurrent.NotThreadSafe;

import java.io.PrintWriter;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

import static com.facebook.swift.codec.ThriftProtocolType.BINARY;
import static com.facebook.swift.codec.ThriftProtocolType.BOOL;
import static com.facebook.swift.codec.ThriftProtocolType.BYTE;
import static com.facebook.swift.codec.ThriftProtocolType.DOUBLE;
import static com.facebook.swift.codec.ThriftProtocolType.ENUM;
import static com.facebook.swift.codec.ThriftProtocolType.I16;
import static com.facebook.swift.codec.ThriftProtocolType.I32;
import static com.facebook.swift.codec.ThriftProtocolType.I64;
import static com.facebook.swift.codec.ThriftProtocolType.LIST;
import static com.facebook.swift.codec.ThriftProtocolType.MAP;
import static com.facebook.swift.codec.ThriftProtocolType.SET;
import static com.facebook.swift.codec.ThriftProtocolType.STRING;
import static com.facebook.swift.codec.ThriftProtocolType.STRUCT;
import static com.facebook.swift.codec.internal.compiler.byteCode.Access.BRIDGE;
import static com.facebook.swift.codec.internal.compiler.byteCode.Access.FINAL;
import static com.facebook.swift.codec.internal.compiler.byteCode.Access.PRIVATE;
import static com.facebook.swift.codec.internal.compiler.byteCode.Access.PUBLIC;
import static com.facebook.swift.codec.internal.compiler.byteCode.Access.SUPER;
import static com.facebook.swift.codec.internal.compiler.byteCode.Access.SYNTHETIC;
import static com.facebook.swift.codec.internal.compiler.byteCode.Access.a;
import static com.facebook.swift.codec.internal.compiler.byteCode.CaseStatement.caseStatement;
import static com.facebook.swift.codec.internal.compiler.byteCode.NamedParameterDefinition.arg;
import static com.facebook.swift.codec.internal.compiler.byteCode.ParameterizedType.type;
import static com.facebook.swift.codec.metadata.FieldKind.THRIFT_FIELD;
import static com.facebook.swift.codec.metadata.FieldKind.THRIFT_UNION_ID;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.lang.String.format;

@NotThreadSafe
public class ThriftCodecByteCodeGenerator
{
    private static final String PACKAGE = "$wift";

    private static final Map READ_METHODS;
    private static final Map WRITE_METHODS;

    private static final Map ARRAY_READ_METHODS;
    private static final Map ARRAY_WRITE_METHODS;

    private final ThriftCodecManager codecManager;
    private final ThriftStructMetadata metadata;
    private final ParameterizedType structType;
    private final ParameterizedType codecType;

    private final ClassDefinition classDefinition;

    private final ConstructorParameters parameters = new ConstructorParameters();

    private final FieldDefinition typeField;
    private final Map codecFields;

    private final ThriftCodec thriftCodec;

    @SuppressWarnings("unchecked")
    @SuppressFBWarnings("DM_DEFAULT_ENCODING")
    public ThriftCodecByteCodeGenerator(
            ThriftCodecManager codecManager,
            ThriftStructMetadata metadata,
            DynamicClassLoader classLoader,
            boolean debug
    )
    {
        this.codecManager = codecManager;
        this.metadata = metadata;

        structType = type(metadata.getStructClass());
        codecType = toCodecType(metadata);

        classDefinition = new ClassDefinition(
                a(PUBLIC, SUPER),
                codecType.getClassName(),
                type(Object.class),
                type(ThriftCodec.class, structType)
        );

        // declare the class fields
        typeField = declareTypeField();
        codecFields = declareCodecFields();

        // declare methods
        defineConstructor();
        defineGetTypeMethod();

        switch (metadata.getMetadataType()) {
            case STRUCT:
                defineReadStructMethod();
                defineWriteStructMethod();
                break;
            case UNION:
                defineReadUnionMethod();
                defineWriteUnionMethod();
                break;
            default:
                throw new IllegalStateException(format("encountered type %s", metadata.getMetadataType()));
        }

        // add the non-generic bridge read and write methods
        defineReadBridgeMethod();
        defineWriteBridgeMethod();

        // generate the byte code
        ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
        classDefinition.getClassNode().accept(cw);
        byte[] byteCode = cw.toByteArray();

        // Run the asm verifier only in debug mode (prints a ton of info)
        if (debug) {
            ClassReader reader = new ClassReader(byteCode);
            CheckClassAdapter.verify(reader, classLoader, true, new PrintWriter(System.out));
        }

        // load the class
        Class codecClass = classLoader.defineClass(codecType.getClassName().replace('/', '.'), byteCode);
        try {
            Class[] types = parameters.getTypes();
            Constructor constructor = codecClass.getConstructor(types);
            thriftCodec = (ThriftCodec) constructor.newInstance(parameters.getValues());
        }
        catch (Exception e) {
            throw new IllegalStateException("Generated class is invalid", e);
        }
    }

    public ThriftCodec getThriftCodec()
    {
        return thriftCodec;
    }

    /**
     * Declares the private ThriftType field type.
     */
    private FieldDefinition declareTypeField()
    {
        FieldDefinition typeField = new FieldDefinition(a(PRIVATE, FINAL), "type", type(ThriftType.class));
        classDefinition.addField(typeField);

        // add constructor parameter to initialize this field
        parameters.add(typeField, ThriftType.struct(metadata));

        return typeField;
    }

    /**
     * Declares a field for each delegate codec
     *
     * @return a map from field id to the codec for the field
     */
    private Map declareCodecFields()
    {
        Map codecFields = new TreeMap<>();
        for (ThriftFieldMetadata fieldMetadata : metadata.getFields()) {
            if (needsCodec(fieldMetadata)) {

                ThriftCodec codec = codecManager.getCodec(fieldMetadata.getThriftType());
                String fieldName = fieldMetadata.getName() + "Codec";

                FieldDefinition codecField = new FieldDefinition(a(PRIVATE, FINAL), fieldName, type(codec.getClass()));
                classDefinition.addField(codecField);
                codecFields.put(fieldMetadata.getId(), codecField);

                parameters.add(codecField, codec);
            }
        }
        return codecFields;
    }

    /**
     * Defines the constructor with a parameter for the ThriftType and the delegate codecs. The
     * constructor simply assigns these parameters to the class fields.
     */
    private void defineConstructor()
    {
        //
        // declare the constructor
        MethodDefinition constructor = new MethodDefinition(
                a(PUBLIC),
                "",
                type(void.class),
                parameters.getParameters()
        );

        // invoke super (Object) constructor
        constructor.loadThis().invokeConstructor(type(Object.class));

        // this.foo = foo;
        for (FieldDefinition fieldDefinition : parameters.getFields()) {
            constructor.loadThis()
                    .loadVariable(fieldDefinition.getName())
                    .putField(codecType, fieldDefinition);
        }

        // return; (implicit)
        constructor.ret();

        classDefinition.addMethod(constructor);
    }

    /**
     * Defines the getType method which simply returns the value of the type field.
     */
    private void defineGetTypeMethod()
    {
        classDefinition.addMethod(
                new MethodDefinition(a(PUBLIC), "getType", type(ThriftType.class))
                        .loadThis()
                        .getField(codecType, typeField)
                        .retObject()
        );
    }

    /**
     * Defines the read method for a struct.
     */
    private void defineReadStructMethod()
    {
        MethodDefinition read = new MethodDefinition(
                a(PUBLIC),
                "read",
                structType,
                arg("protocol", TProtocol.class)
        ).addException(Exception.class);

        // TProtocolReader reader = new TProtocolReader(protocol);
        read.addLocalVariable(type(TProtocolReader.class), "reader");
        read.newObject(TProtocolReader.class);
        read.dup();
        read.loadVariable("protocol");
        read.invokeConstructor(type(TProtocolReader.class), type(TProtocol.class));
        read.storeVariable("reader");

        // read all of the data in to local variables
        Map structData = readFieldValues(read);

        // build the struct
        LocalVariableDefinition result = buildStruct(read, structData);

        // push instance on stack, and return it
        read.loadVariable(result).retObject();

        classDefinition.addMethod(read);
    }

    /**
     * Defines the code to read all of the data from the protocol into local variables.
     */
    private Map readFieldValues(MethodDefinition read)
    {
        LocalVariableDefinition protocol = read.getLocalVariable("reader");

        // declare and init local variables here
        Map structData = new TreeMap<>();
        for (ThriftFieldMetadata field : metadata.getFields(FieldKind.THRIFT_FIELD)) {
            LocalVariableDefinition variable = read.addInitializedLocalVariable(
                    toParameterizedType(field.getThriftType()),
                    "f_" + field.getName()
            );
            structData.put(field.getId(), variable);
        }

        // protocol.readStructBegin();
        read.loadVariable(protocol).invokeVirtual(
                TProtocolReader.class,
                "readStructBegin",
                void.class
        );

        // while (protocol.nextField())
        read.visitLabel("while-begin");
        read.loadVariable(protocol).invokeVirtual(TProtocolReader.class, "nextField", boolean.class);
        read.ifZeroGoto("while-end");

        // switch (protocol.getFieldId())
        read.loadVariable(protocol).invokeVirtual(TProtocolReader.class, "getFieldId", short.class);
        List cases = new ArrayList<>();
        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            cases.add(caseStatement(field.getId(), field.getName() + "-field"));
        }
        read.switchStatement("default", cases);

        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            // case field.id:
            read.visitLabel(field.getName() + "-field");

            // push protocol
            read.loadVariable(protocol);

            // push ThriftTypeCodec for this field
            FieldDefinition codecField = codecFields.get(field.getId());
            if (codecField != null) {
                read.loadThis().getField(codecType, codecField);
            }

            // read value
            Method readMethod = getReadMethod(field.getThriftType());
            if (readMethod == null) {
                throw new IllegalArgumentException("Unsupported field type " + field.getThriftType().getProtocolType());
            }
            read.invokeVirtual(readMethod);

            // todo this cast should be based on readMethod return type and fieldType (or coercion type)
            // add cast if necessary
            if (needsCastAfterRead(field, readMethod)) {
                read.checkCast(toParameterizedType(field.getThriftType()));
            }

            // coerce the type
            if (field.getCoercion().isPresent()) {
                read.invokeStatic(field.getCoercion().get().getFromThrift());
            }

            // store protocol value
            read.storeVariable(structData.get(field.getId()));

            // go back to top of loop
            read.gotoLabel("while-begin");
        }

        // default case
        read.visitLabel("default")
                .loadVariable(protocol)
                .invokeVirtual(TProtocolReader.class, "skipFieldData", void.class)
                .gotoLabel("while-begin");

        // end of while loop
        read.visitLabel("while-end");

        // protocol.readStructEnd();
        read.loadVariable(protocol)
                .invokeVirtual(TProtocolReader.class, "readStructEnd", void.class);
        return structData;
    }

    /**
     * Defines the code to build the struct instance using the data in the local variables.
     */
    private LocalVariableDefinition buildStruct(MethodDefinition read, Map structData)
    {
        // construct the instance and store it in the instance local variable
        LocalVariableDefinition instance = constructStructInstance(read, structData);

        // inject fields
        injectStructFields(read, instance, structData);

        // inject methods
        injectStructMethods(read, instance, structData);

        // invoke factory method if present
        invokeFactoryMethod(read, structData, instance);

        return instance;
    }

    /**
     * Defines the code to construct the struct (or builder) instance and stores it in a local
     * variable.
     */
    private LocalVariableDefinition constructStructInstance(MethodDefinition read, Map structData)
    {
        LocalVariableDefinition instance = read.addLocalVariable(structType, "instance");

        // create the new instance (or builder)
        if (metadata.getBuilderClass() == null) {
            read.newObject(structType).dup();
        }
        else {
            read.newObject(metadata.getBuilderClass()).dup();
        }

        // invoke constructor
        ThriftConstructorInjection constructor = metadata.getConstructorInjection().get();
        // push parameters on stack
        for (ThriftParameterInjection parameter : constructor.getParameters()) {
            read.loadVariable(structData.get(parameter.getId()));
        }
        // invoke constructor
        read.invokeConstructor(constructor.getConstructor())
                .storeVariable(instance);
        return instance;
    }

    /**
     * Defines the code to inject data into the struct public fields.
     */
    private void injectStructFields(MethodDefinition read, LocalVariableDefinition instance, Map structData)
    {
        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            injectField(read, field, instance, structData.get(field.getId()));
        }
    }

    /**
     * Defines the code to inject data into the struct methods.
     */
    private void injectStructMethods(MethodDefinition read, LocalVariableDefinition instance, Map structData)
    {
        for (ThriftMethodInjection methodInjection : metadata.getMethodInjections()) {
            injectMethod(read, methodInjection, instance, structData);
        }
    }

    /**
     * Defines the read method for an union.
     */
    private void defineReadUnionMethod()
    {
        MethodDefinition read = new MethodDefinition(
                a(PUBLIC),
                "read",
                structType,
                arg("protocol", TProtocol.class)
        ).addException(Exception.class);

        // TProtocolReader reader = new TProtocolReader(protocol);
        read.addLocalVariable(type(TProtocolReader.class), "reader");
        read.newObject(TProtocolReader.class);
        read.dup();
        read.loadVariable("protocol");
        read.invokeConstructor(type(TProtocolReader.class), type(TProtocol.class));
        read.storeVariable("reader");

        // field id field.
        read.addInitializedLocalVariable(type(short.class), "fieldId");

        // read all of the data in to local variables
        Map unionData = readSingleFieldValue(read);

        // build the struct
        LocalVariableDefinition result = buildUnion(read, unionData);

        // push instance on stack, and return it
        read.loadVariable(result).retObject();

        classDefinition.addMethod(read);
    }

    /**
     * Defines the code to read all of the data from the protocol into local variables.
     */
    private Map readSingleFieldValue(MethodDefinition read)
    {
        LocalVariableDefinition protocol = read.getLocalVariable("reader");

        // declare and init local variables here
        Map unionData = new TreeMap<>();
        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            LocalVariableDefinition variable = read.addInitializedLocalVariable(
                    toParameterizedType(field.getThriftType()),
                    "f_" + field.getName()
            );
            unionData.put(field.getId(), variable);
        }

        // protocol.readStructBegin();
        read.loadVariable(protocol).invokeVirtual(
                TProtocolReader.class,
                "readStructBegin",
                void.class
        );

        // while (protocol.nextField())
        read.visitLabel("while-begin");
        read.loadVariable(protocol).invokeVirtual(TProtocolReader.class, "nextField", boolean.class);
        read.ifZeroGoto("while-end");

        // fieldId = protocol.getFieldId()
        read.loadVariable(protocol).invokeVirtual(TProtocolReader.class, "getFieldId", short.class);
        read.storeVariable("fieldId");

        // switch (fieldId)
        read.loadVariable("fieldId");

        List cases = new ArrayList<>();
        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            cases.add(caseStatement(field.getId(), field.getName() + "-field"));
        }
        read.switchStatement("default", cases);

        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            // case field.id:
            read.visitLabel(field.getName() + "-field");

            // push protocol
            read.loadVariable(protocol);

            // push ThriftTypeCodec for this field
            FieldDefinition codecField = codecFields.get(field.getId());
            if (codecField != null) {
                read.loadThis().getField(codecType, codecField);
            }

            // read value
            Method readMethod = getReadMethod(field.getThriftType());
            if (readMethod == null) {
                throw new IllegalArgumentException("Unsupported field type " + field.getThriftType().getProtocolType());
            }
            read.invokeVirtual(readMethod);

            // todo this cast should be based on readMethod return type and fieldType (or coercion type)
            // add cast if necessary
            if (needsCastAfterRead(field, readMethod)) {
                read.checkCast(toParameterizedType(field.getThriftType()));
            }

            // coerce the type
            if (field.getCoercion().isPresent()) {
                read.invokeStatic(field.getCoercion().get().getFromThrift());
            }

            // store protocol value
            read.storeVariable(unionData.get(field.getId()));

            // go back to top of loop
            read.gotoLabel("while-begin");
        }

        // default case
        read.visitLabel("default")
                .loadVariable(protocol)
                .invokeVirtual(TProtocolReader.class, "skipFieldData", void.class)
                .gotoLabel("while-begin");

        // end of while loop
        read.visitLabel("while-end");

        // protocol.readStructEnd();
        read.loadVariable(protocol)
                .invokeVirtual(TProtocolReader.class, "readStructEnd", void.class);

        return unionData;
    }

    /**
     * Defines the code to build the struct instance using the data in the local variables.
     */
    private LocalVariableDefinition buildUnion(MethodDefinition read, Map unionData)
    {
        // construct the instance and store it in the instance local variable
        LocalVariableDefinition instance = constructUnionInstance(read);

        // switch (fieldId)
        read.loadVariable("fieldId");

        List cases = new ArrayList<>();
        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            cases.add(caseStatement(field.getId(), field.getName() + "-inject-field"));
        }
        read.switchStatement("inject-default", cases);

        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            // case field.id:
            read.visitLabel(field.getName() + "-inject-field");

            injectField(read, field, instance, unionData.get(field.getId()));

            if (field.getMethodInjection().isPresent()) {
                injectMethod(read, field.getMethodInjection().get(), instance, unionData);
            }

            read.gotoLabel("inject-default");
        }

        // default case
        read.visitLabel("inject-default");

        // find the @ThriftUnionId field
        ThriftFieldMetadata idField = getOnlyElement(metadata.getFields(THRIFT_UNION_ID));

        injectIdField(read, idField, instance, unionData);

        // invoke factory method if present
        invokeFactoryMethod(read, unionData, instance);

        return instance;
    }

    /**
     * Defines the code to construct the union (or builder) instance and stores it in a local
     * variable.
     */
    private LocalVariableDefinition constructUnionInstance(MethodDefinition read)
    {
        LocalVariableDefinition instance = read.addLocalVariable(structType, "instance");

        // create the new instance (or builder)
        if (metadata.getBuilderClass() == null) {
            read.newObject(structType).dup();
        }
        else {
            read.newObject(metadata.getBuilderClass()).dup();
        }

        // switch (fieldId)
        read.loadVariable("fieldId");

        List cases = new ArrayList<>();
        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            if (field.getConstructorInjection().isPresent()) {
                cases.add(caseStatement(field.getId(), field.getName() + "-id-field"));
            }
        }
        read.switchStatement("no-field-ctor", cases);

        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            if (field.getConstructorInjection().isPresent()) {
                // case fieldId:
                read.visitLabel(field.getName() + "-id-field");

                // Load the read value
                read.loadVariable("f_" + field.getName());
                read.invokeConstructor(field.getConstructorInjection().get().getConstructor())
                        .storeVariable(instance)
                        .gotoLabel("instance-ok");
            }
        }

        read.visitLabel("no-field-ctor");

        // No args c'tor present.
        if (metadata.getConstructorInjection().isPresent()) {
            ThriftConstructorInjection constructor = metadata.getConstructorInjection().get();
            // invoke constructor
            read.invokeConstructor(constructor.getConstructor())
                    .storeVariable(instance);
        }
        else {
            read.pop() // get rid of the half-constructed element
                    .loadConstant(metadata.getStructClass())
                    .loadVariable("fieldId")
                    .invokeStatic(SwiftBytecodeHelper.NO_CONSTRUCTOR_FOUND)
                    .throwException();
        }
        read.visitLabel("instance-ok");

        return instance;
    }

    private void injectField(MethodDefinition read, ThriftFieldMetadata field, LocalVariableDefinition instance, LocalVariableDefinition sourceVariable)
    {
        for (ThriftInjection injection : field.getInjections()) {
            if (injection instanceof ThriftFieldInjection) {
                ThriftFieldInjection fieldInjection = (ThriftFieldInjection) injection;

                // if field is an Object && field != null
                if (!isProtocolTypeJavaPrimitive(field)) {
                    read.loadVariable(sourceVariable)
                            .ifNullGoto("field_is_null_" + field.getName());
                }

                // write value
                read.loadVariable(instance)
                        .loadVariable(sourceVariable)
                        .putField(fieldInjection.getField());

                // else do nothing
                if (!isProtocolTypeJavaPrimitive(field)) {
                    read.visitLabel("field_is_null_" + field.getName());
                }
            }
        }
    }

    private void injectMethod(MethodDefinition read, ThriftMethodInjection methodInjection, LocalVariableDefinition instance, Map structData)
    {
        // if any parameter is non-null, invoke the method
        String methodName = methodInjection.getMethod().toGenericString();
        for (ThriftParameterInjection parameter : methodInjection.getParameters()) {
            if (!isParameterTypeJavaPrimitive(parameter)) {
                read.loadVariable(structData.get(parameter.getId()))
                        .ifNotNullGoto("invoke_" + methodName);
            }
            else {
                read.gotoLabel("invoke_" + methodName);
            }
        }
        read.gotoLabel("skip_invoke_" + methodName);

        // invoke the method
        read.visitLabel("invoke_" + methodName)
                .loadVariable(instance);

        // push parameters on stack
        for (ThriftParameterInjection parameter : methodInjection.getParameters()) {
            read.loadVariable(structData.get(parameter.getId()));
        }

        // invoke the method
        read.invokeVirtual(methodInjection.getMethod());

        // if method has a return, we need to pop it off the stack
        if (methodInjection.getMethod().getReturnType() != void.class) {
            read.pop();
        }

        // skip invocation
        read.visitLabel("skip_invoke_" + methodName);
    }

    /**
     * Defines the code that calls the builder factory method.
     */
    private void invokeFactoryMethod(MethodDefinition read, Map structData, LocalVariableDefinition instance)
    {
        if (metadata.getBuilderMethod().isPresent()) {
            ThriftMethodInjection builderMethod = metadata.getBuilderMethod().get();
            read.loadVariable(instance);

            // push parameters on stack
            for (ThriftParameterInjection parameter : builderMethod.getParameters()) {
                read.loadVariable(structData.get(parameter.getId()));
            }

            // invoke the method
            read.invokeVirtual(builderMethod.getMethod())
                    .storeVariable(instance);
        }
    }

    private void injectIdField(MethodDefinition read, ThriftFieldMetadata field, LocalVariableDefinition instance, Map structData)
    {
        for (ThriftInjection injection : field.getInjections()) {
            if (injection instanceof ThriftFieldInjection) {
                ThriftFieldInjection fieldInjection = (ThriftFieldInjection) injection;

                // if field is an Object && field != null
                if (!isProtocolTypeJavaPrimitive(field)) {
                    read.loadVariable("fieldId")
                            .ifNullGoto("field_is_null_fieldId");
                }

                // write value
                read.loadVariable(instance)
                        .loadVariable("fieldId")
                        .putField(fieldInjection.getField());

                // else do nothing
                if (!isProtocolTypeJavaPrimitive(field)) {
                    read.visitLabel("field_is_null_fieldId");
                }
            }
        }
    }

    /**
     * Define the write method.
     */
    private void defineWriteStructMethod()
    {
        MethodDefinition write = new MethodDefinition(
                a(PUBLIC),
                "write",
                null,
                arg("struct", structType),
                arg("protocol", TProtocol.class)
        );

        classDefinition.addMethod(write);

        // TProtocolReader reader = new TProtocolReader(protocol);
        write.addLocalVariable(type(TProtocolWriter.class), "writer");
        write.newObject(TProtocolWriter.class);
        write.dup();
        write.loadVariable("protocol");
        write.invokeConstructor(type(TProtocolWriter.class), type(TProtocol.class));
        write.storeVariable("writer");

        LocalVariableDefinition protocol = write.getLocalVariable("writer");

        // protocol.writeStructBegin("bonk");
        write.loadVariable(protocol)
                .loadConstant(metadata.getStructName())
                .invokeVirtual(TProtocolWriter.class, "writeStructBegin", void.class, String.class);

        // write fields
        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            writeField(write, protocol, field);
        }

        write.loadVariable(protocol)
                .invokeVirtual(TProtocolWriter.class, "writeStructEnd", void.class);

        write.ret();
    }

    /**
     * Define the write method.
     */
    private void defineWriteUnionMethod()
    {
        MethodDefinition write = new MethodDefinition(
                a(PUBLIC),
                "write",
                null,
                arg("struct", structType),
                arg("protocol", TProtocol.class)
        );

        classDefinition.addMethod(write);

        // TProtocolWriter writer = new TProtocolWriter(protocol);
        write.addLocalVariable(type(TProtocolWriter.class), "writer");
        write.newObject(TProtocolWriter.class);
        write.dup();
        write.loadVariable("protocol");
        write.invokeConstructor(type(TProtocolWriter.class), type(TProtocol.class));
        write.storeVariable("writer");

        LocalVariableDefinition protocol = write.getLocalVariable("writer");

        // protocol.writeStructBegin("bonk");
        write.loadVariable(protocol)
            .loadConstant(metadata.getStructName())
            .invokeVirtual(TProtocolWriter.class, "writeStructBegin", void.class, String.class);

        // find the @ThriftUnionId field
        ThriftFieldMetadata idField = getOnlyElement(metadata.getFields(THRIFT_UNION_ID));

        // load its value
        loadFieldValue(write, idField);

        // switch(fieldId)
        List cases = new ArrayList<>();
        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            cases.add(caseStatement(field.getId(), field.getName() + "-write-field"));
        }
        write.switchStatement("default-write", cases);

        // write fields
        for (ThriftFieldMetadata field : metadata.getFields(THRIFT_FIELD)) {
            write.visitLabel(field.getName() + "-write-field");
            writeField(write, protocol, field);
            write.gotoLabel("default-write");
        }

        write.visitLabel("default-write")
            .loadVariable(protocol)
            .invokeVirtual(TProtocolWriter.class, "writeStructEnd", void.class);

        write.ret();
    }

    private void writeField(MethodDefinition write, LocalVariableDefinition protocol, ThriftFieldMetadata field)
    {
        // push protocol
        write.loadVariable(protocol);

        // push (String) field.name
        write.loadConstant(field.getName());

        // push (short) field.id
        write.loadConstant(field.getId());

        // push ThriftTypeCodec for this field
        FieldDefinition codecField = codecFields.get(field.getId());
        if (codecField != null) {
            write.loadThis().getField(codecType, codecField);
        }

        // push field value
        loadFieldValue(write, field);

        // if field value is null, don't coerce or write the field
        if (!isFieldTypeJavaPrimitive(field)) {
            // ifNullGoto consumes the top of the stack, so we need to duplicate the value
            write.dup();
            write.ifNullGoto("field_is_null_" + field.getName());
        }

        // coerce value
        if (field.getCoercion().isPresent()) {
            write.invokeStatic(field.getCoercion().get().getToThrift());

            // if coerced value is null, don't write the field
            if (!isProtocolTypeJavaPrimitive(field)) {
                write.dup();
                write.ifNullGoto("field_is_null_" + field.getName());
            }
        }

        // write value
        Method writeMethod = getWriteMethod(field.getThriftType());
        if (writeMethod == null) {
            throw new IllegalArgumentException("Unsupported field type " + field.getThriftType().getProtocolType());
        }
        write.invokeVirtual(writeMethod);

        //
        // If not written because of a null, clean-up the stack
        if (!isProtocolTypeJavaPrimitive(field) || !isFieldTypeJavaPrimitive(field)) {

            // value was written so skip cleanup
            write.gotoLabel("field_end_" + field.getName());

            // cleanup stack for null field value
            write.visitLabel("field_is_null_" + field.getName());
            // pop value
            write.pop();
            // pop codec
            if (codecField != null) {
                write.pop();
            }
            // pop id
            write.pop();
            // pop name
            write.pop();
            // pop protocol
            write.pop();

            write.visitLabel("field_end_" + field.getName());
        }
    }

    private void loadFieldValue(MethodDefinition write, ThriftFieldMetadata field)
    {
        write.loadVariable("struct");
        if (field.getExtraction().isPresent()) {
            ThriftExtraction extraction = field.getExtraction().get();
            if (extraction instanceof ThriftFieldExtractor) {
                ThriftFieldExtractor fieldExtractor = (ThriftFieldExtractor) extraction;
                write.getField(fieldExtractor.getField());
                if (fieldExtractor.isGeneric()) {
                  write.checkCast(type(fieldExtractor.getType()));
                }
            }
            else if (extraction instanceof ThriftMethodExtractor) {
                ThriftMethodExtractor methodExtractor = (ThriftMethodExtractor) extraction;
                write.invokeVirtual(methodExtractor.getMethod());
                if (methodExtractor.isGeneric()) {
                  write.checkCast(type(methodExtractor.getType()));
                }
            }
        }
    }

    /**
     * Defines the generics bridge method with untyped args to the type specific read method.
     */
    private void defineReadBridgeMethod()
    {
        classDefinition.addMethod(
                new MethodDefinition(a(PUBLIC, BRIDGE, SYNTHETIC), "read", type(Object.class), arg("protocol", TProtocol.class))
                        .addException(Exception.class)
                        .loadThis()
                        .loadVariable("protocol")
                        .invokeVirtual(codecType, "read", structType, type(TProtocol.class))
                        .retObject()
        );
    }

    /**
     * Defines the generics bridge method with untyped args to the type specific write method.
     */
    private void defineWriteBridgeMethod()
    {
        classDefinition.addMethod(
                new MethodDefinition(a(PUBLIC, BRIDGE, SYNTHETIC), "write", null, arg("struct", Object.class), arg("protocol", TProtocol.class))
                        .addException(Exception.class)
                        .loadThis()
                        .loadVariable("struct", structType)
                        .loadVariable("protocol")
                        .invokeVirtual(
                                codecType,
                                "write",
                                type(void.class),
                                structType,
                                type(TProtocol.class)
                        )
                        .ret()
        );
    }

    private boolean isParameterTypeJavaPrimitive(ThriftParameterInjection parameter)
    {
        return isJavaPrimitive(TypeToken.of(parameter.getJavaType()));
    }

    private boolean isFieldTypeJavaPrimitive(ThriftFieldMetadata field)
    {
        return isJavaPrimitive(TypeToken.of(field.getThriftType().getJavaType()));
    }

    private boolean isProtocolTypeJavaPrimitive(ThriftFieldMetadata field)
    {
        if (field.getThriftType().isCoerced()) {
            return isJavaPrimitive(TypeToken.of(field.getThriftType().getUncoercedType().getJavaType()));
        }
        else {
            return isJavaPrimitive(TypeToken.of(field.getThriftType().getJavaType()));
        }
    }

    private boolean isJavaPrimitive(TypeToken typeToken)
    {
        return typeToken
                .getRawType()
                .isPrimitive();
    }

    private static boolean needsCastAfterRead(ThriftFieldMetadata field, Method readMethod)
    {
        Class methodReturn = readMethod.getReturnType();
        Class fieldType;
        if (field.getCoercion().isPresent()) {
            fieldType = field.getCoercion().get().getFromThrift().getParameterTypes()[0];
        }
        else {
            fieldType = TypeToken.of(field.getThriftType().getJavaType()).getRawType();
        }
        boolean needsCast = !fieldType.isAssignableFrom(methodReturn);
        return needsCast;
    }

    private boolean needsCodec(ThriftFieldMetadata fieldMetadata)
    {
        if (ReflectionHelper.isArray(fieldMetadata.getThriftType().getJavaType())) {
            return false;
        }

        ThriftProtocolType protocolType = fieldMetadata.getThriftType().getProtocolType();
        return protocolType == ENUM ||
                protocolType == STRUCT ||
                protocolType == SET ||
                protocolType == LIST ||
                protocolType == MAP;
    }

    private ParameterizedType toCodecType(ThriftStructMetadata metadata)
    {
        return type(PACKAGE + "/" + type(metadata.getStructClass()).getClassName() + "Codec");
    }

    private static class ConstructorParameters
    {
        private final List fields = new ArrayList<>();
        private final List values = new ArrayList<>();

        private void add(FieldDefinition field, Object value)
        {
            fields.add(field);
            values.add(value);
        }

        public List getFields()
        {
            return fields;
        }

        public Object[] getValues()
        {
            return values.toArray(new Object[values.size()]);
        }

        public List getParameters()
        {
            return Lists.transform(fields, new Function()
            {
                public NamedParameterDefinition apply(FieldDefinition field)
                {
                    return arg(field.getName(), field.getType());
                }
            });
        }

        public Class[] getTypes()
        {
            List> types = Lists.transform(values, new Function>()
            {
                public Class apply(Object value)
                {
                    return value.getClass();
                }
            });

            return types.toArray(new Class[types.size()]);
        }
    }

    public static ParameterizedType toParameterizedType(ThriftType type)
    {
        return toParameterizedType(new DefaultThriftTypeReference(type));
    }

    public static ParameterizedType toParameterizedType(ThriftTypeReference typeRef)
    {
        if (ReflectionHelper.isArray(typeRef.getJavaType())) {
            return type((Class) typeRef.getJavaType());
        }

        switch (typeRef.getProtocolType()) {
            case BOOL:
            case BYTE:
            case DOUBLE:
            case I16:
            case I32:
            case I64:
            case STRING:
            case BINARY:
            case STRUCT:
            case ENUM:
                return type((Class) typeRef.getJavaType());
            case MAP:
                return type(Map.class, toParameterizedType(typeRef.get().getKeyTypeReference()), toParameterizedType(typeRef.get().getValueTypeReference()));
            case SET:
                return type(Set.class, toParameterizedType(typeRef.get().getValueTypeReference()));
            case LIST:
                return type(List.class, toParameterizedType(typeRef.get().getValueTypeReference()));
            default:
                throw new IllegalArgumentException("Unsupported thrift field type " + typeRef.getJavaType());
        }
    }

    private Method getWriteMethod(ThriftType thriftType)
    {
        if (ReflectionHelper.isArray(thriftType.getJavaType())) {
            return ARRAY_WRITE_METHODS.get(thriftType.getJavaType());
        }
        return WRITE_METHODS.get(thriftType.getProtocolType());
    }

    private Method getReadMethod(ThriftType thriftType)
    {
        if (ReflectionHelper.isArray(thriftType.getJavaType())) {
            return ARRAY_READ_METHODS.get(thriftType.getJavaType());
        }
        return READ_METHODS.get(thriftType.getProtocolType());
    }

    static {
        ImmutableMap.Builder writeBuilder = ImmutableMap.builder();
        ImmutableMap.Builder readBuilder = ImmutableMap.builder();

        try {
            writeBuilder.put(BOOL, TProtocolWriter.class.getMethod("writeBoolField", String.class, short.class, boolean.class));
            writeBuilder.put(BYTE, TProtocolWriter.class.getMethod("writeByteField", String.class, short.class, byte.class));
            writeBuilder.put(DOUBLE, TProtocolWriter.class.getMethod("writeDoubleField", String.class, short.class, double.class));
            writeBuilder.put(I16, TProtocolWriter.class.getMethod("writeI16Field", String.class, short.class, short.class));
            writeBuilder.put(I32, TProtocolWriter.class.getMethod("writeI32Field", String.class, short.class, int.class));
            writeBuilder.put(I64, TProtocolWriter.class.getMethod("writeI64Field", String.class, short.class, long.class));
            writeBuilder.put(STRING, TProtocolWriter.class.getMethod("writeStringField", String.class, short.class, String.class));
            writeBuilder.put(BINARY, TProtocolWriter.class.getMethod("writeBinaryField", String.class, short.class, ByteBuffer.class));
            writeBuilder.put(STRUCT, TProtocolWriter.class.getMethod("writeStructField", String.class, short.class, ThriftCodec.class, Object.class));
            writeBuilder.put(MAP, TProtocolWriter.class.getMethod("writeMapField", String.class, short.class, ThriftCodec.class, Map.class));
            writeBuilder.put(SET, TProtocolWriter.class.getMethod("writeSetField", String.class, short.class, ThriftCodec.class, Set.class));
            writeBuilder.put(LIST, TProtocolWriter.class.getMethod("writeListField", String.class, short.class, ThriftCodec.class, List.class));
            writeBuilder.put(ENUM, TProtocolWriter.class.getMethod("writeEnumField", String.class, short.class, ThriftCodec.class, Enum.class));

            readBuilder.put(BOOL, TProtocolReader.class.getMethod("readBoolField"));
            readBuilder.put(BYTE, TProtocolReader.class.getMethod("readByteField"));
            readBuilder.put(DOUBLE, TProtocolReader.class.getMethod("readDoubleField"));
            readBuilder.put(I16, TProtocolReader.class.getMethod("readI16Field"));
            readBuilder.put(I32, TProtocolReader.class.getMethod("readI32Field"));
            readBuilder.put(I64, TProtocolReader.class.getMethod("readI64Field"));
            readBuilder.put(STRING, TProtocolReader.class.getMethod("readStringField"));
            readBuilder.put(BINARY, TProtocolReader.class.getMethod("readBinaryField"));
            readBuilder.put(STRUCT, TProtocolReader.class.getMethod("readStructField", ThriftCodec.class));
            readBuilder.put(MAP, TProtocolReader.class.getMethod("readMapField", ThriftCodec.class));
            readBuilder.put(SET, TProtocolReader.class.getMethod("readSetField", ThriftCodec.class));
            readBuilder.put(LIST, TProtocolReader.class.getMethod("readListField", ThriftCodec.class));
            readBuilder.put(ENUM, TProtocolReader.class.getMethod("readEnumField", ThriftCodec.class));
        }
        catch (NoSuchMethodException e) {
            throw Throwables.propagate(e);
        }
        WRITE_METHODS = writeBuilder.build();
        READ_METHODS = readBuilder.build();

        ImmutableMap.Builder arrayWriteBuilder = ImmutableMap.builder();
        ImmutableMap.Builder arrayReadBuilder = ImmutableMap.builder();

        try {
            arrayWriteBuilder.put(boolean[].class, TProtocolWriter.class.getMethod("writeBoolArrayField", String.class, short.class, boolean[].class));
            arrayWriteBuilder.put(short[].class, TProtocolWriter.class.getMethod("writeI16ArrayField", String.class, short.class, short[].class));
            arrayWriteBuilder.put(int[].class, TProtocolWriter.class.getMethod("writeI32ArrayField", String.class, short.class, int[].class));
            arrayWriteBuilder.put(long[].class, TProtocolWriter.class.getMethod("writeI64ArrayField", String.class, short.class, long[].class));
            arrayWriteBuilder.put(double[].class, TProtocolWriter.class.getMethod("writeDoubleArrayField", String.class, short.class, double[].class));

            arrayReadBuilder.put(boolean[].class, TProtocolReader.class.getMethod("readBoolArrayField"));
            arrayReadBuilder.put(short[].class, TProtocolReader.class.getMethod("readI16ArrayField"));
            arrayReadBuilder.put(int[].class, TProtocolReader.class.getMethod("readI32ArrayField"));
            arrayReadBuilder.put(long[].class, TProtocolReader.class.getMethod("readI64ArrayField"));
            arrayReadBuilder.put(double[].class, TProtocolReader.class.getMethod("readDoubleArrayField"));

            // byte[] is encoded as BINARY which should use the normal rules above, but it
            // simpler to add explicit handling here
            arrayWriteBuilder.put(byte[].class, TProtocolWriter.class.getMethod("writeBinaryField", String.class, short.class, ByteBuffer.class));
            arrayReadBuilder.put(byte[].class, TProtocolReader.class.getMethod("readBinaryField"));
        }
        catch (NoSuchMethodException e) {
            throw Throwables.propagate(e);
        }
        ARRAY_WRITE_METHODS = arrayWriteBuilder.build();
        ARRAY_READ_METHODS = arrayReadBuilder.build();
    }
}