io.activej.codegen.ClassGenerator Maven / Gradle / Ivy
Show all versions of activej-codegen Show documentation
/*
* Copyright (C) 2020 ActiveJ LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.activej.codegen;
import io.activej.codegen.expression.Expression;
import io.activej.codegen.expression.Expressions;
import io.activej.codegen.expression.impl.Constant;
import io.activej.codegen.util.DefiningClassWriter;
import io.activej.common.builder.AbstractBuilder;
import org.jetbrains.annotations.ApiStatus.Internal;
import org.jetbrains.annotations.VisibleForTesting;
import org.objectweb.asm.Type;
import org.objectweb.asm.commons.GeneratorAdapter;
import org.objectweb.asm.commons.Method;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.reflect.Field;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.Stream;
import static io.activej.codegen.DefiningClassLoader.createInstance;
import static io.activej.codegen.expression.Expressions.*;
import static io.activej.codegen.util.Utils.getStringSetting;
import static io.activej.codegen.util.Utils.isJvmPrimitive;
import static io.activej.common.Checks.checkState;
import static java.util.stream.Collectors.toList;
import static org.objectweb.asm.Opcodes.*;
import static org.objectweb.asm.Type.*;
import static org.objectweb.asm.commons.Method.getMethod;
/**
* Intends for dynamic description of the object behaviour in runtime
*
* @param type of class to be generated
*/
@SuppressWarnings({"WeakerAccess", "unused"})
public final class ClassGenerator {
private final Logger logger = LoggerFactory.getLogger(getClass());
public static final String GENERATED_MARKER = "$GENERATED";
public static final String PACKAGE_PREFIX = getStringSetting(ClassGenerator.class, "packagePrefix", "io.activej.codegen.");
private static final AtomicInteger COUNTER = new AtomicInteger();
private static final ConcurrentHashMap STATIC_CONSTANTS = new ConcurrentHashMap<>();
final Class> superclass;
final List> interfaces;
private String className;
private final String autoClassName;
final Map> fields = new LinkedHashMap<>();
final Set fieldsFinal = new HashSet<>();
final Set fieldsStatic = new HashSet<>();
final Map fieldExpressions = new HashMap<>();
final Map fieldConstants = new HashMap<>();
final Map methods = new LinkedHashMap<>();
final Map staticMethods = new LinkedHashMap<>();
private final Map constructors = new LinkedHashMap<>();
private final List staticInitializers = new ArrayList<>();
private ClassGenerator(Class> superclass, List> interfaces, String className) {
this.superclass = superclass;
this.interfaces = interfaces;
this.autoClassName = PACKAGE_PREFIX + className;
}
/**
* Creates a new builder of ClassBuilder
*
* @param implementation type of dynamic class
* @param interfaces additional interfaces for the class to implement
*/
public static ClassGenerator.Builder builder(Class> implementation, List> interfaces) {
if (!interfaces.stream().allMatch(Class::isInterface))
throw new IllegalArgumentException();
if (implementation.isInterface()) {
return new ClassGenerator(
Object.class,
Stream.concat(Stream.of(implementation), interfaces.stream()).collect(toList()),
implementation.getName()).new Builder();
} else {
return new ClassGenerator(
implementation,
interfaces,
implementation.getName()).new Builder();
}
}
/**
* Creates a new builder of ClassBuilder
*
* @param implementation type of dynamic class
* @param interfaces additional interfaces for the class to implement
*/
public static ClassGenerator.Builder builder(Class implementation, Class>... interfaces) {
return builder(implementation, List.of(interfaces));
}
public final class Builder extends AbstractBuilder> {
private Builder() {
withStaticField(GENERATED_MARKER, Void.class);
}
/**
* Sets a class name for the generated class
*
* @param name a class name of the generated class
*/
public Builder withClassName(String name) {
checkNotBuilt(this);
ClassGenerator.this.className = name;
return this;
}
/**
* Adds static initializer for the generated class
* (code that would be executed inside a static initialization block of the generated class)
*
* @param expression an expression that represents static initializer
*/
public Builder withStaticInitializer(Expression expression) {
checkNotBuilt(this);
staticInitializers.add(expression);
return this;
}
public Builder withStaticInitializer(Supplier expressionFn) {
return withStaticInitializer(expressionFn.get());
}
/**
* Adds a constructor for the given class with specified argument types and an {@link Exception}
* that would be executed inside the constructor
*
* @param argumentTypes types of arguments to the constructor
* @param expression an expression that would be executed inside the constructor
*/
public Builder withConstructor(List extends Class>> argumentTypes, Expression expression) {
checkNotBuilt(this);
constructors.put(new Method("", VOID_TYPE, argumentTypes.stream().map(Type::getType).toArray(Type[]::new)), expression);
return this;
}
public Builder withConstructor(List extends Class>> argumentTypes, Supplier expressionFn) {
return withConstructor(argumentTypes, expressionFn.get());
}
/**
* Adds a constructor for the given class with an {@link Exception}
* that would be executed inside the constructor
*
* @param expression an expression that would be executed inside the constructor
* @see #withConstructor(List, Expression)
*/
public Builder withConstructor(Expression expression) {
checkNotBuilt(this);
return withConstructor(List.of(), expression);
}
public Builder withConstructor(Supplier expressionFn) {
return withConstructor(expressionFn.get());
}
/**
* Adds a new uninitialized field for a class
*
* @param field name of the field
* @param type type of the field
*/
public Builder withField(String field, Class> type) {
checkNotBuilt(this);
fields.put(field, type);
return this;
}
/**
* Adds a new initialized field for a class
*
* @param field name of the field
* @param type type of the field
* @param value an expression that represents how the new field will be initialized
*/
public Builder withField(String field, Class> type, Expression value) {
checkNotBuilt(this);
fields.put(field, type);
fieldExpressions.put(field, value);
return this;
}
public Builder withField(String field, Class> type, Supplier valueFn) {
return withField(field, type, valueFn.get());
}
/**
* Adds a new final initialized field for a class
*
* @param field name of the field
* @param type type of the field
* @param value an expression that represents how the new final field will be initialized
*/
public Builder withFinalField(String field, Class> type, Expression value) {
checkNotBuilt(this);
fields.put(field, type);
fieldsFinal.add(field);
fieldExpressions.put(field, value);
return this;
}
public Builder withFinalField(String field, Class> type, Supplier valueFn) {
return withFinalField(field, type, valueFn.get());
}
/**
* Adds a new method to a class
*
* @param methodName name of the method
* @param returnType a return type of the method
* @param argumentTypes list of the method's arguments
* @param expression an expression that represents the method's body
*/
public Builder withMethod(String methodName, Class> returnType, List extends Class>> argumentTypes, Expression expression) {
checkNotBuilt(this);
methods.put(new Method(methodName, getType(returnType), argumentTypes.stream().map(Type::getType).toArray(Type[]::new)), expression);
return this;
}
public Builder withMethod(String methodName, Class> returnType, List extends Class>> argumentTypes, Supplier expressionFn) {
return withMethod(methodName, returnType, argumentTypes, expressionFn.get());
}
/**
* Adds a new method to a class
*
* @param methodName name of the method
* @param expression an expression that represents the method's body
*/
public Builder withMethod(String methodName, Expression expression) {
checkNotBuilt(this);
if (methodName.contains("(")) {
Method method = Method.getMethod(methodName);
methods.put(method, expression);
return this;
}
Method foundMethod = null;
List> listOfMethods = new ArrayList<>();
listOfMethods.add(List.of(Object.class.getMethods()));
listOfMethods.add(List.of(superclass.getMethods()));
listOfMethods.add(List.of(superclass.getDeclaredMethods()));
for (Class> type : interfaces) {
listOfMethods.add(List.of(type.getMethods()));
listOfMethods.add(List.of(type.getDeclaredMethods()));
}
for (List list : listOfMethods) {
for (java.lang.reflect.Method m : list) {
if (m.getName().equals(methodName)) {
Method method = getMethod(m);
if (foundMethod != null && !method.equals(foundMethod))
throw new IllegalArgumentException("Method " + method + " collides with " + foundMethod);
foundMethod = method;
}
}
}
if (foundMethod == null)
throw new IllegalArgumentException(String.format("Could not find method '%s'", methodName));
methods.put(foundMethod, expression);
return this;
}
public Builder withMethod(String methodName, Supplier expression) {
return withMethod(methodName, expression.get());
}
/**
* Adds a static method to a class
*
* @param methodName a name of the method
* @param returnClass the method's return type
* @param argumentTypes types of the method's arguments
* @param expression an expression that represents the method's body
*/
public Builder withStaticMethod(String methodName, Class> returnClass, List extends Class>> argumentTypes, Expression expression) {
checkNotBuilt(this);
staticMethods.put(new Method(methodName, getType(returnClass), argumentTypes.stream().map(Type::getType).toArray(Type[]::new)), expression);
return this;
}
public Builder withStaticMethod(String methodName, Class> returnClass, List extends Class>> argumentTypes, Supplier expressionFn) {
return withStaticMethod(methodName, returnClass, argumentTypes, expressionFn.get());
}
/**
* Adds a new uninitialized static field for a class
*
* @param field name of the field
* @param type type of the field
*/
public Builder withStaticField(String field, Class> type) {
checkNotBuilt(this);
fields.put(field, type);
fieldsStatic.add(field);
return this;
}
/**
* Adds a new initialized static field for a class
*
* @param field name of the field
* @param type type of the field
* @param value an expression that represents how the new static field will be initialized
*/
public Builder withStaticField(String field, Class> type, Expression value) {
checkNotBuilt(this);
fields.put(field, type);
fieldsStatic.add(field);
fieldExpressions.put(field, value);
return this;
}
public Builder withStaticField(String field, Class> type, Supplier valueFn) {
return withStaticField(field, type, valueFn.get());
}
/**
* Adds a new static final initialized field for a class
*
* @param field name of the field
* @param type type of the field
* @param value an expression that represents how the new static final field will be initialized
*/
public Builder withStaticFinalField(String field, Class> type, Expression value) {
checkNotBuilt(this);
fields.put(field, type);
fieldsStatic.add(field);
fieldsFinal.add(field);
fieldExpressions.put(field, value);
if (value instanceof Constant constant) {
fieldConstants.put(field, constant);
} else {
fieldExpressions.put(field, value);
}
return this;
}
public Builder withStaticFinalField(String field, Class> type, Supplier valueFn) {
return withStaticFinalField(field, type, valueFn.get());
}
@Override
protected ClassGenerator doBuild() {
return ClassGenerator.this;
}
}
/**
* Returns a static constant by an integer id
*
* This method is used internally by generated classes for constant initialization
*
* @param id id of a static constant
* @return static constant
*/
@Internal
public static Object getStaticConstant(int id) {
return STATIC_CONSTANTS.get(id);
}
/**
* Returns a size of static constants
*/
@VisibleForTesting
public static int getStaticConstantsSize() {
return STATIC_CONSTANTS.size();
}
/**
* Clears all static constants
*/
@VisibleForTesting
public static void clearStaticConstants() {
STATIC_CONSTANTS.clear();
}
/**
* Defines a class from {@code this} {@link ClassGenerator} using given {@link DefiningClassLoader}
*
* @param classLoader a class loader that would be used to define a class
* @return a defined class
*/
public Class generateClass(DefiningClassLoader classLoader) {
try (GeneratedBytecode generatedBytecode = generateBytecode(classLoader)) {
//noinspection unchecked
return (Class) generatedBytecode.generateClass(classLoader);
}
}
/**
* Defines a class from {@code this} {@link ClassGenerator} using given {@link DefiningClassLoader}
* and creates an instance of defined class.
*
* @param classLoader a class loader that would be used to define a class
* @param arguments an array of parameters that would be passed to the constructor of a defined class
* @return an instance of a defined class
*/
public T generateClassAndCreateInstance(DefiningClassLoader classLoader, Object... arguments) {
Class generatedClass = generateClass(classLoader);
return createInstance(generatedClass, arguments);
}
/**
* Uses a given class loader to generate a bytecode out of this class builder.
*
* @param classLoader a class loader for generating a bytecode
* @return a generated bytecode which consists of actual bytecode as well as a class name
* @see GeneratedBytecode
*/
public GeneratedBytecode generateBytecode(ClassLoader classLoader) {
return generateBytecode(classLoader, className != null ? className : autoClassName + '_' + COUNTER.incrementAndGet());
}
/**
* Uses a given class loader to generate a bytecode out of this class builder.
*
* @param classLoader a class loader for generating a bytecode
* @param className a name of a class
* @return a generated bytecode which consists of actual bytecode as well as a class name
* @see GeneratedBytecode
*/
public GeneratedBytecode generateBytecode(ClassLoader classLoader, String className) {
DefiningClassWriter cw = new DefiningClassWriter(classLoader);
Type classType = getType('L' + className.replace('.', '/') + ';');
cw.visit(V1_6, ACC_PUBLIC + ACC_FINAL + ACC_SUPER,
classType.getInternalName(),
null,
getInternalName(superclass),
interfaces.stream().map(Type::getInternalName).toArray(String[]::new));
Map constantMap = new LinkedHashMap<>();
Map constructors = new LinkedHashMap<>(this.constructors);
if (constructors.isEmpty()) {
constructors.put(new Method("", VOID_TYPE, new Type[]{}), superConstructor());
}
for (Map.Entry entry : constructors.entrySet()) {
Method method = entry.getKey();
GeneratorAdapter g = new GeneratorAdapter(ACC_PUBLIC, method, null, null, cw);
Context ctx = new Context(classLoader, this, g, classType, method, constantMap);
Type type = entry.getValue().load(ctx);
if (type != null) {
ctx.cast(type, method.getReturnType());
g.returnValue();
}
g.endMethod();
}
for (String field : this.fields.keySet()) {
cw.visitField(
ACC_PUBLIC + (fieldsStatic.contains(field) ? ACC_STATIC : 0) +
(fieldsFinal.contains(field) ? ACC_FINAL : 0),
field, getType(this.fields.get(field)).getDescriptor(), null, null);
}
for (Method m : this.methods.keySet()) {
GeneratorAdapter g = new GeneratorAdapter(ACC_PUBLIC + ACC_FINAL, m, null, null, cw);
Context ctx = new Context(classLoader, this, g, classType, m, constantMap);
Expression expression = this.methods.get(m);
Type type = expression.load(ctx);
if (type != null) {
ctx.cast(type, m.getReturnType());
g.returnValue();
}
g.endMethod();
}
for (Method m : this.staticMethods.keySet()) {
GeneratorAdapter g = new GeneratorAdapter(ACC_PUBLIC + ACC_STATIC + ACC_FINAL, m, null, null, cw);
Context ctx = new Context(classLoader, this, g, classType, m, constantMap);
Expression expression = this.staticMethods.get(m);
Type type = expression.load(ctx);
if (type != null) {
ctx.cast(type, m.getReturnType());
g.returnValue();
}
g.endMethod();
}
{
Method m = getMethod("void ()");
GeneratorAdapter g = new GeneratorAdapter(ACC_PUBLIC + ACC_STATIC, m, null, null, cw);
Context ctx = new Context(classLoader, this, g, classType, m, constantMap);
for (Map.Entry entry : this.fieldConstants.entrySet()) {
String field = entry.getKey();
Constant expression = entry.getValue();
if (!isJvmPrimitive(expression.value)) {
STATIC_CONSTANTS.put(expression.id, expression.value);
Expressions.set(staticField(field), cast(
staticCall(ClassGenerator.class, "getStaticConstant", value(expression.id)),
this.fields.get(field)))
.load(ctx);
} else {
Expressions.set(staticField(field), expression).load(ctx);
}
}
for (Map.Entry entry : constantMap.entrySet()) {
String field = entry.getKey();
Constant expression = entry.getValue();
cw.visitField(ACC_PUBLIC + ACC_STATIC + ACC_FINAL,
field, getType(expression.getValueClass()).getDescriptor(), null, null);
checkState(!isJvmPrimitive(expression.value));
STATIC_CONSTANTS.put(expression.id, expression.value);
Type typeFrom = staticCall(ClassGenerator.class, "getStaticConstant", value(expression.id)).load(ctx);
g.checkCast(getType(expression.getValueClass()));
g.putStatic(ctx.getSelfType(), field, getType(expression.getValueClass()));
}
for (Expression initializer : staticInitializers) {
initializer.load(ctx);
}
g.returnValue();
g.endMethod();
}
cw.visitEnd();
byte[] bytecode = cw.toByteArray();
return new GeneratedBytecode(className, bytecode) {
@Override
protected void touchGeneratedClass(Class> generatedClass) {
try {
Field field = generatedClass.getField(GENERATED_MARKER);
//noinspection ResultOfMethodCallIgnored
field.get(null);
} catch (IllegalAccessException | NoSuchFieldException e) {
throw new AssertionError(e);
}
}
@Override
public void close() {
for (Map.Entry entry : fieldConstants.entrySet()) {
STATIC_CONSTANTS.remove(entry.getValue().id);
}
for (Constant expression : constantMap.values()) {
STATIC_CONSTANTS.remove(expression.id);
}
}
};
}
}