All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cn.wjybxx.dsonapt.PojoCodecGenerator Maven / Gradle / Ivy

There is a newer version: 2.2.0
Show newest version
/*
 * Copyright 2023-2024 wjybxx([email protected])
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package cn.wjybxx.dsonapt;

import cn.wjybxx.apt.AbstractGenerator;
import cn.wjybxx.apt.AptUtils;
import cn.wjybxx.apt.BeanUtils;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.TypeSpec;

import javax.lang.model.element.*;
import javax.lang.model.type.DeclaredType;
import javax.lang.model.type.TypeKind;
import javax.lang.model.type.TypeMirror;
import javax.tools.Diagnostic;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;

/**
 * @author wjybxx
 * date 2023/4/13
 */
class PojoCodecGenerator extends AbstractGenerator {

    public static final String MNAME_READ_STRING = "readString";
    public static final String MNAME_READ_BYTES = "readBytes";
    public static final String MNAME_READ_OBJECT = "readObject";

    public static final String MNAME_WRITE_STRING = "writeString";
    public static final String MNAME_WRITE_BYTES = "writeBytes";
    public static final String MNAME_WRITE_OBJECT = "writeObject";

    private static final Map primitiveReadMethodNameMap = new EnumMap<>(TypeKind.class);
    private static final Map primitiveWriteMethodNameMap = new EnumMap<>(TypeKind.class);

    private final Context context;
    private final TypeSpec.Builder typeBuilder;
    private final TypeMirror readerTypeMirror;
    private final TypeMirror writerTypeMirror;
    private final List allFieldsAndMethodWithInherit;

    protected ClassName rawTypeName;
    protected boolean containsReaderConstructor;
    protected boolean containsNewInstanceMethod;
    protected boolean containsReadObjectMethod;
    protected boolean containsWriteObjectMethod;
    protected boolean containsBeforeEncodeMethod;
    protected boolean containsAfterDecodeMethod;

    protected DeclaredType superDeclaredType;
    protected MethodSpec.Builder newInstanceMethodBuilder;
    protected MethodSpec.Builder readFieldsMethodBuilder;
    protected MethodSpec.Builder afterDecodeMethodBuilder;
    protected MethodSpec.Builder beforeEncodeMethodBuilder;
    protected MethodSpec.Builder writeFiedlsMethodBuilder;

    static {
        for (TypeKind typeKind : TypeKind.values()) {
            if (!typeKind.isPrimitive()) {
                continue;
            }
            final String name = BeanUtils.firstCharToUpperCase(typeKind.name().toLowerCase());
            primitiveReadMethodNameMap.put(typeKind, "read" + name);
            primitiveWriteMethodNameMap.put(typeKind, "write" + name);
        }
    }

    public PojoCodecGenerator(CodecProcessor processor, Context context) {
        super(processor, context.typeElement);
        this.context = context;
        this.typeBuilder = context.typeBuilder;
        this.readerTypeMirror = context.readerTypeMirror;
        this.writerTypeMirror = context.writerTypeMirror;
        this.allFieldsAndMethodWithInherit = context.allFieldsAndMethodWithInherit;
    }

    // region codec

    @Override
    public void execute() {
        init();

        gen();
    }

    /** 子类需要初始化 fieldsClassName */
    protected void init() {
        rawTypeName = ClassName.get(typeElement);
        containsReaderConstructor = processor.containsReaderConstructor(typeElement, readerTypeMirror);
        containsReadObjectMethod = processor.containsReadObjectMethod(allFieldsAndMethodWithInherit, readerTypeMirror);
        containsWriteObjectMethod = processor.containsWriteObjectMethod(allFieldsAndMethodWithInherit, writerTypeMirror);
        containsBeforeEncodeMethod = processor.containsBeforeEncodeMethod(allFieldsAndMethodWithInherit);
        containsAfterDecodeMethod = processor.containsAfterDecodeMethod(allFieldsAndMethodWithInherit);

        // 需要先初始化superDeclaredType
        superDeclaredType = context.superDeclaredType;
        newInstanceMethodBuilder = processor.newNewInstanceMethodBuilder(superDeclaredType);
        readFieldsMethodBuilder = processor.newReadFieldsMethodBuilder(superDeclaredType);
        afterDecodeMethodBuilder = processor.newAfterDecodeMethodBuilder(superDeclaredType);
        beforeEncodeMethodBuilder = processor.newBeforeEncodeMethodBuilder(superDeclaredType);
        writeFiedlsMethodBuilder = processor.newWriteFieldsMethodBuilder(superDeclaredType);
    }

    protected void gen() {
        // 当前生成的是不可序列化的类
        if (context.serialAnnoMirror == null) {
            newInstanceMethodBuilder.addStatement("throw new $T()", UnsupportedOperationException.class);
            readFieldsMethodBuilder.addStatement("throw new $T()", UnsupportedOperationException.class);
            writeFiedlsMethodBuilder.addStatement("throw new $T()", UnsupportedOperationException.class);
            typeBuilder.addAnnotation(context.scanIgnoreAnnoSpec)
                    .addMethod(processor.newGetEncoderClassMethod(superDeclaredType, rawTypeName))
                    .addMethod(newInstanceMethodBuilder.build())
                    .addMethod(readFieldsMethodBuilder.build())
                    .addMethod(writeFiedlsMethodBuilder.build());
            return;
        }

        // newInstance
        AptClassProps aptClassProps = context.aptClassProps;
        genNewInstanceMethod(aptClassProps);
        if (!aptClassProps.isSingleton()) {
            genWriteObjectMethod(aptClassProps);
            genReadObjectMethod(aptClassProps);
            // 普通字段读写
            for (VariableElement variableElement : context.serialFields) {
                final AptFieldProps aptFieldProps = context.fieldPropsMap.get(variableElement);
                if (processor.isAutoWriteField(variableElement, aptClassProps, aptFieldProps)) {
                    addWriteStatement(variableElement, aptFieldProps, aptClassProps);
                }
                if (processor.isAutoReadField(variableElement, aptClassProps, aptFieldProps)) {
                    addReadStatement(variableElement, aptFieldProps, aptClassProps);
                }
            }
        }

        // 控制方法生成顺序
        // getEncoder
        typeBuilder.addMethod(processor.newGetEncoderClassMethod(superDeclaredType, rawTypeName));

        // beforeEncode回调
        if (genBeforeEncodeMethod(aptClassProps)) {
            typeBuilder.addMethod(beforeEncodeMethodBuilder.build());
        }
        typeBuilder.addMethod(writeFiedlsMethodBuilder.build());
        typeBuilder.addMethod(newInstanceMethodBuilder.build())
                .addMethod(readFieldsMethodBuilder.build());
        // afterDecode回调
        if (genAfterDecodeMethod(aptClassProps)) {
            typeBuilder.addMethod(afterDecodeMethodBuilder.build());
        }

        // 额外注解
        if (context.additionalAnnotations != null) {
            typeBuilder.addAnnotations(context.additionalAnnotations);
        }
    }

    private static boolean containsHookMethod(AptClassProps aptClassProps, String methodName) {
        return aptClassProps.codecProxyEnclosedElements.stream()
                .filter(e -> e.getKind() == ElementKind.METHOD && e.getModifiers().contains(Modifier.STATIC))
                .anyMatch(e -> e.getSimpleName().toString().equals(methodName));
    }

    /** 调用用户的readObject方法 */
    private boolean genReadObjectMethod(AptClassProps aptClassProps) {
        if (aptClassProps.codecProxyTypeElement != null) {
            if (containsHookMethod(aptClassProps, CodecProcessor.MNAME_READ_OBJECT)) {
                // CodecProxy.readObject(instance, reader));
                readFieldsMethodBuilder.addStatement("$T.$L(instance, reader)", aptClassProps.codecProxyClassName, CodecProcessor.MNAME_READ_OBJECT);
                return true;
            }
        } else {
            if (containsReadObjectMethod) {
                // instance.readObject(reader);
                readFieldsMethodBuilder.addStatement("instance.$L(reader)", CodecProcessor.MNAME_READ_OBJECT);
                return true;
            }
        }
        return false;
    }

    /** 调用用户的writeObject方法 */
    private boolean genWriteObjectMethod(AptClassProps aptClassProps) {
        if (aptClassProps.codecProxyTypeElement != null) {
            if (containsHookMethod(aptClassProps, CodecProcessor.MNAME_WRITE_OBJECT)) {
                // CodecProxy.writeObject(instance, writer));
                writeFiedlsMethodBuilder.addStatement("$T.$L(instance, writer)", aptClassProps.codecProxyClassName, CodecProcessor.MNAME_WRITE_OBJECT);
                return true;
            }
        } else {
            if (containsWriteObjectMethod) {
                // instance.writeObject(writer);
                writeFiedlsMethodBuilder.addStatement("instance.$L(writer)", CodecProcessor.MNAME_WRITE_OBJECT);
                return true;
            }
        }
        return false;
    }

    /** 调用用户beforeEncode钩子方法 -- 需要支持codecProxy来处理 */
    private boolean genBeforeEncodeMethod(AptClassProps aptClassProps) {
        if (aptClassProps.codecProxyTypeElement != null) {
            if (containsHookMethod(aptClassProps, CodecProcessor.MNAME_BEFORE_ENCODE)) {
                // CodecProxy.beforeEncode(instance, writer.options());
                beforeEncodeMethodBuilder.addStatement("$T.$L(instance, writer.options())", aptClassProps.codecProxyClassName, CodecProcessor.MNAME_BEFORE_ENCODE);
                return true;
            }
        } else {
            if (containsBeforeEncodeMethod) {
                // instance.beforeEncode(writer.options());
                beforeEncodeMethodBuilder.addStatement("instance.$L(writer.options())", CodecProcessor.MNAME_BEFORE_ENCODE);
                return true;
            }
        }
        return false;
    }

    /** 调用用户afterDecode钩子方法 -- 需要支持CodecProxy来处理 */
    private boolean genAfterDecodeMethod(AptClassProps aptClassProps) {
        if (aptClassProps.codecProxyTypeElement != null) {
            if (containsHookMethod(aptClassProps, CodecProcessor.MNAME_AFTER_DECODE)) {
                // CodecProxy.afterDecode(instance, reader.options());
                afterDecodeMethodBuilder.addStatement("$T.$L(instance, reader.options())", aptClassProps.codecProxyClassName, CodecProcessor.MNAME_AFTER_DECODE);
                return true;
            }
        } else {
            if (containsAfterDecodeMethod) {
                // instance.afterDecode(reader.options());
                afterDecodeMethodBuilder.addStatement("instance.$L(reader.options())", CodecProcessor.MNAME_AFTER_DECODE);
                return true;
            }
        }
        return false;
    }

    private void genNewInstanceMethod(AptClassProps aptClassProps) {
        if (aptClassProps.isSingleton()) {
            // 有CodecProxy的情况下,单例也交由CodecProxy实现 -- 方法名是CodecProxy指定的,因此应当存在,不做校验
            if (aptClassProps.codecProxyTypeElement != null) {
                newInstanceMethodBuilder.addStatement("return $T.$L()", aptClassProps.codecProxyClassName, aptClassProps.singleton);
            } else {
                newInstanceMethodBuilder.addStatement("return $T.$L()", rawTypeName, aptClassProps.singleton);
            }
            return;
        }
        // 理论上,如果当前类是泛型类,需要<>表示泛型,避免不必要的警告
        if (typeElement.getModifiers().contains(Modifier.ABSTRACT)) {// 抽象类或接口
            newInstanceMethodBuilder.addStatement("throw new $T()", UnsupportedOperationException.class);
            return;
        }

        if (aptClassProps.codecProxyTypeElement != null) {
            if (containsHookMethod(aptClassProps, CodecProcessor.MNAME_NEW_INSTANCE)) {
                // CodecProxy.newInstance(reader);
                newInstanceMethodBuilder.addStatement("return $T.$L(reader)", aptClassProps.codecProxyClassName, CodecProcessor.MNAME_NEW_INSTANCE);
            }
        } else {
            if (containsNewInstanceMethod) { // 静态解析方法,优先级更高
                newInstanceMethodBuilder.addStatement("return $T.newInstance(reader)", rawTypeName);
            } else if (containsReaderConstructor) { // 解析构造方法
                newInstanceMethodBuilder.addStatement("return new $T(reader)", rawTypeName);
            } else {
                newInstanceMethodBuilder.addStatement("return new $T()", rawTypeName);
            }
        }
    }

    //
    private void addReadStatement(VariableElement variableElement, AptFieldProps properties, AptClassProps aptClassProps) {
        final String fieldName = variableElement.getSimpleName().toString();
        MethodSpec.Builder builder = readFieldsMethodBuilder;
        if (properties.hasReadProxy()) { // 自定义读
            if (aptClassProps.codecProxyTypeElement != null) {
                // 方法名是CodecProxy指定的,因此应当存在,不做校验
                builder.addStatement("$T.$L(instance, reader, $L)", aptClassProps.codecProxyClassName, properties.readProxy, serialName(fieldName));
            } else {
                builder.addStatement("instance.$L(reader, $L)", properties.readProxy, serialName(fieldName));
            }
            return;
        }
        final String readMethodName = getReadMethodName(variableElement);
        final ExecutableElement setterMethod = processor.findNotPrivateSetter(variableElement, allFieldsAndMethodWithInherit);
        // 优先用setter,否则直接赋值
        if (!AptUtils.isBlank(properties.setter) || setterMethod != null) {
            final String setterName = AptUtils.isBlank(properties.setter) ? setterMethod.getSimpleName().toString() : properties.setter;
            if (readMethodName.equals(MNAME_READ_OBJECT)) { // 读对象时要传入类型信息
                // instance.setName(reader.readObject(XXFields.name, XXTypeArgs.name))
                if (properties.implMirror != null) {
                    builder.addStatement("instance.$L(reader.$L($L, $L, $L))",
                            setterName, readMethodName,
                            serialName(fieldName), serialTypeArg(fieldName), serialFactory(fieldName));
                } else {
                    builder.addStatement("instance.$L(reader.$L($L, $L, null))",
                            setterName, readMethodName,
                            serialName(fieldName), serialTypeArg(fieldName));
                }
            } else {
                // instance.setName(reader.readString(XXFields.name))
                builder.addStatement("instance.$L(reader.$L($L))",
                        setterName, readMethodName,
                        serialName(fieldName));
            }
        } else {
            if (readMethodName.equals(MNAME_READ_OBJECT)) { // 读对象时要传入类型信息
                // instance.name = reader.readObject(XXFields.name, XXTypeArgs.name)
                if (properties.implMirror != null) {
                    builder.addStatement("instance.$L = reader.$L($L, $L, $L)",
                            fieldName, readMethodName,
                            serialName(fieldName), serialTypeArg(fieldName), serialFactory(fieldName));
                } else {
                    builder.addStatement("instance.$L = reader.$L($L, $L, null)",
                            fieldName, readMethodName,
                            serialName(fieldName), serialTypeArg(fieldName));
                }
            } else {
                // instance.name = reader.readString(XXFields.name)
                builder.addStatement("instance.$L = reader.$L($L)",
                        fieldName, readMethodName,
                        serialName(fieldName));
            }
        }
    }

    private void addWriteStatement(VariableElement variableElement, AptFieldProps properties, AptClassProps aptClassProps) {
        final String fieldName = variableElement.getSimpleName().toString();
        MethodSpec.Builder builder = this.writeFiedlsMethodBuilder;
        if (properties.hasWriteProxy()) { // 自定义写
            if (aptClassProps.codecProxyTypeElement != null) {
                // 方法名是CodecProxy指定的,因此应当存在,不做校验
                builder.addStatement("$T.$L(instance, writer, $L)", aptClassProps.codecProxyClassName, properties.writeProxy, serialName(fieldName));
            } else {
                builder.addStatement("instance.$L(writer, $L)", properties.writeProxy, serialName(fieldName));
            }
            return;
        }
        // 优先用getter,否则直接访问
        String fieldAccess;
        ExecutableElement getterMethod = processor.findNotPrivateGetter(variableElement, allFieldsAndMethodWithInherit);
        if (!AptUtils.isBlank(properties.getter)) {
            fieldAccess = properties.getter + "()";
        } else if (getterMethod != null) {
            fieldAccess = getterMethod.getSimpleName() + "()";
        } else {
            fieldAccess = fieldName;
        }

        // 先处理有子类型的类型
        if (properties.dsonType != null) {
            switch (properties.dsonType) {
                case AptFieldProps.TYPE_BINARY -> {
                    // writer.writeBytes(Fields.FieldName, subType, instance.field)
                    builder.addStatement("writer.writeBinary($L, $L, instance.$L)",
                            serialName(fieldName), properties.dsonSubType, fieldAccess);
                }
                case AptFieldProps.TYPE_EXT_INT32 -> {
                    // writer.writeExtInt32(Fields.FieldName, subType, instance.field, WireType.VARINT, NumberStyle.SIMPLE)
                    builder.addStatement("writer.writeExtInt32($L, $L, instance.$L, $T.$L, $T.$L)",
                            serialName(fieldName),
                            properties.dsonSubType, fieldAccess,
                            processor.typeNameWireType, properties.wireType,
                            processor.typeNameNumberStyle, properties.numberStyle);
                }
                case AptFieldProps.TYPE_EXT_INT64 -> {
                    // writer.writeExtInt64(Fields.FieldName, subType, instance.field, WireType.VARINT, NumberStyle.SIMPLE)
                    builder.addStatement("writer.writeExtInt64($L, $L, instance.$L, $T.$L, $T.$L)",
                            serialName(fieldName),
                            properties.dsonSubType, fieldAccess,
                            processor.typeNameWireType, properties.wireType,
                            processor.typeNameNumberStyle, properties.numberStyle);
                }
                case AptFieldProps.TYPE_EXT_DOUBLE -> {
                    // writer.writeExtDouble(Fields.FieldName, subType, instance.field, NumberStyle.SIMPLE)
                    builder.addStatement("writer.writeExtDouble($L, $L, instance.$L, $T.$L)",
                            serialName(fieldName),
                            properties.dsonSubType, fieldAccess,
                            processor.typeNameNumberStyle, properties.numberStyle);
                }
                case AptFieldProps.TYPE_EXT_STRING -> {
                    // writer.writeExtInt64(Fields.FieldName, subType, instance.field, StringStyle.AUTO)
                    builder.addStatement("writer.writeExtString($L, $L, instance.$L, $T.$L)",
                            serialName(fieldName),
                            properties.dsonSubType, fieldAccess,
                            processor.typeNameStringStyle, properties.stringStyle);
                }
                default -> {
                    messager.printMessage(Diagnostic.Kind.ERROR, "bad dsonType ", variableElement);
                }
            }
            return;
        }

        // 先处理数字--涉及WireType和Style
        final String writeMethodName = getWriteMethodName(variableElement);
        switch (variableElement.asType().getKind()) {
            case INT, LONG, SHORT, BYTE, CHAR -> {
                // writer.writeInt(Fields.FieldName, instance.field, WireType.VARINT, NumberStyle.SIMPLE)
                builder.addStatement("writer.$L($L, instance.$L, $T.$L, $T.$L)",
                        writeMethodName, serialName(fieldName), fieldAccess,
                        processor.typeNameWireType, properties.wireType,
                        processor.typeNameNumberStyle, properties.numberStyle);
                return;
            }
            case FLOAT, DOUBLE -> {
                // writer.writeInt(Fields.FieldName, instance.field, NumberStyle.SIMPLE)
                builder.addStatement("writer.$L($L, instance.$L, $T.$L)",
                        writeMethodName, serialName(fieldName), fieldAccess,
                        processor.typeNameNumberStyle, properties.numberStyle);
                return;
            }
        }
        // 处理字符串
        if (writeMethodName.equals(MNAME_WRITE_STRING)) {
            // writer.writeString(XXFields.name, instance.getName(), StringStyle.AUTO)
            builder.addStatement("writer.$L($L, instance.$L, $T.$L)",
                    writeMethodName, serialName(fieldName), fieldAccess,
                    processor.typeNameStringStyle, properties.stringStyle);
            return;
        }
        if (writeMethodName.equals(MNAME_WRITE_OBJECT)) {
            // 写Object时传入类型信息和Style
            // writer.writeObject(XXFields.name, instance.getName(), XXTypeArgs.name, ObjectStyle.INDENT)
            if (properties.objectStyle != null) {
                builder.addStatement("writer.$L($L, instance.$L, $L, $T.$L)",
                        writeMethodName, serialName(fieldName), fieldAccess, serialTypeArg(fieldName),
                        processor.typeNameObjectStyle, properties.objectStyle);
            } else {
                builder.addStatement("writer.$L($L, instance.$L, $L, null)",
                        writeMethodName, serialName(fieldName), fieldAccess, serialTypeArg(fieldName));
            }
        } else {
            // writer.writeBoolean(XXFields.name, instance.getName())
            builder.addStatement("writer.$L($L, instance.$L)",
                    writeMethodName, serialName(fieldName), fieldAccess);
        }
    }

    // endregion

    // region

    // 虽然多了临时字符串拼接,但可以大幅降低字符串模板的复杂度
    private String serialName(String fieldName) {
        return context.serialNameAccess + fieldName;
    }

    private String serialTypeArg(String fieldName) {
        return "types_" + fieldName;
    }

    private String serialFactory(String fieldName) {
        return "factories_" + fieldName;
    }

    /** 获取writer写字段的方法名 */
    private String getWriteMethodName(VariableElement variableElement) {
        TypeMirror typeMirror = variableElement.asType();
        if (isPrimitiveType(typeMirror)) {
            return primitiveWriteMethodNameMap.get(typeMirror.getKind());
        }
        if (processor.isString(typeMirror)) {
            return MNAME_WRITE_STRING;
        }
        if (processor.isByteArray(typeMirror)) {
            return MNAME_WRITE_BYTES;
        }
        return MNAME_WRITE_OBJECT;
    }

    /** 获取reader读字段的方法名 */
    private String getReadMethodName(VariableElement variableElement) {
        TypeMirror typeMirror = variableElement.asType();
        if (isPrimitiveType(typeMirror)) {
            return primitiveReadMethodNameMap.get(typeMirror.getKind());
        }
        if (processor.isString(typeMirror)) {
            return MNAME_READ_STRING;
        }
        if (processor.isByteArray(typeMirror)) {
            return MNAME_READ_BYTES;
        }
        return MNAME_READ_OBJECT;
    }

    private static boolean isPrimitiveType(TypeMirror typeMirror) {
        return typeMirror.getKind().isPrimitive();
    }
    // endregion
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy