All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.com.xebia.functional.xef.AI.kt Maven / Gradle / Ivy

The newest version!
package com.xebia.functional.xef

import com.xebia.functional.openai.generated.model.CreateChatCompletionRequest
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.llm.models.modelType
import com.xebia.functional.xef.llm.prompt
import com.xebia.functional.xef.llm.promptStreaming
import com.xebia.functional.xef.prompt.Prompt
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow

class AI(private val config: AIConfig, val serializer: Tool) {

  private fun runStreamingWithStringSerializer(prompt: Prompt): Flow =
    config.api.promptStreaming(prompt, config.conversation, config.tools)

  @PublishedApi
  internal suspend operator fun invoke(prompt: Prompt): A =
    when (val serializer = serializer) {
      is Tool.Callable -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
      is Tool.Contextual -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
      is Tool.Enumeration -> runWithEnumSingleTokenSerializer(serializer, prompt)
      is Tool.FlowOfStreamedFunctions<*> -> {
        config.api.promptStreaming(prompt, config.conversation, serializer, config.tools) as A
      }
      is Tool.FlowOfStrings -> runStreamingWithStringSerializer(prompt) as A
      is Tool.Primitive -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
      is Tool.Sealed -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
      is Tool.FlowOfAIEventsSealed ->
        channelFlow {
          send(AIEvent.Start)
          config.api.prompt(
            prompt = prompt,
            scope = config.conversation,
            serializer = serializer.sealedSerializer,
            tools = config.tools,
            collector = this
          )
        }
          as A
      is Tool.FlowOfAIEvents ->
        channelFlow {
          send(AIEvent.Start)
          config.api.prompt(
            prompt = prompt,
            scope = config.conversation,
            serializer = serializer.serializer,
            tools = config.tools,
            collector = this
          )
        }
          as A
    }

  private suspend fun runWithEnumSingleTokenSerializer(
    serializer: Tool.Enumeration,
    prompt: Prompt
  ): A {
    val encoding = prompt.model.modelType(forFunctions = false).encoding
    val cases = serializer.cases
    val logitBias =
      cases
        .flatMap {
          val result = encoding.encode(it.function.name)
          if (result.size > 1) {
            error("Cannot encode enum case $it into one token")
          }
          result
        }
        .associate { "$it" to 100 }
    val result =
      config.api.createChatCompletion(
        CreateChatCompletionRequest(
          messages = prompt.messages,
          model = prompt.model,
          logitBias = logitBias,
          maxTokens = 1,
          temperature = 0.0
        )
      )
    val choice = result.choices[0].message.content
    val enumSerializer = serializer.enumSerializer
    return if (choice != null) {
      enumSerializer(choice)
    } else {
      error("Cannot decode enum case from $choice")
    }
  }

  companion object {
    @AiDsl
    suspend inline fun  classify(
      input: String,
      output: String,
      context: String,
      config: AIConfig = AIConfig(),
    ): E where E : Enum, E : PromptClassifier {
      val value = enumValues().firstOrNull() ?: error("No values to classify")
      return AI(
        prompt = value.template(input, output, context),
        config = config,
      )
    }

    @AiDsl
    suspend inline fun  multipleClassify(
      input: String,
      config: AIConfig = AIConfig(),
    ): List where E : Enum, E : PromptMultipleClassifier {
      val values = enumValues()
      val value = values.firstOrNull() ?: error("No values to classify")
      val selected: SelectedItems =
        AI(
          prompt = value.template(input),
          serializer = Tool.fromKotlin(),
          config = config
        )
      return selected.selectedItems.mapNotNull { values.elementAtOrNull(it) }
    }
  }
}

@AiDsl
suspend inline fun  AI(
  prompt: String,
  serializer: Tool = Tool.fromKotlin(),
  config: AIConfig = AIConfig()
): A = AI(Prompt(config.model, prompt), serializer, config)

@AiDsl
suspend inline fun  AI(
  prompt: Prompt,
  serializer: Tool = Tool.fromKotlin(),
  config: AIConfig = AIConfig(),
): A = AI(config, serializer).invoke(prompt)




© 2015 - 2024 Weber Informatics LLC | Privacy Policy