All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.zeoflow.depot.compiler.processing.ksp.KspClassFileUtility.kt Maven / Gradle / Ivy

Go to download

The Depot persistence library provides an abstraction layer over SQLite to allow for more robust database access while using the full power of SQLite.

There is a newer version: 1.0.5
Show newest version
/*
 * Copyright (C) 2021 ZeoFlow SRL
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.zeoflow.depot.compiler.processing.ksp

import com.zeoflow.depot.compiler.processing.XProcessingConfig
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.Origin
import java.lang.reflect.InvocationHandler
import java.lang.reflect.Method
import java.lang.reflect.Proxy

/**
 * When a compiled kotlin class is loaded from a `.class` file, its fields are not ordered in the
 * same way they are declared in code.
 * This particularly hurts Depot where we generate the table structure in that order.
 *
 * This class implements a port of https://github.com/google/ksp/pull/260 via reflection until KSP
 * (or kotlin compiler) fixes the problem. As this uses reflection, it is fail safe such that if it
 * cannot find the correct order, it will just return in the order KSP returned instead of crashing.
 *
 * KSP Bugs:
 *  * https://github.com/google/ksp/issues/250
 */
internal object KspClassFileUtility {
    /**
     * Sorts the given fields in the order they are declared in the backing class declaration.
     */
    fun orderFields(
        owner: KSClassDeclaration,
        fields: List
    ): List {
        // no reason to try to load .class if we don't have any fields to sort
        if (fields.isEmpty()) return fields
        val comparator = getNamesComparator(owner, Type.FIELD, KspFieldElement::name)
        return if (comparator == null) {
            fields
        } else {
            fields.forEach {
                // make sure each name gets registered so that if we didn't find it in .class for
                // whatever reason, we keep the order given from KSP.
                comparator.register(it.name)
            }
            fields.sortedWith(comparator)
        }
    }

    /**
     * Sorts the given methods in the order they are declared in the backing class declaration.
     * Note that this does not check signatures so ordering might break if there are multiple
     * methods with the same name.
     */
    fun orderMethods(
        owner: KSClassDeclaration,
        methods: List
    ): List {
        // no reason to try to load .class if we don't have any fields to sort
        if (methods.isEmpty()) return methods
        val comparator = getNamesComparator(owner, Type.METHOD, KspMethodElement::name)
        return if (comparator == null) {
            methods
        } else {
            methods.forEach {
                // make sure each name gets registered so that if we didn't find it in .class for
                // whatever reason, we keep the order given from KSP.
                comparator.register(it.name)
            }
            methods.sortedWith(comparator)
        }
    }

    /**
     * Builds a field names comparator from the given class declaration if and only if its origin
     * is Kotlin .class.
     * If it fails to find the order, returns null.
     */
    @Suppress("BanUncheckedReflection")
    private fun  getNamesComparator(
        ksClassDeclaration: KSClassDeclaration,
        type: Type,
        getName: T.() -> String,
    ): MemberNameComparator? {
        return try {
            // this is needed only for compiled kotlin classes
            // https://github.com/google/ksp/issues/250#issuecomment-761108924
            if (ksClassDeclaration.origin != Origin.KOTLIN_LIB) return null
            val typeReferences = ReflectionReferences.getInstance(ksClassDeclaration) ?: return null
            val descriptor = typeReferences.getDescriptorMethod.invoke(ksClassDeclaration)
                ?: return null
            if (!typeReferences.deserializedClassDescriptor.isInstance(descriptor)) {
                return null
            }
            val descriptorSrc = typeReferences.descriptorSourceMethod.invoke(descriptor)
                ?: return null
            if (!typeReferences.kotlinJvmBinarySourceElement.isInstance(descriptorSrc)) {
                return null
            }
            val binarySource = typeReferences.binaryClassMethod.invoke(descriptorSrc)
                ?: return null

            val fieldNameComparator = MemberNameComparator(getName)
            val invocationHandler = InvocationHandler { _, method, args ->
                if (method.name == type.visitorName) {
                    val nameAsString = typeReferences.asStringMethod.invoke(args[0])
                    if (nameAsString is String) {
                        fieldNameComparator.register(nameAsString)
                    }
                }
                null
            }

            val proxy = Proxy.newProxyInstance(
                ksClassDeclaration.javaClass.classLoader,
                arrayOf(typeReferences.memberVisitor),
                invocationHandler
            )
            typeReferences.visitMembersMethod.invoke(binarySource, proxy, null)
            fieldNameComparator.seal()
            fieldNameComparator
        } catch (ignored: Throwable) {
            // this is best effort, if it failed, just ignore
            if (XProcessingConfig.STRICT_MODE) {
                throw RuntimeException("failed to get fields", ignored)
            }
            null
        }
    }

    /**
     * Holder object to keep references to class/method instances.
     */
    private class ReflectionReferences private constructor(
        classLoader: ClassLoader
    ) {

        val deserializedClassDescriptor: Class<*> = classLoader.loadClass(
            "org.jetbrains.kotlin.serialization.deserialization.descriptors" +
                ".DeserializedClassDescriptor"
        )

        val ksClassDeclarationDescriptorImpl: Class<*> = classLoader.loadClass(
            "com.google.devtools.ksp.symbol.impl.binary.KSClassDeclarationDescriptorImpl"
        )
        val kotlinJvmBinarySourceElement: Class<*> = classLoader.loadClass(
            "org.jetbrains.kotlin.load.kotlin.KotlinJvmBinarySourceElement"
        )

        val kotlinJvmBinaryClass: Class<*> = classLoader.loadClass(
            "org.jetbrains.kotlin.load.kotlin.KotlinJvmBinaryClass"
        )

        val memberVisitor: Class<*> = classLoader.loadClass(
            "org.jetbrains.kotlin.load.kotlin.KotlinJvmBinaryClass\$MemberVisitor"
        )

        val name: Class<*> = classLoader.loadClass(
            "org.jetbrains.kotlin.name.Name"
        )

        val getDescriptorMethod: Method = ksClassDeclarationDescriptorImpl
            .getDeclaredMethod("getDescriptor")

        val descriptorSourceMethod: Method = deserializedClassDescriptor.getMethod("getSource")

        val binaryClassMethod: Method = kotlinJvmBinarySourceElement.getMethod("getBinaryClass")

        val visitMembersMethod: Method = kotlinJvmBinaryClass.getDeclaredMethod(
            "visitMembers",
            memberVisitor, ByteArray::class.java
        )

        val asStringMethod: Method = name.getDeclaredMethod("asString")

        companion object {
            private val FAILED = Any()
            private var instance: Any? = null

            /**
             * Gets the cached instance or create a new one using the class loader of the given
             * [ref] parameter.
             */
            fun getInstance(ref: Any): ReflectionReferences? {
                if (instance == null) {
                    instance = try {
                        ReflectionReferences(ref::class.java.classLoader)
                    } catch (ignored: Throwable) {
                        FAILED
                    }
                }
                return instance as? ReflectionReferences
            }
        }
    }

    private class MemberNameComparator(
        val getName: T.() -> String
    ) : Comparator {
        private var nextOrder: Int = 0
        private var sealed: Boolean = false
        private val orders = mutableMapOf()

        /**
         * Called when fields are read to lock the ordering.
         * This is only relevant in tests as at runtime, we just do a best effort (add a new id
         * for it) and continue.
         */
        fun seal() {
            sealed = true
        }

        /**
         * Registers the name with the next order id
         */
        fun register(name: String) {
            getOrder(name)
        }

        /**
         * Gets the order of the name. If it was not seen before, adds it to the list, giving it a
         * new ID.
         */
        private fun getOrder(name: String) = orders.getOrPut(name) {
            if (sealed && XProcessingConfig.STRICT_MODE) {
                error("expected to find field $name but it is non-existent")
            }
            nextOrder++
        }

        override fun compare(elm1: T, elm2: T): Int {
            return getOrder(elm1.getName()).compareTo(getOrder(elm2.getName()))
        }
    }

    /**
     * The type of declaration that we want to extract from class descriptor.
     */
    private enum class Type(
        val visitorName: String
    ) {
        FIELD("visitField"),
        METHOD("visitMethod")
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy