All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.skyenet.heart.JavaInterpreter.kt Maven / Gradle / Ivy

package com.simiacryptus.skyenet.heart

import com.simiacryptus.skyenet.Heart
import java.io.ByteArrayOutputStream
import java.io.OutputStream
import java.net.URI
import java.util.*
import javax.tools.*
import javax.tools.DiagnosticCollector
import javax.tools.JavaFileObject
open class JavaInterpreter(val defs: Map = HashMap()) : Heart {
    private val guidMap = HashMap()

    init {
        defs.entries.forEach { t ->
            val key = t.key
            val value = t.value
            val guid = UUID.randomUUID().toString()
            guidMap.put(guid, value)
            System.setProperty("java.def.$key", guid)
        }
    }

    private val compiler: JavaCompiler = ToolProvider.getSystemJavaCompiler()
    override fun getLanguage(): String {
        return "Java"
    }

    override fun run(code: String): Any? {
        val (className, wrappedCode) = wrapJavaCode(code)
        val classLoader = InMemoryClassLoader()
        val (diagnosticCollector, success) = compile(className, wrappedCode, classLoader)
        return wrapExecution {
            if (success) {
                val clazz = classLoader.loadClass(className)
                val method = clazz.getMethod("runSnippet")
                try {
                    method.invoke(null)
                } catch (e: Exception) {
                    throw e.cause ?: e
                }
            } else {
                val errors = diagnosticCollector.diagnostics.joinToString("\n")
                Exception("Compilation errors:\n$errors")
            }
        }
    }

    data class Pair(val a: A, val b: B)

    private fun compile(
        className: String,
        wrappedCode: String,
        classLoader: InMemoryClassLoader,
    ): Pair, Boolean> {
        val compilationUnits = listOf(stringJavaFileObject(className, wrappedCode))
        val diagnosticCollector = DiagnosticCollector()
        val fileManager = InMemoryJavaFileManager(compiler, classLoader)
        val success = compiler.getTask(null, fileManager, diagnosticCollector, null, null, compilationUnits).call()
        return Pair(diagnosticCollector, success)
    }
    override fun validate(code: String): Exception? {
        val (className, wrappedCode) = wrapJavaCode(code)
        val classLoader = InMemoryClassLoader()
        val (diagnosticCollector, success) = compile(className, wrappedCode, classLoader)
        return if (success) {
            null
        } else {
            val errors = diagnosticCollector.diagnostics.joinToString("\n") { it.toString() }
            Exception("Compilation errors:\n$errors")
        }
    }


    private fun wrapJavaCode(code: String): Pair {
        // Class name with random hex
        val className = "JavaSnippet" + UUID.randomUUID().toString().replace("-", "")
        val definitions = defs.entries.joinToString("\n") { (key, _) ->
            val guid = System.getProperty("java.def.$key")
            "public static Object $key = guidMap.get(\"$guid\");"
        }

        // Split code into import section and code section
        val (imports, codeSection) = code.split("\n").partition { it.startsWith("import ") }

        val wrappedCode = """
            |import java.util.Map;
            |import java.util.UUID;
            |${imports.joinToString("\n|")}
            |
            |public class $className {
            |    private static final Map guidMap = initGuidMap();
            |
            |    $definitions
            |
            |    private static Map initGuidMap() {
            |        Map map = new java.util.HashMap<>();
            |        ${guidMap.entries.joinToString("\n|        ") { (guid, _) -> "map.put(\"$guid\", ${className}.class.getClassLoader().loadClass(\"JavaInterpreter\").getMethod(\"getGuidMap\").invoke(null).get(\"$guid\"));" }}
            |        return map;
            |    }
            |
            |    public static Object runSnippet() {
            |        ${codeSection.joinToString("\n|        ")}
            |        return null;
            |    }
            |}
        """.trimMargin()
        return Pair(className, wrappedCode)
    }

    class stringJavaFileObject(private val className: String, private val code: String) : SimpleJavaFileObject(
        URI.create(
            "string:///" + className.replace('.', '/') + JavaFileObject.Kind.SOURCE.extension
        ), JavaFileObject.Kind.SOURCE
    ) {
        override fun getCharContent(ignoreEncodingErrors: Boolean): CharSequence = code
    }

    class InMemoryClassLoader : ClassLoader() {
        private val classes = mutableMapOf()

        fun addClass(name: String, bytes: ByteArray) {
            classes[name] = bytes
        }

        override fun findClass(name: String): Class<*> {
            val bytes = classes[name] ?: throw ClassNotFoundException(name)
            return defineClass(name, bytes, 0, bytes.size)
        }
    }

    class InMemoryJavaFileManager(private val compiler: JavaCompiler, private val classLoader: InMemoryClassLoader) :
        ForwardingJavaFileManager(compiler.getStandardFileManager(null, null, null)) {
        override fun getJavaFileForOutput(
            location: JavaFileManager.Location?,
            className: String,
            kind: JavaFileObject.Kind,
            sibling: FileObject?,
        ): JavaFileObject {
            return object : SimpleJavaFileObject(
                URI.create(
                    "string:///" + className.replace(
                        '.',
                        '/'
                    ) + kind.extension
                ), kind
            ) {
                override fun openOutputStream(): OutputStream {
                    return ByteArrayOutputStream().also { byteArrayOutputStream ->
                        classLoader.addClass(className, byteArrayOutputStream.toByteArray())
                    }
                }
            }
        }
    }

    companion object {
        @JvmStatic
        fun getGuidMap(): MutableMap {
            return (Thread.currentThread().contextClassLoader.loadClass("JavaInterpreter").getDeclaredConstructor()
                .newInstance() as JavaInterpreter).guidMap
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy