org.neo4j.codegen.api.CodeGeneration.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of neo4j-codegen Show documentation
Show all versions of neo4j-codegen Show documentation
Simple library for generating code.
The newest version!
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [https://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package org.neo4j.codegen.api
import org.neo4j.codegen
import org.neo4j.codegen.ClassHandle
import org.neo4j.codegen.CodeGenerationNotSupportedException
import org.neo4j.codegen.CodeGenerator
import org.neo4j.codegen.CodeGenerator.generateCode
import org.neo4j.codegen.CodeGeneratorOption
import org.neo4j.codegen.DisassemblyVisitor
import org.neo4j.codegen.Expression
import org.neo4j.codegen.Expression.constant
import org.neo4j.codegen.Expression.getStatic
import org.neo4j.codegen.Expression.invoke
import org.neo4j.codegen.Expression.invokeSuper
import org.neo4j.codegen.Expression.newInitializedArray
import org.neo4j.codegen.FieldReference
import org.neo4j.codegen.FieldReference.field
import org.neo4j.codegen.FieldReference.staticField
import org.neo4j.codegen.Parameter.param
import org.neo4j.codegen.TypeReference
import org.neo4j.codegen.TypeReference.OBJECT
import org.neo4j.codegen.api.CodeGeneration.ByteCodeGeneration
import org.neo4j.codegen.api.CodeGeneration.CodeGenerationMode
import org.neo4j.codegen.api.CodeGeneration.CodeGenerationMode.modeFromDebugOptions
import org.neo4j.codegen.api.CodeGeneration.DEBUG_PRINT_BYTECODE
import org.neo4j.codegen.api.CodeGeneration.DEBUG_PRINT_SOURCE
import org.neo4j.codegen.api.CodeGeneration.SourceCodeGeneration
import org.neo4j.codegen.api.SizeEstimation.estimateByteCodeSize
import org.neo4j.codegen.bytecode.ByteCode.BYTECODE
import org.neo4j.codegen.bytecode.ByteCode.PRINT_BYTECODE
import org.neo4j.codegen.source.SourceCode.PRINT_SOURCE
import org.neo4j.codegen.source.SourceCode.SOURCECODE
import org.neo4j.codegen.source.SourceCode.sourceLocation
import org.neo4j.codegen.source.SourceVisitor
import org.neo4j.cypher.internal.options.CypherDebugOption
import org.neo4j.cypher.internal.options.CypherDebugOptions
import org.neo4j.exceptions.CantCompileQueryException
import org.neo4j.exceptions.InternalException
import java.nio.file.Path
import java.nio.file.Paths
import scala.annotation.nowarn
import scala.collection.mutable.ArrayBuffer
import scala.language.existentials
/**
* Produces runnable code from an IntermediateRepresentation
*/
object CodeGeneration {
// the jvm doesn't allow methods bigger than 65535 bytes,
final val MAX_METHOD_LIMIT: Int = 65535
// Use these options for Debugging. They will print generated code to stdout
private val DEBUG_PRINT_SOURCE = false
private val DEBUG_PRINT_BYTECODE = false
final val GENERATE_JAVA_SOURCE_DEBUG_OPTION = CypherDebugOption.generateJavaSource.name
final val GENERATED_SOURCE_LOCATION_PROPERTY = "org.neo4j.cypher.DEBUG.generated_source_location"
def fromDebugOptions(methodLimit: Int = MAX_METHOD_LIMIT, debugOptions: CypherDebugOptions): CodeGeneration =
codeGeneration(methodLimit, modeFromDebugOptions(debugOptions))
def codeGeneration(
methodLimit: Int = MAX_METHOD_LIMIT,
mode: CodeGenerationMode =
ByteCodeGeneration(new CodeSaver(false, false))
): CodeGeneration = new CodeGeneration(methodLimit, mode)
sealed trait CodeGenerationMode {
def saver: CodeSaver
}
case class ByteCodeGeneration(saver: CodeSaver) extends CodeGenerationMode
case class SourceCodeGeneration(saver: CodeSaver) extends CodeGenerationMode
object CodeGenerationMode {
def modeFromDebugOptions(debugOptions: CypherDebugOptions): CodeGenerationMode = {
if (debugOptions.generateJavaSourceEnabled) {
val saveSourceToFileLocation = Option(System.getProperty(GENERATED_SOURCE_LOCATION_PROPERTY)).map(Paths.get(_))
val saver =
new CodeSaver(debugOptions.showJavaSourceEnabled, debugOptions.showBytecodeEnabled, saveSourceToFileLocation)
SourceCodeGeneration(saver)
} else {
val saver = new CodeSaver(false, debugOptions.showBytecodeEnabled)
ByteCodeGeneration(saver)
}
}
}
class CodeSaver(saveSource: Boolean, saveByteCode: Boolean, saveSourceToFileLocation: Option[Path] = None) {
private val _source: ArrayBuffer[(String, String)] = new ArrayBuffer()
private val _bytecode: ArrayBuffer[(String, String)] = new ArrayBuffer()
private def sourceVisitor: SourceVisitor =
(reference: TypeReference, sourceCode: CharSequence) => _source += (reference.name() -> sourceCode.toString)
private def byteCodeVisitor: DisassemblyVisitor =
(className: String, disassembly: CharSequence) => _bytecode += (className -> disassembly.toString)
def options: List[CodeGeneratorOption] = {
var l: List[CodeGeneratorOption] = Nil
if (saveSource) l ::= sourceVisitor
if (saveByteCode) l ::= byteCodeVisitor
saveSourceToFileLocation.foreach(path => l ::= sourceLocation(path))
l
}
def sourceCode: Seq[(String, String)] = _source.toSeq
def bytecode: Seq[(String, String)] = _bytecode.toSeq
}
}
class CodeGeneration(methodLimit: Int, val codeGenerationMode: CodeGenerationMode) {
def createGenerator(): CodeGenerator = {
createGenerator(classOf[IntermediateRepresentation].getClassLoader)
}
def createGenerator(parentClassLoader: ClassLoader): CodeGenerator = {
var (strategy, options) = (codeGenerationMode, DEBUG_PRINT_SOURCE) match {
case (SourceCodeGeneration(saver), _) => (SOURCECODE, saver.options)
case (ByteCodeGeneration(saver), true) => (SOURCECODE, saver.options)
case (ByteCodeGeneration(saver), false) => (BYTECODE, saver.options)
}
if (DEBUG_PRINT_SOURCE) options ::= PRINT_SOURCE
if (DEBUG_PRINT_BYTECODE) options ::= PRINT_BYTECODE
generateCode(parentClassLoader, strategy, options: _*)
}
def compileClass[T](c: ClassDeclaration[T], generator: CodeGenerator): ClassHandle = {
compileClassDeclaration(c, generator)
}
def loadAndSetConstants[T](handle: ClassHandle, declaration: ClassDeclaration[T]): Class[T] = {
val clazz = handle.loadClass()
setConstants(clazz, declaration.fields)
clazz.asInstanceOf[Class[T]]
}
def compileAnonymousClass[T](c: ClassDeclaration[T], generator: CodeGenerator): Class[T] = {
val handle = compileClassDeclaration(c, generator)
val clazz = handle.loadAnonymousClass()
setConstants(clazz, c.fields)
clazz.asInstanceOf[Class[T]]
}
private def setConstants(clazz: Class[_], fields: collection.Seq[Field]): Unit = {
if (fields.isEmpty) {
return
}
val declaredFields = clazz.getDeclaredFields
@nowarn("msg=return statement")
def findField(fields: Array[java.lang.reflect.Field], name: String): java.lang.reflect.Field = {
for (field <- fields) {
if (field.getName == name) return field
}
throw new NoSuchFieldException(name)
}
fields.distinct.foreach {
case StaticField(_, name, Some(value)) =>
findField(declaredFields, name).set(null, value)
case _ =>
}
}
private def beginBlock[BlockType <: AutoCloseable, T](block: BlockType)(exhaustBlock: BlockType => T): T = {
/*
* In the java API we are using try-with-resources for this. This is slightly problematic since we
* are then always calling close which potentially will hide errors thrown in code generation.
*/
val result = exhaustBlock(block)
block.close()
result
}
private def generateConstructor(
clazz: codegen.ClassGenerator,
fields: collection.Seq[Field],
params: collection.Seq[Parameter],
initializationCode: codegen.CodeBlock => codegen.Expression,
parent: Option[TypeReference]
): Unit = {
beginBlock(clazz.generateConstructor(params.map(_.asCodeGen).toSeq: _*)) { block =>
block.expression(invokeSuper(parent.getOrElse(OBJECT)))
fields.distinct.foreach {
case field @ InstanceField(typ, name) =>
val reference = clazz.field(typ, name)
field.initializer.map(ir => compileExpression(ir(), block)).foreach { value =>
block.put(block.self(), reference, value)
}
case StaticField(typ, name, _) =>
clazz.publicStaticField(typ, name)
}
initializationCode(block)
}
}
private def compileExpression(ir: IntermediateRepresentation, block: codegen.CodeBlock): codegen.Expression =
ir match {
// Foo.method(p1, p2,...)
case InvokeStatic(method, params) =>
invoke(method.asReference, params.map(p => compileExpression(p, block)): _*)
// Foo.method(p1, p2,...)
case InvokeStaticSideEffect(method, params) =>
val invocation = invoke(method.asReference, params.map(p => compileExpression(p, block)): _*)
if (method.returnType.isVoid) {
block.expression(invocation)
} else {
block.expression(codegen.Expression.pop(invocation))
}
codegen.Expression.EMPTY
// target.method(p1,p2,...)
case Invoke(target, method, params) =>
invoke(compileExpression(target, block), method.asReference, params.map(p => compileExpression(p, block)): _*)
// target.method(p1,p2,...)
case InvokeLocal(method, params) =>
invoke(
block.self(),
codegen.MethodReference.methodReference(block.owner(), method.returnType, method.name, method.params: _*),
params.map(p => compileExpression(p, block)): _*
)
// target.method(p1,p2,...)
case InvokeSideEffect(target, method, params) =>
val invocation =
invoke(compileExpression(target, block), method.asReference, params.map(p => compileExpression(p, block)): _*)
if (method.returnType.isVoid) {
block.expression(invocation)
} else {
block.expression(codegen.Expression.pop(invocation))
}
codegen.Expression.EMPTY
// this.method(p1,p2,...)
case InvokeLocalSideEffect(method, params) =>
val invocation = invoke(
block.self(),
codegen.MethodReference.methodReference(block.owner(), method.returnType, method.name, method.params: _*),
params.map(p => compileExpression(p, block)): _*
)
if (method.returnType.isVoid) {
block.expression(invocation)
} else {
block.expression(codegen.Expression.pop(invocation))
}
codegen.Expression.EMPTY
// loads local variable by name
case Load(variable, _) => block.load(variable)
// loads field on this object
case LoadField(None, f) =>
codegen.Expression.get(block.self(), field(block.owner(), f.typ, f.name))
// loads field on the given owner object
case LoadField(Some(owner), f) =>
val ownerExpr = compileExpression(owner, block)
codegen.Expression.get(ownerExpr, field(ownerExpr.`type`, f.typ, f.name))
// sets a field on this object
case SetField(None, f, v) =>
block.put(block.self(), field(block.owner(), f.typ, f.name), compileExpression(v, block))
codegen.Expression.EMPTY
// sets a field on the given owner object
case SetField(Some(owner), f, v) =>
val ownerExpr = compileExpression(owner, block)
block.put(ownerExpr, field(ownerExpr.`type`, f.typ, f.name), compileExpression(v, block))
codegen.Expression.EMPTY
// loads a given constant
case Constant(value) => constant(value)
// new ArrayValue[]{p1, p2,...}
case ArrayLiteral(typ, values) => newInitializedArray(typ, values.map(v => compileExpression(v, block)): _*)
// array[offset] = value
case ArraySet(array, offset, value) =>
block.expression(codegen.Expression.arraySet(
compileExpression(array, block),
compileExpression(offset, block),
compileExpression(value, block)
))
codegen.Expression.EMPTY
// array.length
case ArrayLength(array) =>
codegen.Expression.arrayLength(compileExpression(array, block))
// array[offset]
case ArrayLoad(array, offset) =>
codegen.Expression.arrayLoad(compileExpression(array, block), compileExpression(offset, block))
// Foo.BAR
case GetStatic(owner, typ, name) =>
getStatic(staticField(owner.getOrElse(block.classGenerator().handle()), typ, name))
// condition ? onTrue : onFalse
case Ternary(condition, onTrue, onFalse) =>
codegen.Expression.ternary(
compileExpression(condition, block),
compileExpression(onTrue, block),
compileExpression(onFalse, block)
)
// lhs + rhs
case Add(lhs, rhs) =>
codegen.Expression.add(compileExpression(lhs, block), compileExpression(rhs, block))
// lhs - rhs
case Subtract(lhs, rhs) =>
codegen.Expression.subtract(compileExpression(lhs, block), compileExpression(rhs, block))
// lhs * rhs
case Multiply(lhs, rhs) =>
codegen.Expression.multiply(compileExpression(lhs, block), compileExpression(rhs, block))
// lhs < rhs
case Lt(lhs, rhs) =>
codegen.Expression.lt(compileExpression(lhs, block), compileExpression(rhs, block))
// lhs <= rhs
case Lte(lhs, rhs) =>
codegen.Expression.lte(compileExpression(lhs, block), compileExpression(rhs, block))
// lhs > rhs
case Gt(lhs, rhs) =>
codegen.Expression.gt(compileExpression(lhs, block), compileExpression(rhs, block))
// lhs > rhs
case Gte(lhs, rhs) =>
codegen.Expression.gte(compileExpression(lhs, block), compileExpression(rhs, block))
// lhs == rhs
case Eq(lhs, rhs) =>
codegen.Expression.equal(compileExpression(lhs, block), compileExpression(rhs, block))
// lhs != rhs
case NotEq(lhs, rhs) =>
codegen.Expression.notEqual(compileExpression(lhs, block), compileExpression(rhs, block))
// test == null
case IsNull(test) => codegen.Expression.isNull(compileExpression(test, block))
// run multiple ops in a block, the value of the block is the last expression
case Block(ops) =>
if (ops.isEmpty) codegen.Expression.EMPTY else ops.map(compileExpression(_, block)).last
// run multiple ops in a block, the value of the block is the last expression
case p: PlaceHolder =>
val ops = p.ops
if (ops.isEmpty) codegen.Expression.EMPTY else ops.map(compileExpression(_, block)).last
case Comment(c: String) =>
block.comment(c)
codegen.Expression.EMPTY
// if (test) {onTrue}
case Condition(test, onTrue, None) =>
beginBlock(block.ifStatement(compileExpression(test, block)))(compileExpression(onTrue, _))
case Condition(test, onTrue, Some(onFalse)) =>
block.ifElseStatement(
compileExpression(test, block),
(t: codegen.CodeBlock) => compileExpression(onTrue, t),
(f: codegen.CodeBlock) => compileExpression(onFalse, f)
)
codegen.Expression.EMPTY
// typ name;
case DeclareLocalVariable(typ, name) =>
block.declare(typ, name)
// name = value;
case AssignToLocalVariable(name, value) =>
block.assign(block.local(name), compileExpression(value, block))
codegen.Expression.EMPTY
// try {ops} catch(exception name)(onError)
case TryCatch(ops, onError, exception, name) =>
beginBlock(block.tryCatch(
(errorBlock: codegen.CodeBlock) => compileExpression(onError, errorBlock),
param(exception, name)
))(
compileExpression(ops, _)
)
codegen.Expression.EMPTY
// throw error
case Throw(error) =>
block.throwException(compileExpression(error, block))
codegen.Expression.EMPTY
// lhs && rhs
case BooleanAnd(Seq(lhs, rhs)) =>
codegen.Expression.and(compileExpression(lhs, block), compileExpression(rhs, block))
case BooleanAnd(as) =>
codegen.Expression.ands(as.map(a => compileExpression(a, block)).toArray)
// lhs && rhs
case BooleanOr(Seq(lhs, rhs)) =>
codegen.Expression.or(compileExpression(lhs, block), compileExpression(rhs, block))
case BooleanOr(as) =>
codegen.Expression.ors(as.map(a => compileExpression(a, block)).toArray)
// new Foo(args[0], args[1], ...)
case NewInstance(constructor, args) =>
codegen.Expression.invoke(
codegen.Expression.newInstance(constructor.owner),
constructor.asReference,
args.map(compileExpression(_, block)): _*
)
// new Foo[5]
case NewArray(baseType, size) =>
codegen.Expression.newArray(baseType, size)
// new Foo[size]
case NewArrayDynamicSize(baseType, size) =>
codegen.Expression.newArray(baseType, compileExpression(size, block))
case Returns(value: IntermediateRepresentation) =>
block.returns(compileExpression(value, block))
codegen.Expression.EMPTY
// while(test) { body }
case Loop(test, body, labelName) =>
beginBlock(block.whileLoop(compileExpression(test, block), labelName))(compileExpression(body, _))
// break label
case Break(labelName) =>
block.breaks(labelName)
codegen.Expression.EMPTY
// (to) expressions
case Cast(to, expression) => codegen.Expression.cast(to, compileExpression(expression, block))
// expressions instance of t
case InstanceOf(typ, expression) => codegen.Expression.instanceOf(typ, compileExpression(expression, block))
case Not(test) => codegen.Expression.not(compileExpression(test, block))
case e @ OneTime(inner) =>
if (!e.isUsed) {
e.use()
compileExpression(inner, block)
} else {
codegen.Expression.EMPTY
}
case Noop =>
codegen.Expression.EMPTY
case Box(expression) =>
codegen.Expression.box(compileExpression(expression, block))
case Unbox(expression) =>
codegen.Expression.unbox(compileExpression(expression, block))
case Self(_) => block.self()
case NewInstanceInnerClass(ExtendClass(className, overrides, params, methods, fields), args) =>
val parentClass: ClassHandle = block.classGenerator().handle()
val generator = parentClass.generator
val classHandle = beginBlock(generator.generateClass(overrides, parentClass.packageName(), className)) {
(clazz: codegen.ClassGenerator) =>
beginBlock(clazz.generateConstructor(params.map(_.asCodeGen): _*)) { constructor =>
constructor.expression(Expression.invokeSuper(overrides, params.map(p => constructor.load(p.name)): _*))
fields.distinct.foreach {
case field @ InstanceField(typ, name) =>
val reference = clazz.field(typ, name)
field.initializer.map(ir => compileExpression(ir(), constructor)).foreach { value =>
constructor.put(constructor.self(), reference, value)
}
case StaticField(typ, name, _) =>
val field = clazz.publicStaticField(typ, name)
constructor.putStatic(
field,
Expression.getStatic(FieldReference.staticField(parentClass, field.`type`(), field.name()))
)
}
}
// methods
methods.foreach { m =>
compileMethodDeclaration(clazz, m)
}
clazz.handle()
}
val constructor =
if (args.isEmpty) {
codegen.MethodReference.constructorReference(classHandle)
} else {
codegen.MethodReference.constructorReference(classHandle, params.map(_.typ): _*)
}
codegen.Expression.invoke(
codegen.Expression.newInstance(classHandle),
constructor,
args.map(compileExpression(_, block)): _*
)
case unknownIr =>
throw new CodeGenerationNotSupportedException(null, s"Unknown ir `$unknownIr`") {}
}
private def compileClassDeclaration(c: ClassDeclaration[_], generator: CodeGenerator): codegen.ClassHandle = {
val handle = beginBlock(generator.generateClass(
c.extendsClass.getOrElse(codegen.TypeReference.OBJECT),
c.packageName,
c.className,
c.classDependencies.toArray,
c.implementsInterfaces.toArray
)) { (clazz: codegen.ClassGenerator) =>
generateConstructor(
clazz,
c.fields,
c.constructorParameters,
block => compileExpression(c.initializationCode, block),
c.extendsClass
)
c.methods.foreach { m =>
compileMethodDeclaration(clazz, m)
}
clazz.handle()
}
handle
}
private def compileMethodDeclaration(clazz: codegen.ClassGenerator, m: MethodDeclaration): Unit = {
val estimatedSize = estimateByteCodeSize(m)
if (estimatedSize > methodLimit) {
throw new CantCompileQueryException(
s"Method '${m.methodName}' is too big, estimated size $estimatedSize is bigger than $methodLimit"
)
}
val method = codegen.MethodDeclaration.method(
m.returnType,
m.methodName,
m.parameters.map(_.asCodeGen): _*
).modifiers(m.modifiers)
m.parameterizedWith.foreach {
case (name, bound) => method.parameterizedWith(name, bound)
}
m.throws.foreach(method.throwsException)
try {
beginBlock(clazz.generate(method)) { block =>
m.localVariables.distinct.foreach { v =>
block.assign(v.typ, v.name, compileExpression(v.value, block))
}
if (m.returnType == codegen.TypeReference.VOID) {
block.expression(compileExpression(m.body, block))
} else {
block.returns(compileExpression(m.body, block))
}
}
} catch {
case e: ArrayIndexOutOfBoundsException =>
// NOTE: This could be a CantCompileQueryException, but then it would be handled at runtime, and may pass unnoticed
throw new InternalException(
s"""Method '${m.methodName}' in class '${clazz.handle().name()}' failed in code generation: ${e.getClass.getSimpleName} '${e.getMessage}
|This could mean that an intermediate representation instruction has been generated with an incorrect type.
|One common mistake is that a method type parameter of an invoke has been set to the wrong type:
| -> Check that type parameters of e.g. invoke(..., method[OWNER, OUT, IN1, IN2, ...](...), ...) are exactly matching the method's declaration.
|If your problem is something different, please extend this error message with more examples.""".stripMargin,
e
)
case e: Exception =>
// NOTE: This could be a CantCompileQueryException, but then it would be handled at runtime, and may pass unnoticed
throw new InternalException(
s"Method '${m.methodName}' in class '${clazz.handle().name()}' failed in code generation: ${e.getClass.getSimpleName} '${e.getMessage}",
e
)
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy