All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.mockk.impl.stub.MockKStub.kt Maven / Gradle / Ivy

package io.mockk.impl.stub

import io.mockk.*
import io.mockk.impl.InternalPlatform
import io.mockk.impl.InternalPlatform.customComputeIfAbsent
import kotlin.reflect.KClass

open class MockKStub(override val type: KClass<*>,
                     override val name: String,
                     val relaxed: Boolean = false,
                     val gatewayAccess: StubGatewayAccess) : Stub {

    private val answers = InternalPlatform.synchronizedMutableList()
    private val childs = InternalPlatform.synchronizedMutableMap()
    private val recordedCalls = InternalPlatform.synchronizedMutableList()

    lateinit var hashCodeStr: String

    override fun addAnswer(matcher: InvocationMatcher, answer: Answer<*>) {
        answers.add(InvocationAnswer(matcher, answer))
    }

    override fun answer(invocation: Invocation): Any? {
        val invocationAndMatcher = synchronized(answers) {
            answers
                    .reversed()
                    .firstOrNull { it.matcher.match(invocation) }
                    ?: return defaultAnswer(invocation)
        }

        return with(invocationAndMatcher) {
            matcher.captureAnswer(invocation)

            val call = Call(
                    invocation.method.returnType,
                    invocation,
                    matcher)

            answer.answer(call)
        }
    }


    protected inline fun stdObjectFunctions(self: Any,
                                            method: MethodDescription,
                                            args: List,
                                            otherwise: () -> Any?): Any? {
        if (method.isToString()) {
            return toStr()
        } else if (method.isHashCode()) {
            return InternalPlatformDsl.identityHashCode(self)
        } else if (method.isEquals()) {
            return self === args[0]
        } else {
            return otherwise()
        }
    }

    override fun stdObjectAnswer(invocation: Invocation): Any? {
        return stdObjectFunctions(invocation.self, invocation.method, invocation.args) {
            throw MockKException("No other calls allowed in stdObjectAnswer than equals/hashCode/toString")
        }
    }

    protected open fun defaultAnswer(invocation: Invocation): Any? {
        return stdObjectFunctions(invocation.self, invocation.method, invocation.args) {
            if (relaxed) {
                return gatewayAccess.anyValueGenerator.anyValue(invocation.method.returnType) {
                    childMockK(invocation.allEqMatcher(), invocation.method.returnType)
                }
            } else {
                throw MockKException("no answer found for: $invocation")
            }
        }
    }

    override fun recordCall(invocation: Invocation) {
        recordedCalls.add(invocation)
    }

    override fun allRecordedCalls(): List {
        synchronized(recordedCalls) {
            return recordedCalls.toList()
        }
    }

    override fun toStr() = "${type.simpleName}($name)"

    override fun childMockK(matcher: InvocationMatcher, childType: KClass<*>): Any {
        return synchronized(childs) {
            gatewayAccess.safeLog.exec {
                childs.customComputeIfAbsent(matcher) {
                    gatewayAccess.mockFactory!!.mockk(
                            childType,
                            childName(this.name),
                            moreInterfaces = arrayOf(),
                            relaxed = relaxed)
                }
            }
        }
    }

    private fun childName(name: String): String {
        val result = childOfRegex.matchEntire(name)
        return if (result != null) {
            val group = result.groupValues[2]
            val childN = if (group.isEmpty()) 1 else group.toInt()
            "child^" + (childN + 1) + " of " + result.groupValues[3]
        } else {
            "child of " + name
        }
    }

    override fun handleInvocation(self: Any,
                                  method: MethodDescription,
                                  originalCall: () -> Any?,
                                  args: Array): Any? {
        val originalPlusToString = {
            if (method.isToString()) {
                toStr()
            } else {
                originalCall()
            }
        }

        val invocation = Invocation(
                self,
                this,
                method,
                args.toList(),
                InternalPlatform.time(),
                originalPlusToString)

        return gatewayAccess.callRecorder().call(invocation)
    }

    override fun clear(answers: Boolean, calls: Boolean, childMocks: Boolean) {
        if (answers) {
            this.answers.clear()
        }
        if (calls) {
            this.recordedCalls.clear()
        }
        if (childMocks) {
            this.childs.clear()
        }
    }

    companion object {
        val childOfRegex = Regex("child(\\^(\\d+))? of (.+)")

        fun MethodDescription.isToString() = name == "toString" && paramTypes.isEmpty()
        fun MethodDescription.isHashCode() = name == "hashCode" && paramTypes.isEmpty()
        fun MethodDescription.isEquals() = name == "equals" && paramTypes.size == 1 && paramTypes[0] == Any::class
    }

    private data class InvocationAnswer(val matcher: InvocationMatcher, val answer: Answer<*>)

    protected fun Invocation.allEqMatcher() =
            InvocationMatcher(self, method,
                    args.map {
                        if (it == null)
                            NullCheckMatcher()
                        else
                            EqMatcher(it)
                    }, false)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy