All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.openai.proxy.ChatProxy.kt Maven / Gradle / Ivy

The newest version!
package com.simiacryptus.openai.proxy

import com.simiacryptus.openai.models.OpenAIModel
import com.simiacryptus.openai.models.ChatModels
import com.simiacryptus.openai.OpenAIClient
import com.simiacryptus.openai.OpenAIClient.*
import com.simiacryptus.openai.OpenAIClientBase.Companion.toContentList
import com.simiacryptus.openai.models.OpenAITextModel
import com.simiacryptus.util.JsonUtil.toJson
import java.util.concurrent.atomic.AtomicInteger

open class ChatProxy(
    clazz: Class,
    val api: OpenAIClient,
    var model: OpenAITextModel = ChatModels.GPT35Turbo,
    temperature: Double = 0.7,
    private var verbose: Boolean = false,
    private val moderated: Boolean = true,
    val deserializerRetries: Int = 5,
    validation: Boolean = true
) : GPTProxyBase(clazz, temperature, validation, deserializerRetries) {

    constructor(params: LinkedHashMap) : this(
        clazz = params["clazz"] as Class,
        api = params["api"] as OpenAIClient? ?: OpenAIClient(),
        model = params["model"] as OpenAITextModel? ?: ChatModels.GPT35Turbo,
        temperature = params["temperature"] as Double? ?: 0.7,
        verbose = params["verbose"] as Boolean? ?: false,
        moderated = params["moderated"] as Boolean? ?: true,
        deserializerRetries = params["deserializerRetries"] as Int? ?: 5,
        validation = params["validation"] as Boolean? ?: true,
    )

    override val metrics: Map
        get() = hashMapOf(
            "totalInputLength" to totalInputLength.get(),
            "totalOutputLength" to totalOutputLength.get(),
            "totalNonJsonPrefixLength" to totalNonJsonPrefixLength.get(),
            "totalNonJsonSuffixLength" to totalNonJsonSuffixLength.get(),
            "totalYamlLength" to totalYamlLength.get(),
            "totalExamplesLength" to totalExamplesLength.get(),
        ) + super.metrics + api.metrics
    private val totalNonJsonPrefixLength = AtomicInteger(0)
    private val totalNonJsonSuffixLength = AtomicInteger(0)
    private val totalInputLength = AtomicInteger(0)
    private val totalYamlLength = AtomicInteger(0)
    private val totalExamplesLength = AtomicInteger(0)
    private val totalOutputLength = AtomicInteger(0)

    override fun complete(prompt: ProxyRequest, vararg examples: RequestResponse): String {
        if (verbose) log.info(prompt.toString())
        var request = ChatRequest()
        totalYamlLength.addAndGet(prompt.apiYaml.length)
        val exampleMessages = examples.flatMap {
            listOf(
                ChatMessage(
                    Role.user,
                    argsToString(it.argList).toContentList()
                ),
                ChatMessage(
                    Role.assistant,
                    it.response.toContentList()
                )
            )
        }
        totalExamplesLength.addAndGet(toJson(exampleMessages).length)
        request = request.copy(messages = ArrayList(
                listOf(
                    ChatMessage(
                        Role.system, """
                |You are a JSON-RPC Service
                |Responses are in JSON format
                |Do not include explaining text outside the JSON
                |All input arguments are optional
                |Outputs are based on inputs, with any missing information filled randomly
                |You will respond to the following method
                |
                |${prompt.apiYaml}
                |""".trimMargin().trim().toContentList()
                    )
                ) +
                        exampleMessages +
                        listOf(
                            ChatMessage(
                                Role.user,
                                argsToString(prompt.argList).toContentList()
                            )
                        )
                ))
        request = request.copy(model = model.modelName)
        request = request.copy(temperature = temperature)
        val json = toJson(request)
        if (moderated) api.moderate(json)
        totalInputLength.addAndGet(json.length)

        val completion = api.chat(request, model).choices.first().message?.content.orEmpty()
        if (verbose) log.info(completion)
        totalOutputLength.addAndGet(completion.length)
        val trimPrefix = trimPrefix(completion)
        val trimSuffix = trimSuffix(trimPrefix.first)
        totalNonJsonPrefixLength.addAndGet(trimPrefix.second.length)
        totalNonJsonSuffixLength.addAndGet(trimSuffix.second.length)
        return trimSuffix.first
    }

    companion object {

        private val log = org.slf4j.LoggerFactory.getLogger(ChatProxy::class.java)
        private fun trimPrefix(completion: String): Pair {
            val start = completion.indexOf('{').coerceAtMost(completion.indexOf('['))
            return if (start < 0) {
                completion to ""
            } else {
                val substring = completion.substring(start)
                substring to completion.substring(0, start)
            }
        }

        private fun trimSuffix(completion: String): Pair {
            val end = completion.lastIndexOf('}').coerceAtLeast(completion.lastIndexOf(']'))
            return if (end < 0) {
                completion to ""
            } else {
                val substring = completion.substring(0, end + 1)
                substring to completion.substring(end + 1)
            }
        }

        private fun argsToString(argList: Map) =
            "{" + argList.entries.joinToString(",\n", transform = { (argName, argValue) ->
                """"$argName": $argValue"""
            }) + "}"
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy