dorkbox.network.rmi.RmiUtils.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of Network Show documentation
Show all versions of Network Show documentation
Encrypted, high-performance, and event-driven/reactive network stack for Java 6+
/*
* Copyright 2020 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.rmi
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.reflectasm.MethodAccess
import dorkbox.network.connection.Connection
import dorkbox.util.classes.ClassHelper
import mu.KLogger
import java.lang.reflect.Method
import java.lang.reflect.Modifier
import java.util.*
import kotlin.coroutines.Continuation
/**
* Utility methods for creating a method cache for a class or interface.
*
* Additionally, this will override methods on the implementation so that methods can be called with a [Connection] parameter as the
* first parameter, with all other parameters being equal to the interface.
*
* This is to support calling RMI methods from an interface (that does pass the connection reference) to
* an implType, that DOES pass the connection reference. The remote side (that initiates the RMI calls), MUST use
* the interface, and the implType may override the method, so that we add the connection as the first in
* the list of parameters.
*
* for example:
* Interface: foo(String x)
*
* Impl: foo(String x) -> not called
* Impl: foo(Connection c, String x) -> this is called instead
*
* The implType (if it exists, with the same name, and with the same signature + connection parameter) will be called from the interface
* instead of the method that would NORMALLY be called.
*/
object RmiUtils {
private val METHOD_COMPARATOR = Comparator { o1, o2 -> // Methods are sorted so they can be represented as an index.
val o1Name = o1.name
val o2Name = o2.name
var diff = o1Name.compareTo(o2Name)
if (diff != 0) {
return@Comparator diff
}
val argTypes1 = o1.parameterTypes
val argTypes2 = o2.parameterTypes
if (argTypes1.size > argTypes2.size) {
return@Comparator 1
}
if (argTypes1.size < argTypes2.size) {
return@Comparator -1
}
for (i in argTypes1.indices) {
diff = argTypes1[i].name.compareTo(argTypes2[i].name)
if (diff != 0) {
return@Comparator diff
}
}
throw RuntimeException("Two methods with same signature! ('$o1Name', '$o2Name'")
}
private fun getReflectAsmMethod(logger: KLogger, clazz: Class<*>): MethodAccess? {
return try {
val methodAccess = MethodAccess.get(clazz)
if (methodAccess.methodNames.isEmpty() && methodAccess.parameterTypes.isEmpty() && methodAccess.returnTypes.isEmpty()) {
// there was NOTHING that reflectASM found, so trying to use it doesn't do us any good
null
} else {
methodAccess
}
} catch (e: Exception) {
logger.error("Unable to create ReflectASM method access", e)
null
}
}
/**
* @param iFace this is never null.
* @param impl this is NULL on the rmi "client" side. This is NOT NULL on the "server" side (where the object lives)
*/
fun getCachedMethods(logger: KLogger, kryo: Kryo, asmEnabled: Boolean, iFace: Class<*>, impl: Class<*>?, classId: Int): Array {
var ifaceAsmMethodAccess: MethodAccess? = null
var implAsmMethodAccess: MethodAccess? = null
// RMI is **ALWAYS** based upon an interface, so we must always make sure to get the methods of the interface, instead of the
// implementation, otherwise we will have the wrong order of methods, so invoking a method by it's index will fail.
val methods = getMethods(iFace)
val size = methods.size
val cachedMethods = arrayOfNulls(size)
val implMethods: Array?
if (impl != null) {
require(!impl.isInterface) { "Cannot have type as an interface, it must be an implementation" }
implMethods = getMethods(impl)
// reflectASM
// doesn't work on android (set correctly by the serialization manager)
// can't get any method from the 'Object' object (we get from the interface, which is NOT 'Object')
// and it MUST be public (iFace is always public)
if (asmEnabled) {
implAsmMethodAccess = getReflectAsmMethod(logger, impl)
}
} else {
implMethods = null
}
// reflectASM
// doesn't work on android (set correctly by the serialization manager)
// can't get any method from the 'Object' object (we get from the interface, which is NOT 'Object')
// and it MUST be public (iFace is always public)
if (asmEnabled) {
ifaceAsmMethodAccess = getReflectAsmMethod(logger, iFace)
}
val hasConnectionOverrideMethods = hasOverwriteMethodWithConnectionParam(implMethods)
for (i in 0 until size) {
val method = methods[i]
val declaringClass = method.declaringClass
val parameterTypes = method.parameterTypes
// Store the serializer for each final parameter.
// this is ONLY for the ORIGINAL method, not the overridden one.
val serializers = arrayOfNulls>(parameterTypes.size)
parameterTypes.forEachIndexed { index, parameterType ->
val paramClazz = parameterTypes[index]
if (kryo.isFinal(paramClazz) || paramClazz === Continuation::class.java) {
serializers[index] = kryo.getSerializer(parameterType)
}
}
// copy because they can be overridden
var cachedMethod: CachedMethod? = null
@Suppress("LocalVariableName")
var iface_OR_ImplMethodAccess = ifaceAsmMethodAccess
// reflectAsm doesn't like "Object" class methods
val canUseAsm = asmEnabled && declaringClass != Any::class.java
var overwrittenMethod: Method? = null
// this is how we detect if the method has been changed from the interface -> implementation + connection parameter
if (implMethods != null && hasConnectionOverrideMethods) {
overwrittenMethod = getOverwriteMethodWithConnectionParam(implMethods, method)
if (overwrittenMethod != null) {
// still might be null!
iface_OR_ImplMethodAccess = implAsmMethodAccess
}
}
if (canUseAsm) {
try {
val index = if (overwrittenMethod != null) {
// have to take into account the overwritten method's first parameter will ALWAYS be "Connection"
iface_OR_ImplMethodAccess!!.getIndex(method.name, *overwrittenMethod.parameterTypes)
} else {
iface_OR_ImplMethodAccess!!.getIndex(method.name, *parameterTypes)
}
cachedMethod = CachedAsmMethod(
methodAccessIndex = index,
methodAccess = iface_OR_ImplMethodAccess,
name = method.name,
method = method,
methodIndex = i,
methodClassId = classId,
serializers = serializers)
} catch (e: Exception) {
logger.error("Unable to use ReflectAsm for ${makeFancyMethodName(method)}", e)
}
}
if (cachedMethod == null) {
cachedMethod = CachedMethod(
method = method,
methodIndex = i,
methodClassId = classId,
serializers = serializers)
}
// this MIGHT be null, but if it is not, this is the method we will invoke INSTEAD of the "normal" method
cachedMethod.overriddenMethod = overwrittenMethod
cachedMethods[i] = cachedMethod
if (overwrittenMethod != null && logger.isDebugEnabled) {
val name = if (cachedMethod.method.declaringClass.isInterface) {
"iface"
} else {
"method"
}
logger.debug("Overridden $name : ${makeFancyMethodName(cachedMethod)}")
logger.debug(" to method : ${makeFancyMethodName(overwrittenMethod)}")
}
}
// force the type, because we KNOW it is ok to do so
@Suppress("UNCHECKED_CAST")
return cachedMethods as Array
}
/**
* Check to see if there are ANY methods in this class that start with a "Connection" parameter.
*/
private fun hasOverwriteMethodWithConnectionParam(implMethods: Array?): Boolean {
if (implMethods == null) {
return false
}
// maybe there is a method that starts with a "Connection" parameter.
for (implMethod in implMethods) {
val implParameters = implMethod.parameterTypes
// check if the FIRST parameter is "Connection"
if (implParameters.isNotEmpty() && ClassHelper.hasInterface(Connection::class.java, implParameters[0])) {
return true
}
}
return false
}
/**
* This will overwrite an original (iface based) method with a method from the implementation ONLY if there is the extra 'Connection' parameter (as per above)
* NOTE: does not null check
*
* @param implMethods methods from the implementation
* @param origMethods methods from the interface
*/
private fun overwriteMethodsWithConnectionParam(implMethods: Array, origMethods: Array) {
var i = 0
val origMethodsSize = origMethods.size
while (i < origMethodsSize) {
val origMethod = origMethods[i]
val overwriteMethodsWithConnectionParam = getOverwriteMethodWithConnectionParam(implMethods, origMethod)
if (overwriteMethodsWithConnectionParam != null) {
origMethods[i] = overwriteMethodsWithConnectionParam
}
i++
}
}
/**
* This will overwrite an original (iface based) method with a method from the implementation ONLY if there is the extra 'Connection' parameter (as per above)
* NOTE: does not null check
*
* @param implMethods methods from the implementation
* @param origMethod original method from the interface
*/
private fun getOverwriteMethodWithConnectionParam(implMethods: Array, origMethod: Method): Method? {
val origName = origMethod.name
val origTypes = origMethod.parameterTypes
val origLength = origTypes.size + 1
for (implMethod in implMethods) {
val implName = implMethod.name
val implParameters = implMethod.parameterTypes
val implLength = implParameters.size
if (origLength != implLength || origName != implName) {
continue
}
// check if the FIRST parameter is "Connection"
if (ClassHelper.hasInterface(Connection::class.java, implParameters[0])) {
// now we check to see if our "check" method is equal to our "cached" method + Connection
if (implLength == 1) {
// we only have "Connection" as a parameter
return implMethod
} else {
var found = true
for (k in 1 until implLength) {
if (origTypes[k - 1] != implParameters[k]) {
// make sure all the parameters match. Cannot use arrays.equals(*), because one will have "Connection" as
// a parameter - so we check that the rest match
found = false
break
}
}
if (found) {
return implMethod
}
}
}
}
return null
}
/**
* This will methods from an interface (for RMI), and from an implementation (for "connection" overriding the method signature).
*
* @return an array list of all found methods for this class
*/
fun getMethods(type: Class<*>): Array {
val allMethods = ArrayList()
val accessibleMethods: MutableMap>>> = HashMap()
val classes = LinkedList>()
classes.add(type)
// explicitly add Object.class because that can always be called, because it is common to 100% of all java objects (and it's methods
// are not added via parentClass.getMethods()
classes.add(Any::class.java)
var nextClass: Class<*>
while (!classes.isEmpty()) {
nextClass = classes.removeFirst()
val methods = nextClass.methods
for (i in methods.indices) {
val method = methods[i]
// static and private methods cannot be called via RMI.
val modifiers = method.modifiers
if (Modifier.isStatic(modifiers)) {
continue
}
if (Modifier.isPrivate(modifiers)) {
continue
}
if (method.isSynthetic) {
continue
}
// methods that have been over-ridden by another method cannot be called.
// the first one in the map, is the "highest" level method, and is what can be called.
val name = method.name
val types = method.parameterTypes // length 0 if there are no parameters
var existingTypes = accessibleMethods[name]
if (existingTypes != null) {
var found = false
for (existingType in existingTypes) {
if (Arrays.equals(types, existingType)) {
found = true
break
}
}
if (found) {
// the method is overridden, so it should not be called.
continue
}
}
if (existingTypes == null) {
existingTypes = ArrayList()
}
existingTypes.add(types)
// add to the map for checking later
accessibleMethods[name] = existingTypes
// safe to add this method to the list of recognized methods
allMethods.add(method)
}
// add all interfaces from our class (if any)
classes.addAll(listOf(*nextClass.interfaces))
// If we are an interface, one CANNOT call any methods NOT defined by the interface!
// also, interfaces don't have a super-class.
val superclass = nextClass.superclass
if (superclass != null) {
classes.add(superclass)
}
}
accessibleMethods.clear()
Collections.sort(allMethods, METHOD_COMPARATOR)
return allMethods.toTypedArray()
}
fun resolveSerializerInstance(k: Kryo, superClass: Class<*>, serializerClass: Class>): Serializer<*> {
return try {
try {
serializerClass.getConstructor(Kryo::class.java, Class::class.java).newInstance(k, superClass)
} catch (ex1: NoSuchMethodException) {
try {
serializerClass.getConstructor(Kryo::class.java).newInstance(k)
} catch (ex2: NoSuchMethodException) {
try {
serializerClass.getConstructor(Class::class.java).newInstance(superClass)
} catch (ex3: NoSuchMethodException) {
serializerClass.getDeclaredConstructor().newInstance()
}
}
}
} catch (ex: Exception) {
throw IllegalArgumentException(
"Unable to create serializer \"" + serializerClass.name + "\" for class: " + superClass.name, ex)
}
}
fun getHierarchy(clazz: Class<*>): ArrayList> {
val allClasses = ArrayList>()
val parseClasses = LinkedList>()
parseClasses.add(clazz)
var nextClass: Class<*>
while (!parseClasses.isEmpty()) {
nextClass = parseClasses.removeFirst()
allClasses.add(nextClass)
// add all interfaces from our class (if any)
parseClasses.addAll(Arrays.asList(*nextClass.interfaces))
val superclass = nextClass.superclass
if (superclass != null) {
parseClasses.add(superclass)
}
}
// remove the first class, because we don't need it
allClasses.remove(clazz)
return allClasses
}
fun packShorts(left: Int, right: Int): Int {
return left shl 16 or (right and 0xFFFF)
}
fun unpackLeft(packedInt: Int): Int {
return packedInt shr 16
}
fun unpackRight(packedInt: Int): Int {
return packedInt.toShort().toInt()
}
@Suppress("EXPERIMENTAL_API_USAGE")
fun unpackUnsignedRight(packedInt: Int): Int {
// this just does a .toUShort().toInt() conversion. This is cleaner than doing it manually
return packedInt.toUShort().toInt()
}
@Suppress("EXPERIMENTAL_API_USAGE")
fun unpackUnsignedRight(packedLong: Long): Int {
return packedLong.toUShort().toInt()
}
fun makeFancyMethodName(method: CachedMethod): String {
return makeFancyMethodName(method.method)
}
fun makeFancyMethodName(method: Method): String {
val parameterTypes: Array> = method.parameterTypes
val size = parameterTypes.size
val isSuspend = size > 0 && parameterTypes[size - 1] == Continuation::class.java
val args: String = if (size == 0 || (size == 1 && isSuspend)) {
""
} else {
// ALWAYS remove Continuation, since it's REALLY with "suspend" modifier)
if (isSuspend) {
parameterTypes.filter { it != Continuation::class.java }.joinToString { it.simpleName }
} else {
parameterTypes.joinToString { it.simpleName }
}
}
return if (isSuspend) {
"suspend ${method.declaringClass.name}.${method.name}($args)"
} else {
"${method.declaringClass.name}.${method.name}($args)"
}
}
/**
* Remove from the stacktrace the "slice" of stack that relates to "dorkbox.network." package
*
* Then remove from the last occurrence of "dorkbox.network." ALL "kotlinx.coroutines." and "kotlin.coroutines." stacks
*
* We do this because these stack frames are not useful in resolving exception handling from a users perspective, and only clutter the stacktrace.
*/
fun cleanStackTraceForProxy(localException: Exception, remoteException: Exception? = null) {
val myClassName = RmiClient::class.java.name
val stackTrace = localException.stackTrace
var newStartIndex = 0
var newEndIndex = stackTrace.size
var foundStart = false
for ((index, element) in stackTrace.withIndex()) {
// step 1: Find the start of our method invocation
if (!foundStart) {
// "startsWith" because with continuations, the ACTUAL class name is mangled, ie: dorkbox.network.rmi.RmiClient$invoke$$inlined$Continuation$1
if (element.className.startsWith(myClassName)) {
newStartIndex = index
// check 1 more time, because we want to remove the proxy invocation off the stack as well.
if (stackTrace[index+1].className.startsWith("com.sun.proxy.")) {
newStartIndex++
}
// this is where we will START (not where we are)
newStartIndex++
newEndIndex = newStartIndex
foundStart = true
}
} else {
// step 2: Find the start of coroutines
val className = element.className
if (className.startsWith("kotlin.coroutines.") || className.startsWith("kotlinx.coroutines.")) {
// -1 because we want to end BEFORE the coroutine suspend call starts
newEndIndex = index-1
break
}
}
}
if (remoteException == null) {
// no remote exception, just cleanup our own callstack. We don't ALWAYS have a new stack.
if (newEndIndex > newStartIndex) {
localException.stackTrace = stackTrace.copyOfRange(newStartIndex, newEndIndex)
}
} else {
// merge this info into the remote exception, so we can get the correct call stack info
val newStack = Array(remoteException.stackTrace.size + (newEndIndex - newStartIndex)) { stackTrace[0] }
remoteException.stackTrace.copyInto(newStack)
stackTrace.copyInto(newStack, remoteException.stackTrace.size, newStartIndex, newEndIndex)
remoteException.stackTrace = newStack
}
}
/**
* Remove from the stacktrace (going in reverse), kotlin coroutine info + dorkbox network call stack.
*
* Neither of these are useful in resolving exception handling from a users perspective, and only clutter the stacktrace.
*/
fun cleanStackTraceForImpl(exception: Exception, isSuspendFunction: Boolean) {
val packageName = RmiUtils::class.java.packageName
val stackTrace = exception.stackTrace
val size = stackTrace.size
if (size == 0) {
return
}
var newEndIndex = -1 // because we index by size, but access from 0
// step 1: starting at 0, find the start of our RMI method invocation
for (element in stackTrace) {
if (element.className.startsWith(packageName)) {
break
} else {
newEndIndex++
}
}
// step 2: starting at newEndIndex -> 0, find the start of reflection information (we are java11+ ONLY, so this is easy)
for (i in newEndIndex downTo 0) {
// this will be either JAVA reflection or ReflectASM reflection
val stackModule = stackTrace[i].moduleName
if (stackModule == "java.base") {
newEndIndex--
} else {
break
}
}
newEndIndex++ // have to add 1 back, because a copy must be by size (and we access from 0)
if (newEndIndex > 0) {
// if we are Java reflection, we are correct.
// if we are ReflectASM reflection, there is ONE stack frame extra we have to remove
if (stackTrace[newEndIndex].className == CachedAsmMethod::class.java.name) {
newEndIndex--
}
// if we are a KOTLIN suspend function, there is ONE stack frame extra we have to remove
if (isSuspendFunction && newEndIndex > 0) {
newEndIndex--
}
exception.stackTrace = stackTrace.copyOfRange(0, newEndIndex)
}
}
}