com.simiacryptus.openai.proxy.ChatProxy.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of joe-penai Show documentation
Show all versions of joe-penai Show documentation
A Java client for OpenAI's API
package com.simiacryptus.openai.proxy
import com.simiacryptus.openai.OpenAIClient.ChatMessage
import com.simiacryptus.openai.OpenAIClient.ChatRequest
import com.simiacryptus.openai.OpenAIClient
import com.simiacryptus.util.JsonUtil.toJson
import java.util.concurrent.atomic.AtomicInteger
@Suppress("MemberVisibilityCanBePrivate")
class ChatProxy(
clazz: Class,
val api: OpenAIClient,
var model: OpenAIClient.Model = OpenAIClient.Models.GPT35Turbo,
temperature: Double = 0.7,
var verbose: Boolean = false,
private val moderated: Boolean = true,
val deserializerRetries: Int = 5,
validation: Boolean = true,) : GPTProxyBase(clazz, temperature, validation, deserializerRetries) {
constructor(params: LinkedHashMap) : this(
clazz = params["clazz"] as Class,
api = params["api"] as OpenAIClient? ?: OpenAIClient(),
model = params["model"] as OpenAIClient.Model? ?: OpenAIClient.Models.GPT35Turbo,
temperature = params["temperature"] as Double? ?: 0.7,
verbose = params["verbose"] as Boolean? ?: false,
moderated = params["moderated"] as Boolean? ?: true,
deserializerRetries = params["deserializerRetries"] as Int? ?: 5,
validation = params["validation"] as Boolean? ?: true,
)
override val metrics: Map
get() = hashMapOf(
"totalInputLength" to totalInputLength.get(),
"totalOutputLength" to totalOutputLength.get(),
"totalNonJsonPrefixLength" to totalNonJsonPrefixLength.get(),
"totalNonJsonSuffixLength" to totalNonJsonSuffixLength.get(),
"totalYamlLength" to totalYamlLength.get(),
"totalExamplesLength" to totalExamplesLength.get(),
) + super.metrics + api.metrics
protected val totalNonJsonPrefixLength = AtomicInteger(0)
protected val totalNonJsonSuffixLength = AtomicInteger(0)
protected val totalInputLength = AtomicInteger(0)
protected val totalYamlLength = AtomicInteger(0)
protected val totalExamplesLength = AtomicInteger(0)
protected val totalOutputLength = AtomicInteger(0)
override fun complete(prompt: ProxyRequest, vararg examples: RequestResponse): String {
if (verbose) log.info(prompt.toString())
val request = ChatRequest()
totalYamlLength.addAndGet(prompt.apiYaml.length)
val exampleMessages = examples.flatMap {
listOf(
ChatMessage(
ChatMessage.Role.user,
argsToString(it.argList)
),
ChatMessage(
ChatMessage.Role.assistant,
it.response
)
)
}
totalExamplesLength.addAndGet(toJson(exampleMessages).length)
request.messages = (
listOf(
ChatMessage(
ChatMessage.Role.system, """
|You are a JSON-RPC Service
|Responses are in JSON format
|Do not include explaining text outside the JSON
|All input arguments are optional
|Outputs are based on inputs, with any missing information filled randomly
|You will respond to the following method
|
|${prompt.apiYaml}
|""".trimMargin().trim()
)
) +
exampleMessages +
listOf(
ChatMessage(
ChatMessage.Role.user,
argsToString(prompt.argList)
)
)
).toTypedArray()
request.model = model.modelName
request.max_tokens = model.maxTokens
request.temperature = temperature
val json = toJson(request)
if (moderated) api.moderate(json)
totalInputLength.addAndGet(json.length)
val completion = api.chat(request, model).choices?.first()?.message?.content.orEmpty()
if (verbose) log.info(completion)
totalOutputLength.addAndGet(completion.length)
val trimPrefix = trimPrefix(completion)
val trimSuffix = trimSuffix(trimPrefix.first)
totalNonJsonPrefixLength.addAndGet(trimPrefix.second.length)
totalNonJsonSuffixLength.addAndGet(trimSuffix.second.length)
return trimSuffix.first
}
companion object {
val log = org.slf4j.LoggerFactory.getLogger(ChatProxy::class.java)
private fun trimPrefix(completion: String): Pair {
val start = completion.indexOf('{').coerceAtMost(completion.indexOf('['))
return if (start < 0) {
completion to ""
} else {
val substring = completion.substring(start)
substring to completion.substring(0, start)
}
}
private fun trimSuffix(completion: String): Pair {
val end = completion.lastIndexOf('}').coerceAtLeast(completion.lastIndexOf(']'))
return if (end < 0) {
completion to ""
} else {
val substring = completion.substring(0, end + 1)
substring to completion.substring(end + 1)
}
}
private fun argsToString(argList: Map) =
"{" + argList.entries.joinToString(",\n", transform = { (argName, argValue) ->
""""$argName": $argValue"""
}) + "}"
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy