All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.com.xebia.functional.xef.Tool.kt Maven / Gradle / Ivy

There is a newer version: 0.0.5-alpha.119
Show newest version
package com.xebia.functional.xef

import com.xebia.functional.openai.generated.model.FunctionObject
import com.xebia.functional.xef.conversation.Description
import com.xebia.functional.xef.llm.FunctionCall
import com.xebia.functional.xef.llm.StreamedFunction
import com.xebia.functional.xef.llm.chatFunction
import kotlin.jvm.JvmName
import kotlin.reflect.KClass
import kotlin.reflect.KFunction1
import kotlin.reflect.KSuspendFunction1
import kotlin.reflect.KType
import kotlin.reflect.typeOf
import kotlinx.coroutines.flow.Flow
import kotlinx.serialization.*
import kotlinx.serialization.builtins.ListSerializer
import kotlinx.serialization.builtins.SetSerializer
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.json.*

sealed class Tool(
  open val function: FunctionObject,
  open val invoke: suspend (FunctionCall) -> A
) {

  data class Enumeration(
    override val function: FunctionObject,
    override val invoke: suspend (FunctionCall) -> E,
    val cases: List>,
    val enumSerializer: (String) -> E
  ) : Tool(function = function, invoke = invoke)

  class FlowOfStrings :
    Tool(function = FunctionObject("", ""), invoke = { error("Not invoked") })

  class FlowOfStreamedFunctions(
    override val function: FunctionObject,
    override val invoke: suspend (FunctionCall) -> A
  ) : Tool(function = function, invoke = invoke)

  class FlowOfAIEvents(
    val serializer: Tool,
    override val function: FunctionObject = serializer.function,
    override val invoke: suspend (FunctionCall) -> A = { serializer.invoke(it) }
  ) : Tool(function = function, invoke = invoke)

  class FlowOfAIEventsSealed(
    val sealedSerializer: Sealed,
    override val function: FunctionObject = sealedSerializer.function,
    override val invoke: suspend (FunctionCall) -> A = { sealedSerializer.invoke(it) }
  ) : Tool(function = function, invoke = invoke)

  data class Sealed(
    override val function: FunctionObject,
    override val invoke: suspend (FunctionCall) -> A,
    val cases: List,
  ) : Tool(function = function, invoke = invoke) {
    data class Case(val className: String, val tool: Tool<*>)
  }

  data class Contextual(
    override val function: FunctionObject,
    override val invoke: suspend (FunctionCall) -> A,
  ) : Tool(function = function, invoke = invoke)

  data class Callable(
    override val function: FunctionObject,
    override val invoke: suspend (FunctionCall) -> A,
  ) : Tool(function = function, invoke = invoke)

  data class Primitive(
    override val function: FunctionObject,
    override val invoke: suspend (FunctionCall) -> A
  ) : Tool(function = function, invoke = invoke)

  companion object {

    inline fun  fromKotlin(): Tool = fromKotlin(typeOf())

    @OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
    fun  fromKotlin(type: KType): Tool {
      val targetClass = type.getTargetClass()
      val descriptor = targetClass.serializer().descriptor
      val kind = descriptor.kind
      return when {
        kind == PolymorphicKind.SEALED -> sealedTool(targetClass, descriptor)
        kind == SerialKind.ENUM -> enumerationTool(targetClass, descriptor)
        type == typeOf>() -> flowOfStringsTool()
        isFlowOfStreamedFunctions(targetClass, type) -> flowOfStreamedFunctionsTool(type)
        isFlowOfAiEventsSealed(targetClass, type) -> flowOfAIEventsSealedTool(type)
        isFlowOfAIEvents(targetClass, type) -> flowOfAIEventsTool(type)
        requiresWrapping(type) -> wrappedValueTool(type, targetClass)
        else -> defaultClassTool(targetClass)
      }
    }

    @OptIn(ExperimentalSerializationApi::class, InternalSerializationApi::class)
    private fun  isFlowOfAiEventsSealed(
      targetClass: KClass,
      type: KType,
    ): Boolean {
      val innerTargetClass =
        type.arguments.firstOrNull()?.type?.arguments?.firstOrNull()?.type?.classifier as? KClass<*>
      val kind = innerTargetClass?.serializer()?.descriptor?.kind
      return isFlowOfAIEvents(targetClass, type) && kind == PolymorphicKind.SEALED
    }

    private fun isFlowOfAIEvents(targetClass: KClass<*>, type: KType): Boolean =
      targetClass == Flow::class && type.arguments[0].type?.classifier == AIEvent::class

    private fun isFlowOfStreamedFunctions(targetClass: KClass<*>, type: KType): Boolean =
      targetClass == Flow::class && type.arguments[0].type?.classifier == StreamedFunction::class

    private fun  wrappedValueTool(type: KType, targetClass: KClass): Tool {
      val collectionTypeArg = type.arguments.firstOrNull()?.type?.classifier as? KClass<*>
      return if (collectionTypeArg != null) {
        collectionTool(collectionTypeArg, targetClass)
      } else {
        primitiveTool(targetClass)
      }
    }

    private fun  KType.getTargetClass(): KClass =
      (classifier as? KClass<*> ?: error("Expected KClass got $classifier")) as KClass

    @OptIn(InternalSerializationApi::class)
    private fun  defaultClassTool(targetClass: KClass): Tool {
      val typeSerializer = targetClass.serializer()
      val functionObject = chatFunction(typeSerializer.descriptor)
      return Callable(functionObject) {
        Config.DEFAULT.json.decodeFromString(typeSerializer, it.arguments)
      }
    }

    @OptIn(InternalSerializationApi::class)
    private fun  primitiveTool(targetClass: KClass): Tool {
      val functionSerializer = Value.serializer(targetClass.serializer())
      val functionObject = chatFunction(functionSerializer.descriptor)
      return Primitive(functionObject) {
        Config.DEFAULT.json.decodeFromString(functionSerializer, it.arguments).value
      }
    }

    @OptIn(InternalSerializationApi::class)
    private fun  collectionTool(
      collectionTypeArg: KClass<*>,
      targetClass: KClass<*>
    ): Callable {
      val innerSerializer = collectionTypeArg.serializer()
      val functionSerializer =
        when (targetClass) {
          List::class -> {
            Value.serializer(ListSerializer(innerSerializer))
          }
          Set::class -> {
            Value.serializer(SetSerializer(innerSerializer))
          }
          else -> {
            error("Unsupported collection type: $targetClass, expected List or Set")
          }
        }
      val functionObject = chatFunction(functionSerializer.descriptor)
      return Callable(functionObject) {
        Config.DEFAULT.json.decodeFromString(functionSerializer, it.arguments).value as A
      }
    }

    @OptIn(InternalSerializationApi::class)
    private fun  flowOfAIEventsSealedTool(type: KType): FlowOfAIEventsSealed {
      val targetType = flowInnerContainerTypeArg(type)
      val targetClass = (targetType?.classifier as? KClass<*>)
      val typeSerializer = targetClass?.serializer() ?: error("No serializer found for $targetType")
      val tool = sealedTool(targetClass, typeSerializer.descriptor)
      return FlowOfAIEventsSealed(tool)
    }

    private fun  flowOfAIEventsTool(type: KType): FlowOfAIEvents {
      val targetType = flowInnerContainerTypeArg(type)
      val functionSerializer =
        targetType?.let { fromKotlin(it) } ?: error("No serializer found for $type")
      return FlowOfAIEvents(functionSerializer)
    }

    @OptIn(InternalSerializationApi::class)
    private fun  flowOfStreamedFunctionsTool(
      type: KType,
    ): FlowOfStreamedFunctions {
      val targetType = flowInnerContainerTypeArg(type)
      val typeSerializer =
        (targetType?.classifier as? KClass<*>)?.serializer()
          ?: error("No serializer found for $targetType")
      val functionSerializer = fromKotlin(targetType)
      val functionObject = chatFunction(typeSerializer.descriptor)
      return FlowOfStreamedFunctions(functionObject) { functionSerializer.invoke(it) }
    }

    private fun flowInnerContainerTypeArg(type: KType): KType? =
      type.arguments[0].type?.arguments?.get(0)?.type

    private fun flowOfStringsTool(): FlowOfStrings = FlowOfStrings()

    @OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
    private fun  enumerationTool(
      targetClass: KClass<*>,
      descriptor: SerialDescriptor
    ): Enumeration {
      val enumSerializer = { value: String ->
        Config.DEFAULT.json.decodeFromString(targetClass.serializer(), value) as A
      }
      val functionObject = chatFunction(descriptor)
      val cases =
        descriptor.elementDescriptors.map {
          val enumValue = it.serialName
          val enumFunction = chatFunction(it)
          Primitive(enumFunction) { enumSerializer(enumValue) }
        }
      return Enumeration(functionObject, { error("should not get called") }, cases, enumSerializer)
    }

    @OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
    private fun  sealedTool(targetClass: KClass<*>, descriptor: SerialDescriptor): Sealed {
      val sealedClassSerializer =
        targetClass.serializer() as? SealedClassSerializer
          ?: error("expected SealedClassSerializer got ${targetClass.serializer()}")
      val casesDescriptors =
        sealedClassSerializer.descriptor.elementDescriptors.toList()[1].elementDescriptors.toList()
      val functionObjectMap = casesDescriptors.associateWith { chatFunction(it) }
      val cases =
        casesDescriptors.map {
          val caseFunction =
            functionObjectMap[it] ?: error("No function found for ${it.serialName}")
          Sealed.Case(
            tool =
              Callable(caseFunction) {
                callSealedCase(it, functionObjectMap, sealedClassSerializer)
              },
            className = it.serialName
          )
        }
      return Sealed(
        chatFunction(descriptor),
        { callSealedCase(it, functionObjectMap, sealedClassSerializer) },
        cases
      )
    }

    @OptIn(InternalSerializationApi::class)
    private fun  callSealedCase(
      it: FunctionCall,
      functionObjectMap: Map,
      sealedClassSerializer: SealedClassSerializer
    ): A {
      val newJson = descriptorChoice(it, functionObjectMap)
      return Config.DEFAULT.json.decodeFromString(
        sealedClassSerializer,
        Json.encodeToString(newJson)
      ) as A
    }

    private fun descriptorChoice(
      call: FunctionCall,
      descriptors: Map
    ): JsonObject {
      // adds a `type` field with the call.functionName serial name equivalent to the call arguments
      val jsonWithDiscriminator =
        Config.DEFAULT.json.decodeFromString(JsonElement.serializer(), call.arguments)
      val descriptor =
        descriptors.values.firstOrNull { it.name.endsWith(call.functionName) }
          ?: error("No descriptor found for ${call.functionName}")
      val targetDescriptor =
        descriptors.keys.firstOrNull { it.serialName.endsWith(call.functionName) }
          ?: error("No descriptor found for ${call.functionName}")
      val newJson =
        JsonObject(
          jsonWithDiscriminator.jsonObject +
            (Config.TYPE_DISCRIMINATOR to JsonPrimitive(targetDescriptor.serialName))
        )
      return newJson
    }

    private fun requiresWrapping(type: KType): Boolean {
      val targetClass =
        type.classifier as? KClass<*> ?: error("expected KClass got ${type.classifier}")
      return when (targetClass) {
        List::class,
        Int::class,
        String::class,
        Boolean::class,
        Double::class,
        Float::class,
        Char::class,
        Byte::class -> true
        else -> false
      }
    }

    @JvmName("fromKotlinFunction1")
    inline operator fun  B> invoke(
      name: String,
      description: Description,
      fn: F,
    ): Tool {
      val tool = fromKotlin()
      return Callable(
        function = tool.function.copy(name = name, description = description.value),
        invoke = {
          val input = tool.invoke(it)
          fn(input)
        }
      )
    }

    @JvmName("fromKotlinSuspendFunction1")
    inline fun  suspend(
      name: String,
      description: Description,
      noinline fn: suspend (A) -> B,
    ): Tool {
      val tool = fromKotlin()
      return Callable(
        function = tool.function.copy(name = name, description = description.value),
        invoke = {
          val input = tool.invoke(it)
          fn(input)
        }
      )
    }

    @JvmName("fromKotlinKFunction1")
    inline operator fun > invoke(
      fn: F,
      description: Description = Description(fn.name)
    ): Tool {
      val tool = fromKotlin()
      return Callable(
        function = tool.function.copy(name = fn.name, description = description.value),
        invoke = {
          val input = tool.invoke(it)
          fn(input)
        }
      )
    }

    @JvmName("fromKotlinKSuspendFunction1")
    inline operator fun  invoke(
      fn: KSuspendFunction1,
      description: Description = Description(fn.name)
    ): Tool {
      val tool = fromKotlin()
      return Callable(
        function = tool.function.copy(name = fn.name, description = description.value),
        invoke = {
          val input = tool.invoke(it)
          fn(input)
        }
      )
    }
  }
}