it.auties.protobuf.tool.schema.MessageSchemaCreator Maven / Gradle / Ivy
package it.auties.protobuf.tool.schema;
import it.auties.protobuf.base.ProtobufName;
import it.auties.protobuf.base.ProtobufProperty;
import it.auties.protobuf.base.ProtobufType;
import it.auties.protobuf.parser.statement.*;
import it.auties.protobuf.tool.util.AstElements;
import it.auties.protobuf.tool.util.AstUtils;
import spoon.reflect.code.*;
import spoon.reflect.declaration.*;
import spoon.reflect.factory.Factory;
import spoon.reflect.reference.CtFieldReference;
import spoon.reflect.reference.CtTypeReference;
import spoon.support.reflect.reference.CtArrayTypeReferenceImpl;
import java.util.*;
import java.util.stream.Collectors;
public final class MessageSchemaCreator extends SchemaCreator, ProtobufMessageStatement> {
private static final List> ORDER = List.of(
ProtobufFieldStatement.class,
ProtobufOneOfStatement.class,
ProtobufMessageStatement.class,
ProtobufEnumStatement.class
);
public MessageSchemaCreator(CtClass> ctType, ProtobufMessageStatement protoStatement, boolean accessors, Factory factory) {
super(ctType, protoStatement, accessors, factory);
}
public MessageSchemaCreator(ProtobufMessageStatement protoStatement, boolean accessors, Factory factory) {
super(protoStatement, accessors, factory);
}
public MessageSchemaCreator(CtClass> ctType, CtType> parent, ProtobufMessageStatement protoStatement, boolean accessors, Factory factory) {
super(ctType, parent, protoStatement, accessors, factory);
}
@Override
public CtClass> createSchema() {
this.ctType = createClass();
createMessage();
return ctType;
}
@Override
public CtClass> update() {
this.ctType = Objects.requireNonNullElseGet(ctType, this::createClass);
createMessage();
return ctType;
}
private void createMessage() {
protoStatement.statements()
.stream()
.sorted(Comparator.comparingInt(entry -> ORDER.indexOf(entry.getClass())))
.forEach(this::createMessageStatement);
createReservedMethod(true);
createReservedMethod(false);
}
private void createReservedMethod(boolean indexes){
var methodName = indexes ? "reservedFieldIndexes" : "reservedFieldNames";
var existing = ctType.getMethod(methodName);
if(existing != null){
return;
}
var elements = indexes ? protoStatement.reservedIndexes() : protoStatement.reservedNames();
if(elements.isEmpty()){
return;
}
var returnType = factory.Type().createReference(AstElements.LIST);
var returnTypeArg = indexes ? factory.Type().integerType() : factory.Type().stringType();
returnType.addActualTypeArgument(returnTypeArg);
var method = factory.createMethod(
ctType,
Set.of(ModifierKind.PUBLIC),
returnType,
methodName,
List.of(),
Set.of()
);
method.addAnnotation(factory.createAnnotation(factory.createReference(AstElements.OVERRIDE)));
var body = factory.createBlock();
var returnStatement = factory.createReturn();
var ofMethod = factory.Method().createReference(
factory.createReference(AstElements.LIST),
factory.createReference(AstElements.LIST),
"of",
factory.createArrayTypeReference()
);
var literals = elements.stream()
.map(factory::createLiteral)
.collect(Collectors.toCollection(ArrayList>::new));
var ofInvocation = factory.createInvocation(
factory.createTypeAccess(factory.createReference(AstElements.ARRAYS)),
ofMethod,
literals
);
returnStatement.setReturnedExpression(ofInvocation);
body.addStatement(returnStatement);
method.setBody(body);
}
private void createMessageStatement(ProtobufStatement statement) {
if (statement instanceof ProtobufMessageStatement messageStatement) {
createNestedMessageWithLookup(messageStatement);
return;
}
if (statement instanceof ProtobufEnumStatement enumStatement) {
createNestedEnumWithLookup(enumStatement);
return;
}
if (statement instanceof ProtobufOneOfStatement oneOfStatement) {
createNestedOneOf(oneOfStatement);
return;
}
if (statement instanceof ProtobufFieldStatement fieldStatement) {
createField(fieldStatement);
return;
}
throw new UnsupportedOperationException("Cannot create schema for statement: " + statement);
}
private boolean createField(ProtobufFieldStatement fieldStatement) {
var existingField = getExistingField(fieldStatement);
var field = createFieldInternal(fieldStatement, existingField);
var accessor = ctType.getMethod(fieldStatement.name());
if(accessor == null){
createFieldAccessor(fieldStatement, field);
}
return existingField != null;
}
private CtField> createFieldInternal(ProtobufFieldStatement fieldStatement, CtField> existingField) {
if(existingField == null){
return createProtobufProperty(fieldStatement);
}
if (!fieldStatement.type().type().isMessage()) {
return existingField;
}
var expectedName = fieldStatement.type().name();
var actualName = existingField.getType().getSimpleName();
if (Objects.equals(expectedName, actualName)) {
return existingField;
}
var target = existingField.getType();
if (target.getDeclaration() == null || hasProtobufMessageName(target.getDeclaration())) {
return existingField;
}
var name = factory.createAnnotation(
factory.createReference(AstElements.PROTOBUF_MESSAGE_NAME)
);
name.addValue("value", expectedName);
target.getDeclaration().addAnnotation(name);
return existingField;
}
private boolean hasProtobufMessageName(CtType> target) {
return target.getAnnotations()
.stream()
.anyMatch(entry -> entry.getName().equalsIgnoreCase(ProtobufName.class.getSimpleName()));
}
private CtField> getExistingField(ProtobufFieldStatement fieldStatement) {
return ctType.getFields()
.stream()
.filter(entry -> {
var annotation = entry.getAnnotation(ProtobufProperty.class);
return annotation != null && annotation.index() == fieldStatement.index();
})
.findFirst()
.orElse(null);
}
private CtField> createProtobufProperty(ProtobufFieldStatement fieldStatement) {
CtField> ctField = factory.createField(
ctType,
Set.of(ModifierKind.PRIVATE),
AstUtils.createReference(fieldStatement, true, factory),
fieldStatement.name()
);
if(fieldStatement.type().type() == ProtobufType.BYTES){
ctField.getType().putMetadata("DeclarationKind", CtArrayTypeReferenceImpl.DeclarationKind.TYPE);
}
var annotation = factory.createAnnotation(factory.createReference(AstElements.PROTOBUF_PROPERTY));
annotation.addValue("index", fieldStatement.index());
annotation.addValue("type", fieldStatement.type().type());
if(fieldStatement.required()){
annotation.addValue("required", true);
var nonNull = factory.createAnnotation(factory.createReference(AstElements.NON_NULL));
ctField.addAnnotation(nonNull);
}
if(fieldStatement.repeated()){
annotation.addValue("repeated", true);
annotation.addValue("implementation", factory.createClassAccess(AstUtils.createReference(fieldStatement, false, factory)));
createBuilderMethod(fieldStatement, ctField);
}
if(fieldStatement.packed()){
annotation.addValue("packed", true);
}
if (fieldStatement.deprecated()) {
var deprecated = factory.createAnnotation(factory.createReference(AstElements.DEPRECATED));
ctField.addAnnotation(deprecated);
}
if (fieldStatement.defaultValue() != null) {
var defaultBuilder = factory.createAnnotation(factory.createReference(AstElements.DEFAULT));
ctField.addAnnotation(defaultBuilder);
createDefaultExpression(ctField, fieldStatement);
}
annotation.addValue("name", fieldStatement.name());
ctField.addAnnotation(annotation);
return ctField;
}
@SuppressWarnings({"unchecked", "rawtypes"})
private void createFieldAccessor(ProtobufFieldStatement fieldStatement, CtField> ctField) {
if(!accessors){
return;
}
var returnType = createFieldAccessorType(fieldStatement, ctField);
var accessor = factory.createMethod(
ctType,
Set.of(ModifierKind.PUBLIC),
returnType,
ctField.getSimpleName(),
List.of(),
Set.of()
);
var body = factory.createBlock();
accessor.setBody(body);
CtReturn returnStatement = factory.createReturn();
if(fieldStatement.required() || fieldStatement.repeated()) {
returnStatement.setReturnedExpression(createFieldRead(fieldStatement));
}else {
var optionalType = factory.createReference(AstElements.OPTIONAL);
var ofMethod = factory.Method().createReference(
optionalType,
optionalType,
"ofNullable",
ctField.getType()
);
var ofInvocation = factory.createInvocation(
factory.createTypeAccess(optionalType),
ofMethod,
createFieldRead(fieldStatement)
);
returnStatement.setReturnedExpression(ofInvocation);
}
body.addStatement(returnStatement);
}
private CtTypeReference> createFieldAccessorType(ProtobufFieldStatement fieldStatement, CtField> ctField) {
if(fieldStatement.required() || fieldStatement.repeated()){
return ctField.getType();
}
var returnType = factory.createReference(AstElements.OPTIONAL);
returnType.addActualTypeArgument(ctField.getType());
return returnType;
}
@SuppressWarnings({"unchecked", "rawtypes"})
private void createBuilderMethod(ProtobufFieldStatement fieldStatement, CtField> ctField) {
var builderClass = getOrCreateBuilder();
var existing = builderClass.getMethod(fieldStatement.name());
if(existing != null){
return;
}
var method = factory.createMethod(
builderClass,
Set.of(ModifierKind.PUBLIC),
builderClass.getReference(),
fieldStatement.name(),
List.of(),
Set.of()
);
var parameter = factory.createParameter(
method,
ctField.getType(),
fieldStatement.name()
);
var body = factory.createBlock();
CtFieldReference localFieldReference = factory.Field().createReference(
builderClass.getReference(),
ctField.getType(),
fieldStatement.name()
);
var localFieldRead = factory.createFieldRead();
localFieldRead.setTarget(factory.createThisAccess(builderClass.getReference()));
localFieldRead.setVariable(localFieldReference);
CtBinaryOperator isNullCondition = factory.createBinaryOperator(
localFieldRead,
factory.createLiteral(null),
BinaryOperatorKind.EQ
);
var newArrayList = factory.createConstructorCall(
factory.createReference(AstElements.ARRAY_LIST)
);
var newArrayListDiamond = AstUtils.createReference(fieldStatement, false, factory);
newArrayListDiamond.setImplicit(true);
newArrayList.getType().addActualTypeArgument(newArrayListDiamond);
var isNullThen = factory.createVariableAssignment(
localFieldReference,
false,
newArrayList
);
var isNullIf = factory.createIf();
isNullIf.setCondition(isNullCondition);
isNullIf.setThenStatement(isNullThen);
body.addStatement(isNullIf);
var addAllMethod = factory.Method().createReference(
builderClass.getReference(),
localFieldRead.getType(),
"addAll",
factory.createReference(AstElements.COLLECTION)
);
var addAllInvocation = factory.createInvocation(
localFieldRead,
addAllMethod,
factory.createVariableRead(factory.createParameterReference(parameter), false)
);
body.addStatement(addAllInvocation);
CtReturn returnStatement = factory.createReturn();
returnStatement.setReturnedExpression(factory.createThisAccess(builderClass.getReference()));
body.addStatement(returnStatement);
method.setBody(body);
}
private CtClass> getOrCreateBuilder() {
var existing = AstUtils.getBuilderClass(ctType);
if (existing != null) {
return existing;
}
CtClass> builderClass = factory.createClass("%sBuilder".formatted(ctType.getSimpleName()));
builderClass.setModifiers(Set.of(ModifierKind.PUBLIC, ModifierKind.STATIC));
builderClass.setParent(ctType);
ctType.addNestedType(builderClass);
return builderClass;
}
@SuppressWarnings({"unchecked", "rawtypes"})
private void createDefaultExpression(CtField ctField, ProtobufFieldStatement fieldStatement) {
var literal = factory.createLiteral(fieldStatement.defaultValue());
ctField.setDefaultExpression(literal);
}
private void createNestedOneOf(ProtobufOneOfStatement oneOfStatement) {
var updating = createOneOfStatements(oneOfStatement);
var enumDescriptor = createOneOfEnumDescriptor(oneOfStatement, updating);
createOneOfMethod(oneOfStatement, enumDescriptor);
}
private void createOneOfMethod(ProtobufOneOfStatement oneOfStatement, CtEnum> oneOfEnum) {
var existing = ctType.getMethods()
.stream()
.filter(entry -> entry.getType().getSimpleName().equals(oneOfEnum.getSimpleName()))
.findFirst()
.orElse(null);
if (existing != null) {
return;
}
var body = factory.createBlock();
var iterator = oneOfStatement.statements().iterator();
var nullTarget = factory.createLiteral(null);
while (iterator.hasNext()){
var entry = iterator.next();
if (!iterator.hasNext()) {
var returnStatement = createReturn(oneOfEnum, entry);
body.addStatement(returnStatement);
continue;
}
var check = factory.createIf();
var fieldRead = createFieldRead(entry);
CtBinaryOperator condition = factory.createBinaryOperator(fieldRead, nullTarget, BinaryOperatorKind.NE);
check.setCondition(condition);
var block = factory.createBlock();
block.addStatement(createReturn(oneOfEnum, entry));
check.setThenStatement(block);
body.addStatement(check);
}
var method = factory.createMethod(
ctType,
Set.of(ModifierKind.PUBLIC),
oneOfEnum.getReference(),
oneOfStatement.methodName(),
List.of(),
Set.of()
);
method.setBody(body);
}
private boolean createOneOfStatements(ProtobufOneOfStatement oneOfStatement) {
return oneOfStatement.statements()
.stream()
.anyMatch(this::createField);
}
private CtEnum> createOneOfEnumDescriptor(ProtobufOneOfStatement oneOfStatement, boolean updatingInnerScope) {
var existing = !updating ? null : (CtEnum>) AstUtils.getProtobufClass(factory.getModel(), oneOfStatement.className(), true);
if (existing != null) {
return existing;
}
var enumStatement = createOneOfEnum(oneOfStatement);
if(!updating || !updatingInnerScope){
return createNestedEnum(enumStatement, null);
}
var possibleEnumType = getExistingOneOfDescriptor(enumStatement);
if (possibleEnumType != null) {
log.info("Oneof statement %s in %s doesn't have an enum descriptor(expected %s)."
.formatted(oneOfStatement.name(), oneOfStatement.parent().name(), oneOfStatement.className()));
log.info("%s looks like the missing descriptor, but no name override was specified using @ProtobufName so this is just speculation."
.formatted(ctType.getSimpleName()));
log.info("Type yes to use this enum, otherwise type enter to continue");
var scanner = new Scanner(System.in);
if (scanner.nextLine().equalsIgnoreCase("yes")) {
var name = factory.createAnnotation(factory.createReference(AstElements.PROTOBUF_MESSAGE_NAME));
name.addValue("value", oneOfStatement.className());
possibleEnumType.addAnnotation(name);
return createNestedEnum(enumStatement, possibleEnumType);
}
}
log.info("Oneof statement %s in %s doesn't have an enum descriptor(expected %s)."
.formatted(oneOfStatement.name(), oneOfStatement.parent().name(), oneOfStatement.className()));
log.info("Type its name or click enter to generate it:");
var suggestedNames = AstUtils.getSuggestedNames(factory.getModel(), oneOfStatement.className(), true);
log.info("Suggested names: %s".formatted(suggestedNames));
var scanner = new Scanner(System.in);
var newName = scanner.nextLine();
if (newName.isBlank()) {
return createNestedEnum(enumStatement, null);
}
var result = (CtEnum>) AstUtils.getProtobufClass(factory.getModel(), newName, true);
if(result != null){
var name = factory.createAnnotation(factory.createReference(AstElements.PROTOBUF_MESSAGE_NAME));
name.addValue("value", oneOfStatement.className());
result.addAnnotation(name);
return createNestedEnum(enumStatement, result);
}
log.info("Enum %s doesn't exist, try again".formatted(newName));
return createOneOfEnumDescriptor(oneOfStatement, true);
}
private CtEnum> getExistingOneOfDescriptor(ProtobufEnumStatement enumStatement) {
return ctType.getNestedTypes()
.stream()
.filter(CtTypeInformation::isEnum)
.map(entry -> (CtEnum>) entry)
.filter(entry -> inferExistingOneOfDescriptor(enumStatement, entry))
.findFirst()
.orElse(null);
}
private boolean inferExistingOneOfDescriptor(ProtobufEnumStatement enumStatement, CtEnum> entry) {
var enumStatementEntries = enumStatement.statements()
.stream()
.map(ProtobufFieldStatement::name)
.collect(Collectors.toUnmodifiableSet());
var success = entry.getEnumValues()
.stream()
.filter(enumValue -> enumStatementEntries.contains(enumValue.getSimpleName()
.toLowerCase()
.replaceAll("_", "")))
.count();
return ((float) success / enumStatementEntries.size()) > 0.5;
}
private ProtobufEnumStatement createOneOfEnum(ProtobufOneOfStatement oneOfStatement) {
var fieldsCounter = 0;
var enumStatement = new ProtobufEnumStatement(oneOfStatement.className(), oneOfStatement.packageName(), oneOfStatement.parent());
var defaultStatement = new ProtobufFieldStatement(
fieldsCounter++,
"UNKNOWN",
enumStatement.packageName(),
oneOfStatement.parent()
);
enumStatement.addStatement(defaultStatement);
for (var fieldStatement : oneOfStatement.statements()) {
enumStatement.addStatement(new ProtobufFieldStatement(
fieldsCounter++,
fieldStatement.nameAsConstant(),
enumStatement.packageName(),
oneOfStatement.parent()
));
}
return enumStatement;
}
private void createNestedEnumWithLookup(ProtobufEnumStatement enumStatement) {
var existing = (CtEnum>) AstUtils.getProtobufClass(factory.getModel(), enumStatement.name(), true);
createNestedEnum(enumStatement, existing);
}
private CtEnum> createNestedEnum(ProtobufEnumStatement enumStatement, CtEnum> existing) {
var creator = new EnumSchemaCreator(existing, ctType, enumStatement, accessors, factory);
var result = creator.update();
if(existing == null) {
result.setParent(parent);
ctType.addNestedType(result);
}
return result;
}
private void createNestedMessageWithLookup(ProtobufMessageStatement messageStatement) {
var existing = AstUtils.getProtobufClass(factory.getModel(), messageStatement.name(), false);
var creator = new MessageSchemaCreator(existing, ctType, messageStatement, accessors, factory);
var result = creator.update();
if(existing == null) {
result.setParent(parent);
ctType.addNestedType(result);
}
}
private CtClass> createClass() {
CtClass> ctClass = factory.createClass(protoStatement.staticallyQualifiedName());
ctClass.setModifiers(Set.of(ModifierKind.PUBLIC));
if(protoStatement.nested()){
ctClass.addModifier(ModifierKind.STATIC);
}
ctClass.addSuperInterface(factory.createReference(AstElements.PROTOBUF_MESSAGE));
ctClass.addAnnotation(factory.createAnnotation(factory.createReference(AstElements.ALL_ARGS_CONSTRUCTOR)));
if(!accessors){
ctClass.addAnnotation(factory.createAnnotation(factory.createReference(AstElements.DATA)));
}
ctClass.addAnnotation(factory.createAnnotation(factory.createReference(AstElements.JACKSONIZED)));
ctClass.addAnnotation(factory.createAnnotation(factory.createReference(AstElements.BUILDER)));
var name = factory.createAnnotation(factory.createReference(AstElements.PROTOBUF_MESSAGE_NAME));
name.addValue("value", protoStatement.name());
ctClass.addAnnotation(name);
return ctClass;
}
@SuppressWarnings({"unchecked", "rawtypes"})
private CtFieldRead createFieldRead(ProtobufFieldStatement entry) {
CtFieldRead fieldRead = factory.createFieldRead();
fieldRead.setType(ctType.getReference());
fieldRead.setVariable(getExistingField(entry).getReference());
return fieldRead;
}
private CtReturn> createReturn(CtEnum> javaEnum, ProtobufFieldStatement entry) {
return createReturn(javaEnum, entry, entry.nameAsConstant(), false);
}
@SuppressWarnings({"unchecked", "rawtypes"})
private CtReturn createReturn(CtEnum javaEnum, ProtobufFieldStatement entry, String constantName, boolean annotate) {
CtEnumValue constant = javaEnum.getEnumValue(constantName);
if(constant == null){
log.info("Enum constant %s doesn't exist inside %s".formatted(constantName, javaEnum.getSimpleName()));
log.info("Known constants: %s".formatted(javaEnum.getEnumValues()
.stream()
.map(element -> ((CtNamedElement) element).getSimpleName())
.collect(Collectors.joining(", "))));
log.info("Type the correct name or hit enter to generate it");
var scanner = new Scanner(System.in);
var newName = scanner.nextLine();
if (!newName.isBlank()) {
return createReturn(javaEnum, entry, newName, true);
}
var enumInitializer = factory.createConstructorCall();
enumInitializer.addArgument(factory.createLiteral(entry.index()));
var enumValue = factory.createEnumValue();
enumValue.setSimpleName(entry.nameAsConstant());
enumValue.setAssignment(enumInitializer);
enumValue.setType(javaEnum.getReference());
enumValue.setParent(javaEnum);
constant = enumValue;
}
if(annotate) {
var name = factory.createAnnotation(
factory.createReference(AstElements.PROTOBUF_MESSAGE_NAME)
);
name.addValue("value", entry.nameAsConstant());
constant.addAnnotation(name);
}
constant.setType(javaEnum.getReference());
CtFieldRead fieldRead = factory.createFieldRead();
fieldRead.setType(javaEnum.getReference());
fieldRead.setVariable(constant.getReference());
fieldRead.setTarget(factory.Code().createTypeAccess(javaEnum.getReference()));
return factory.createReturn()
.setReturnedExpression(fieldRead);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy