commonMain.com.xebia.functional.xef.AI.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of xef-core Show documentation
Show all versions of xef-core Show documentation
Building applications with LLMs through composability in Kotlin
package com.xebia.functional.xef
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequest
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.llm.models.modelType
import com.xebia.functional.xef.llm.prompt
import com.xebia.functional.xef.llm.promptStreaming
import com.xebia.functional.xef.prompt.Prompt
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow
class AI(private val config: AIConfig, val serializer: Tool) {
private fun runStreamingWithStringSerializer(prompt: Prompt): Flow =
config.api.promptStreaming(prompt, config.conversation, config.tools)
@PublishedApi
internal suspend operator fun invoke(prompt: Prompt): A =
when (val serializer = serializer) {
is Tool.Callable -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.Contextual -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.Enumeration -> runWithEnumSingleTokenSerializer(serializer, prompt)
is Tool.FlowOfStreamedFunctions<*> -> {
config.api.promptStreaming(prompt, config.conversation, serializer, config.tools) as A
}
is Tool.FlowOfStrings -> runStreamingWithStringSerializer(prompt) as A
is Tool.Primitive -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.Sealed -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.FlowOfAIEventsSealed ->
channelFlow {
send(AIEvent.Start)
config.api.prompt(
prompt = prompt,
scope = config.conversation,
serializer = serializer.sealedSerializer,
tools = config.tools,
collector = this
)
}
as A
is Tool.FlowOfAIEvents ->
channelFlow {
send(AIEvent.Start)
config.api.prompt(
prompt = prompt,
scope = config.conversation,
serializer = serializer.serializer,
tools = config.tools,
collector = this
)
}
as A
}
private suspend fun runWithEnumSingleTokenSerializer(
serializer: Tool.Enumeration,
prompt: Prompt
): A {
val encoding = prompt.model.modelType(forFunctions = false).encoding
val cases = serializer.cases
val logitBias =
cases
.flatMap {
val result = encoding.encode(it.function.name)
if (result.size > 1) {
error("Cannot encode enum case $it into one token")
}
result
}
.associate { "$it" to 100 }
val result =
config.api.createChatCompletion(
CreateChatCompletionRequest(
messages = prompt.messages,
model = prompt.model,
logitBias = logitBias,
maxTokens = 1,
temperature = 0.0
)
)
val choice = result.choices[0].message.content
val enumSerializer = serializer.enumSerializer
return if (choice != null) {
enumSerializer(choice)
} else {
error("Cannot decode enum case from $choice")
}
}
companion object {
@AiDsl
suspend inline fun classify(
input: String,
output: String,
context: String,
config: AIConfig = AIConfig(),
): E where E : Enum, E : PromptClassifier {
val value = enumValues().firstOrNull() ?: error("No values to classify")
return AI(
prompt = value.template(input, output, context),
config = config,
)
}
@AiDsl
suspend inline fun multipleClassify(
input: String,
config: AIConfig = AIConfig(),
): List where E : Enum, E : PromptMultipleClassifier {
val values = enumValues()
val value = values.firstOrNull() ?: error("No values to classify")
val selected: SelectedItems =
AI(
prompt = value.template(input),
serializer = Tool.fromKotlin(),
config = config
)
return selected.selectedItems.mapNotNull { values.elementAtOrNull(it) }
}
}
}
@AiDsl
suspend inline fun AI(
prompt: String,
serializer: Tool = Tool.fromKotlin(),
config: AIConfig = AIConfig()
): A = AI(Prompt(config.model, prompt), serializer, config)
@AiDsl
suspend inline fun AI(
prompt: Prompt,
serializer: Tool = Tool.fromKotlin(),
config: AIConfig = AIConfig(),
): A = AI(config, serializer).invoke(prompt)