All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dorkbox.network.rmi.RmiUtils.kt Maven / Gradle / Ivy

Go to download

Encrypted, high-performance, and event-driven/reactive network stack for Java 6+

There is a newer version: 6.14
Show newest version
/*
 * Copyright 2020 dorkbox, llc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package dorkbox.network.rmi

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.reflectasm.MethodAccess
import dorkbox.network.connection.Connection
import dorkbox.util.classes.ClassHelper
import mu.KLogger
import java.lang.reflect.Method
import java.lang.reflect.Modifier
import java.util.*
import kotlin.coroutines.Continuation

/**
 * Utility methods for creating a method cache for a class or interface.
 *
 * Additionally, this will override methods on the implementation so that methods can be called with a [Connection] parameter as the
 * first parameter, with all other parameters being equal to the interface.
 *
 * This is to support calling RMI methods from an interface (that does pass the connection reference) to
 * an implType, that DOES pass the connection reference. The remote side (that initiates the RMI calls), MUST use
 * the interface, and the implType may override the method, so that we add the connection as the first in
 * the list of parameters.
 *
 * for example:
 * Interface: foo(String x)
 *
 * Impl: foo(String x) -> not called
 * Impl: foo(Connection c, String x) -> this is called instead
 *
 * The implType (if it exists, with the same name, and with the same signature + connection parameter) will be called from the interface
 * instead of the method that would NORMALLY be called.
 */
object RmiUtils {
    private val METHOD_COMPARATOR = Comparator { o1, o2 -> // Methods are sorted so they can be represented as an index.
        val o1Name = o1.name
        val o2Name = o2.name

        var diff = o1Name.compareTo(o2Name)
        if (diff != 0) {
            return@Comparator diff
        }

        val argTypes1 = o1.parameterTypes
        val argTypes2 = o2.parameterTypes
        if (argTypes1.size > argTypes2.size) {
            return@Comparator 1
        }
        if (argTypes1.size < argTypes2.size) {
            return@Comparator -1
        }

        for (i in argTypes1.indices) {
            diff = argTypes1[i].name.compareTo(argTypes2[i].name)
            if (diff != 0) {
                return@Comparator diff
            }
        }

        throw RuntimeException("Two methods with same signature! ('$o1Name', '$o2Name'")
    }

    private fun getReflectAsmMethod(logger: KLogger, clazz: Class<*>): MethodAccess? {
        return try {
            val methodAccess = MethodAccess.get(clazz)

            if (methodAccess.methodNames.isEmpty() && methodAccess.parameterTypes.isEmpty() && methodAccess.returnTypes.isEmpty()) {
                // there was NOTHING that reflectASM found, so trying to use it doesn't do us any good
                null
            } else {
                methodAccess
            }
        } catch (e: Exception) {
            logger.error("Unable to create ReflectASM method access", e)
            null
        }
    }

    /**
     * @param iFace this is never null.
     * @param impl this is NULL on the rmi "client" side. This is NOT NULL on the "server" side (where the object lives)
     */
    fun getCachedMethods(logger: KLogger, kryo: Kryo, asmEnabled: Boolean, iFace: Class<*>, impl: Class<*>?, classId: Int): Array {
        var ifaceAsmMethodAccess: MethodAccess? = null
        var implAsmMethodAccess: MethodAccess? = null

        // RMI is **ALWAYS** based upon an interface, so we must always make sure to get the methods of the interface, instead of the
        // implementation, otherwise we will have the wrong order of methods, so invoking a method by it's index will fail.
        val methods = getMethods(iFace)
        val size = methods.size
        val cachedMethods = arrayOfNulls(size)

        val implMethods: Array?
        if (impl != null) {
            require(!impl.isInterface) { "Cannot have type as an interface, it must be an implementation" }
            implMethods = getMethods(impl)

            // reflectASM
            //   doesn't work on android (set correctly by the serialization manager)
            //   can't get any method from the 'Object' object (we get from the interface, which is NOT 'Object')
            //   and it MUST be public (iFace is always public)
            if (asmEnabled) {
                implAsmMethodAccess = getReflectAsmMethod(logger, impl)
            }
        } else {
            implMethods = null
        }

        // reflectASM
        //   doesn't work on android (set correctly by the serialization manager)
        //   can't get any method from the 'Object' object (we get from the interface, which is NOT 'Object')
        //   and it MUST be public (iFace is always public)
        if (asmEnabled) {
            ifaceAsmMethodAccess = getReflectAsmMethod(logger, iFace)
        }
        
        val hasConnectionOverrideMethods =  hasOverwriteMethodWithConnectionParam(implMethods)


        for (i in 0 until size) {
            val method = methods[i]
            val declaringClass = method.declaringClass
            val parameterTypes = method.parameterTypes

            // Store the serializer for each final parameter.
            // this is ONLY for the ORIGINAL method, not the overridden one.
            val serializers = arrayOfNulls>(parameterTypes.size)

            parameterTypes.forEachIndexed { index, parameterType ->
                val paramClazz = parameterTypes[index]
                if (kryo.isFinal(paramClazz) || paramClazz === Continuation::class.java) {
                    serializers[index] = kryo.getSerializer(parameterType)
                }
            }

            // copy because they can be overridden
            var cachedMethod: CachedMethod? = null

            @Suppress("LocalVariableName")
            var iface_OR_ImplMethodAccess = ifaceAsmMethodAccess

            // reflectAsm doesn't like "Object" class methods
            val canUseAsm = asmEnabled && declaringClass != Any::class.java
            var overwrittenMethod: Method? = null

            // this is how we detect if the method has been changed from the interface -> implementation + connection parameter
            if (implMethods != null && hasConnectionOverrideMethods) {
                overwrittenMethod = getOverwriteMethodWithConnectionParam(implMethods, method)

                if (overwrittenMethod != null) {
                    // still might be null!
                    iface_OR_ImplMethodAccess = implAsmMethodAccess
                }
            }


            if (canUseAsm) {
                try {
                    val index = if (overwrittenMethod != null) {
                        // have to take into account the overwritten method's first parameter will ALWAYS be "Connection"
                        iface_OR_ImplMethodAccess!!.getIndex(method.name, *overwrittenMethod.parameterTypes)
                    } else {
                        iface_OR_ImplMethodAccess!!.getIndex(method.name, *parameterTypes)
                    }

                    cachedMethod = CachedAsmMethod(
                            methodAccessIndex = index,
                            methodAccess = iface_OR_ImplMethodAccess,
                            name = method.name,
                            method = method,
                            methodIndex = i,
                            methodClassId = classId,
                            serializers = serializers)
                } catch (e: Exception) {
                    logger.error("Unable to use ReflectAsm for ${makeFancyMethodName(method)}", e)
                }
            }

            if (cachedMethod == null) {
                cachedMethod = CachedMethod(
                        method = method,
                        methodIndex = i,
                        methodClassId = classId,
                        serializers = serializers)
            }

            // this MIGHT be null, but if it is not, this is the method we will invoke INSTEAD of the "normal" method
            cachedMethod.overriddenMethod = overwrittenMethod

            cachedMethods[i] = cachedMethod


            if (overwrittenMethod != null && logger.isDebugEnabled) {
                val name = if (cachedMethod.method.declaringClass.isInterface) {
                    "iface"
                } else {
                    "method"
                }
                logger.debug("Overridden $name : ${makeFancyMethodName(cachedMethod)}")
                logger.debug("       to method : ${makeFancyMethodName(overwrittenMethod)}")
            }
        }

        // force the type, because we KNOW it is ok to do so
        @Suppress("UNCHECKED_CAST")
        return cachedMethods as Array
    }

    /**
     * Check to see if there are ANY methods in this class that start with a "Connection" parameter.
     */
    private fun hasOverwriteMethodWithConnectionParam(implMethods: Array?): Boolean {
        if (implMethods == null) {
            return false
        }

        // maybe there is a method that starts with a "Connection" parameter.
        for (implMethod in implMethods) {
            val implParameters = implMethod.parameterTypes

            // check if the FIRST parameter is "Connection"
            if (implParameters.isNotEmpty() && ClassHelper.hasInterface(Connection::class.java, implParameters[0])) {
                return true
            }
        }

        return false
    }

    /**
     * This will overwrite an original (iface based) method with a method from the implementation ONLY if there is the extra 'Connection' parameter (as per above)
     * NOTE: does not null check
     *
     * @param implMethods methods from the implementation
     * @param origMethods methods from the interface
     */
    private fun overwriteMethodsWithConnectionParam(implMethods: Array, origMethods: Array) {
        var i = 0
        val origMethodsSize = origMethods.size

        while (i < origMethodsSize) {
            val origMethod = origMethods[i]
            val overwriteMethodsWithConnectionParam = getOverwriteMethodWithConnectionParam(implMethods, origMethod)
            if (overwriteMethodsWithConnectionParam != null) {
                origMethods[i] = overwriteMethodsWithConnectionParam
            }

            i++
        }
    }

    /**
     * This will overwrite an original (iface based) method with a method from the implementation ONLY if there is the extra 'Connection' parameter (as per above)
     * NOTE: does not null check
     *
     * @param implMethods methods from the implementation
     * @param origMethod original method from the interface
     */
    private fun getOverwriteMethodWithConnectionParam(implMethods: Array, origMethod: Method): Method? {
        val origName = origMethod.name
        val origTypes = origMethod.parameterTypes
        val origLength = origTypes.size + 1

        for (implMethod in implMethods) {
            val implName = implMethod.name
            val implParameters = implMethod.parameterTypes
            val implLength = implParameters.size

            if (origLength != implLength || origName != implName) {
                continue
            }

            // check if the FIRST parameter is "Connection"
            if (ClassHelper.hasInterface(Connection::class.java, implParameters[0])) {
                // now we check to see if our "check" method is equal to our "cached" method + Connection
                if (implLength == 1) {
                    // we only have "Connection" as a parameter
                    return implMethod
                } else {
                    var found = true
                    for (k in 1 until implLength) {
                        if (origTypes[k - 1] != implParameters[k]) {
                            // make sure all the parameters match. Cannot use arrays.equals(*), because one will have "Connection" as
                            // a parameter - so we check that the rest match
                            found = false
                            break
                        }
                    }
                    if (found) {
                        return implMethod
                    }
                }
            }
        }
        return null
    }

    /**
     * This will methods from an interface (for RMI), and from an implementation (for "connection" overriding the method signature).
     *
     * @return an array list of all found methods for this class
     */
    fun getMethods(type: Class<*>): Array {
        val allMethods = ArrayList()
        val accessibleMethods: MutableMap>>> = HashMap()
        val classes = LinkedList>()
        classes.add(type)

        // explicitly add Object.class because that can always be called, because it is common to 100% of all java objects (and it's methods
        // are not added via parentClass.getMethods()
        classes.add(Any::class.java)

        var nextClass: Class<*>
        while (!classes.isEmpty()) {
            nextClass = classes.removeFirst()
            val methods = nextClass.methods
            for (i in methods.indices) {
                val method = methods[i]

                // static and private methods cannot be called via RMI.
                val modifiers = method.modifiers
                if (Modifier.isStatic(modifiers)) {
                    continue
                }
                if (Modifier.isPrivate(modifiers)) {
                    continue
                }
                if (method.isSynthetic) {
                    continue
                }

                // methods that have been over-ridden by another method cannot be called.
                // the first one in the map, is the "highest" level method, and is what can be called.
                val name = method.name
                val types = method.parameterTypes // length 0 if there are no parameters
                var existingTypes = accessibleMethods[name]

                if (existingTypes != null) {
                    var found = false
                    for (existingType in existingTypes) {
                        if (Arrays.equals(types, existingType)) {
                            found = true
                            break
                        }
                    }
                    if (found) {
                        // the method is overridden, so it should not be called.
                        continue
                    }
                }

                if (existingTypes == null) {
                    existingTypes = ArrayList()
                }
                existingTypes.add(types)

                // add to the map for checking later
                accessibleMethods[name] = existingTypes

                // safe to add this method to the list of recognized methods
                allMethods.add(method)
            }

            // add all interfaces from our class (if any)
            classes.addAll(listOf(*nextClass.interfaces))

            // If we are an interface, one CANNOT call any methods NOT defined by the interface!
            // also, interfaces don't have a super-class.
            val superclass = nextClass.superclass
            if (superclass != null) {
                classes.add(superclass)
            }
        }
        accessibleMethods.clear()
        Collections.sort(allMethods, METHOD_COMPARATOR)

        return allMethods.toTypedArray()
    }

    fun resolveSerializerInstance(k: Kryo, superClass: Class<*>, serializerClass: Class>): Serializer<*> {
        return try {
            try {
                serializerClass.getConstructor(Kryo::class.java, Class::class.java).newInstance(k, superClass)
            } catch (ex1: NoSuchMethodException) {
                try {
                    serializerClass.getConstructor(Kryo::class.java).newInstance(k)
                } catch (ex2: NoSuchMethodException) {
                    try {
                        serializerClass.getConstructor(Class::class.java).newInstance(superClass)
                    } catch (ex3: NoSuchMethodException) {
                        serializerClass.getDeclaredConstructor().newInstance()
                    }
                }
            }
        } catch (ex: Exception) {
            throw IllegalArgumentException(
                    "Unable to create serializer \"" + serializerClass.name + "\" for class: " + superClass.name, ex)
        }
    }

    fun getHierarchy(clazz: Class<*>): ArrayList> {
        val allClasses = ArrayList>()
        val parseClasses = LinkedList>()
        parseClasses.add(clazz)
        var nextClass: Class<*>
        while (!parseClasses.isEmpty()) {
            nextClass = parseClasses.removeFirst()
            allClasses.add(nextClass)

            // add all interfaces from our class (if any)
            parseClasses.addAll(Arrays.asList(*nextClass.interfaces))
            val superclass = nextClass.superclass
            if (superclass != null) {
                parseClasses.add(superclass)
            }
        }

        // remove the first class, because we don't need it
        allClasses.remove(clazz)
        return allClasses
    }

    fun packShorts(left: Int, right: Int): Int {
        return left shl 16 or (right and 0xFFFF)
    }

    fun unpackLeft(packedInt: Int): Int {
        return packedInt shr 16
    }

    fun unpackRight(packedInt: Int): Int {
        return packedInt.toShort().toInt()
    }

    @Suppress("EXPERIMENTAL_API_USAGE")
    fun unpackUnsignedRight(packedInt: Int): Int {
        // this just does a .toUShort().toInt() conversion. This is cleaner than doing it manually
        return packedInt.toUShort().toInt()
    }

    @Suppress("EXPERIMENTAL_API_USAGE")
    fun unpackUnsignedRight(packedLong: Long): Int {
        return packedLong.toUShort().toInt()
    }

    fun makeFancyMethodName(method: CachedMethod): String {
        return makeFancyMethodName(method.method)
    }

    fun makeFancyMethodName(method: Method): String {
        val parameterTypes: Array> = method.parameterTypes
        val size = parameterTypes.size
        val isSuspend = size > 0 && parameterTypes[size - 1] == Continuation::class.java

        val args: String = if (size == 0 || (size == 1 && isSuspend)) {
            ""
        } else {
            // ALWAYS remove Continuation, since it's REALLY with "suspend" modifier)
            if (isSuspend) {
                parameterTypes.filter { it != Continuation::class.java }.joinToString { it.simpleName }
            } else {
                parameterTypes.joinToString { it.simpleName }
            }
        }

        return if (isSuspend) {
            "suspend ${method.declaringClass.name}.${method.name}($args)"
        } else {
            "${method.declaringClass.name}.${method.name}($args)"
        }
    }




    /**
     * Remove from the stacktrace the "slice" of stack that relates to "dorkbox.network." package
     *
     * Then remove from the last occurrence of "dorkbox.network." ALL "kotlinx.coroutines." and "kotlin.coroutines." stacks
     *
     * We do this because these stack frames are not useful in resolving exception handling from a users perspective, and only clutter the stacktrace.
     */
    fun cleanStackTraceForProxy(localException: Exception, remoteException: Exception? = null) {
        val myClassName = RmiClient::class.java.name
        val stackTrace = localException.stackTrace
        var newStartIndex = 0
        var newEndIndex = stackTrace.size


        var foundStart = false

        for ((index, element) in stackTrace.withIndex()) {
            // step 1: Find the start of our method invocation
            if (!foundStart) {
                // "startsWith" because with continuations, the ACTUAL class name is mangled, ie: dorkbox.network.rmi.RmiClient$invoke$$inlined$Continuation$1
                if (element.className.startsWith(myClassName)) {
                    newStartIndex = index

                    // check 1 more time, because we want to remove the proxy invocation off the stack as well.
                    if (stackTrace[index+1].className.startsWith("com.sun.proxy.")) {
                        newStartIndex++
                    }

                    // this is where we will START (not where we are)
                    newStartIndex++

                    newEndIndex = newStartIndex
                    foundStart = true
                }
            } else {
                // step 2: Find the start of coroutines
                val className = element.className
                if (className.startsWith("kotlin.coroutines.") || className.startsWith("kotlinx.coroutines.")) {
                    // -1 because we want to end BEFORE the coroutine suspend call starts
                    newEndIndex = index-1
                    break
                }
            }
        }

        if (remoteException == null) {
            // no remote exception, just cleanup our own callstack. We don't ALWAYS have a new stack.
            if (newEndIndex > newStartIndex) {
                localException.stackTrace = stackTrace.copyOfRange(newStartIndex, newEndIndex)
            }
        } else {
            // merge this info into the remote exception, so we can get the correct call stack info
            val newStack = Array(remoteException.stackTrace.size + (newEndIndex - newStartIndex)) { stackTrace[0] }
            remoteException.stackTrace.copyInto(newStack)
            stackTrace.copyInto(newStack, remoteException.stackTrace.size, newStartIndex, newEndIndex)

            remoteException.stackTrace = newStack
        }
    }

    /**
     * Remove from the stacktrace (going in reverse), kotlin coroutine info + dorkbox network call stack.
     *
     * Neither of these are useful in resolving exception handling from a users perspective, and only clutter the stacktrace.
     */
    fun cleanStackTraceForImpl(exception: Exception, isSuspendFunction: Boolean) {
        val packageName = RmiUtils::class.java.packageName

        val stackTrace = exception.stackTrace
        val size = stackTrace.size

        if (size == 0) {
            return
        }

        var newEndIndex = -1 // because we index by size, but access from 0

        // step 1: starting at 0, find the start of our RMI method invocation
        for (element in stackTrace) {
            if (element.className.startsWith(packageName)) {
                break
            } else {
                newEndIndex++
            }
        }


        // step 2: starting at newEndIndex -> 0, find the start of reflection information (we are java11+ ONLY, so this is easy)
        for (i in newEndIndex downTo 0) {
            // this will be either JAVA reflection or ReflectASM reflection
            val stackModule = stackTrace[i].moduleName
            if (stackModule == "java.base") {
                newEndIndex--
            } else {
                break
            }
        }

        newEndIndex++ // have to add 1 back, because a copy must be by size (and we access from 0)

        if (newEndIndex > 0) {
            // if we are Java reflection, we are correct.
            // if we are ReflectASM reflection, there is ONE stack frame extra we have to remove
            if (stackTrace[newEndIndex].className == CachedAsmMethod::class.java.name) {
                newEndIndex--
            }

            // if we are a KOTLIN suspend function, there is ONE stack frame extra we have to remove
            if (isSuspendFunction && newEndIndex > 0) {
                newEndIndex--
            }

            exception.stackTrace = stackTrace.copyOfRange(0, newEndIndex)
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy