
com.simiacryptus.jopenai.ChatClient.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jo-penai Show documentation
Show all versions of jo-penai Show documentation
A Java client for OpenAI's API
The newest version!
package com.simiacryptus.jopenai
import com.fasterxml.jackson.core.JsonProcessingException
import com.google.common.util.concurrent.ListeningScheduledExecutorService
import com.google.gson.Gson
import com.google.gson.JsonObject
import com.simiacryptus.jopenai.exceptions.ModerationException
import com.simiacryptus.jopenai.models.*
import com.simiacryptus.jopenai.models.ApiModel.*
import com.simiacryptus.jopenai.util.ClientUtil.allowedCharset
import com.simiacryptus.jopenai.util.ClientUtil.checkError
import com.simiacryptus.jopenai.util.ClientUtil.defaultApiProvider
import com.simiacryptus.jopenai.util.ClientUtil.keyMap
import com.simiacryptus.util.JsonUtil
import com.simiacryptus.util.StringUtil
import com.simiacryptus.util.runWithPermit
import org.apache.hc.client5.http.classic.methods.HttpPost
import org.apache.hc.core5.http.HttpRequest
import org.apache.hc.core5.http.io.entity.EntityUtils
import org.apache.hc.core5.http.io.entity.StringEntity
import org.slf4j.LoggerFactory
import org.slf4j.event.Level
import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain
import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider
import software.amazon.awssdk.core.SdkBytes
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest
import java.io.BufferedOutputStream
import java.io.IOException
import java.util.*
import java.util.concurrent.Semaphore
import java.util.concurrent.ThreadPoolExecutor
open class ChatClient(
protected var key: Map = keyMap.mapKeys { APIProvider.valueOf(it.key) },
protected val apiBase: Map = APIProvider.values().associate { it to (it.base ?: "") },
logLevel: Level = Level.INFO,
logStreams: MutableList = mutableListOf(),
scheduledPool: ListeningScheduledExecutorService = HttpClientManager.scheduledPool,
workPool: ThreadPoolExecutor = HttpClientManager.workPool
) : HttpClientManager(
logLevel = logLevel,
logStreams = logStreams,
scheduledPool = scheduledPool,
workPool = workPool
) {
open var session: Any? = null
open var user: Any? = null
var budget: Number? = null
protected open class ChildClient(
val inner: ChatClient,
key: Map = inner.key,
apiBase: Map = inner.apiBase
) : ChatClient(
key = key,
apiBase = apiBase,
logLevel = Level.INFO
) {
override fun log(level: Level, msg: String) {
super.log(level, msg)
inner.log(level, msg)
}
}
open fun getChildClient(): ChatClient = ChildClient(inner = this, key = key, apiBase = apiBase).apply {
session = inner.session
user = inner.user
}
protected open fun onUsage(model: OpenAIModel?, tokens: Usage) {
log.debug(
"Usage recorded for session: {}, user: {}, model: {}, tokens: {}",
session,
user,
model,
tokens
)
if (null != budget) budget = budget!!.toDouble() - (tokens.cost ?: 0.0)
}
fun moderate(text: String) = withReliability {
when {
defaultApiProvider == APIProvider.Groq -> return@withReliability
defaultApiProvider == APIProvider.ModelsLab -> return@withReliability
}
withPerformanceLogging {
val body: String = try {
JsonUtil.objectMapper().writeValueAsString(
mapOf(
"input" to StringUtil.restrictCharacterSet(text, allowedCharset)
)
)
} catch (e: JsonProcessingException) {
throw RuntimeException(e)
}
val result: String = try {
this.post("${apiBase[defaultApiProvider]}/moderations", body, defaultApiProvider)
} catch (e: IOException) {
throw RuntimeException(e)
} catch (e: InterruptedException) {
throw RuntimeException(e)
}
val jsonObject = Gson().fromJson(
result, JsonObject::class.java
) ?: return@withPerformanceLogging
if (jsonObject.has("error")) {
val errorObject = jsonObject.getAsJsonObject("error")
throw RuntimeException(IOException(errorObject["message"].asString))
}
val moderationResult = jsonObject.getAsJsonArray("results")[0].asJsonObject
if (moderationResult["flagged"].asBoolean) {
val categoriesObj = moderationResult["categories"].asJsonObject
throw RuntimeException(
ModerationException("Moderation flagged this request due to " + categoriesObj.keySet()
.stream().filter { c: String? ->
categoriesObj[c].asBoolean
}.reduce { a: String, b: String -> "$a, $b" }.orElse("???")
)
)
}
}
}
@Throws(IOException::class, InterruptedException::class)
private fun post(url: String, json: String, apiProvider: APIProvider): String {
val request = HttpPost(url)
request.addHeader("Content-Type", "application/json")
request.addHeader("Accept", "application/json")
authorize(request, apiProvider)
request.entity = StringEntity(json, Charsets.UTF_8, false)
return post(request)
}
private fun post(request: HttpPost): String = withClient { EntityUtils.toString(it.execute(request).entity) }
@Throws(IOException::class)
protected open fun authorize(request: HttpRequest, apiProvider: APIProvider) {
log.debug("Authorizing request for session: {}, user: {}, apiProvider: {}", session, user, apiProvider)
require(null == budget || budget!!.toDouble() > 0.0) { "Budget Exceeded" }
when (apiProvider) {
APIProvider.Google -> {
// request.addHeader("X-goog-api-key", "${key.get(apiProvider)}")
}
APIProvider.Anthropic -> {
request.addHeader("x-api-key", "${key.get(apiProvider)}")
request.addHeader("anthropic-version", "2023-06-01")
}
else -> request.addHeader("Authorization", "Bearer ${key.get(apiProvider)}")
}
}
open fun chat(
chatRequest: ChatRequest, model: TextModel
): ChatResponse {
var chatRequest = chatRequest
log.info("Starting chat with model: ${model.modelName}")
if (model.modelName in listOf("o1-preview", "o1-mini")) {
chatRequest = chatRequest.copy(
messages = chatRequest.messages.map { message ->
if (message.role == Role.system) {
message.copy(role = Role.user)
} else {
message
}
},
temperature = 1.0,
stop = null
)
log.debug("Adjusted chat request for model: ${model.modelName}")
}
if (chatRequest.messages.any { it.content?.any { it.text?.contains(" {
val geminiChatRequest =
toGeminiChatRequest(chatRequest.copy(messages = chatRequest.messages.map {
it.copy(
role = when (it.role) {
Role.system -> Role.user
else -> it.role
}
)
}), model)
val json = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter()
.writeValueAsString(geminiChatRequest)
log(
level = Level.DEBUG,
msg = String.format(
"Chat Request %s\nPrefix:\n\t%s\n",
requestID,
json.replace("\n", "\n\t")
)
)
fromGemini(
post(
"${apiBase[apiProvider]}/v1beta/${model.modelName}:generateContent?key=${key[apiProvider]}",
json,
apiProvider
)
)
}
apiProvider == APIProvider.Anthropic -> {
val anthropicChatRequest = mapToAnthropicChatRequest(chatRequest, model)
val json = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter()
.writeValueAsString(anthropicChatRequest)
log(
level = Level.DEBUG,
msg = String.format(
"Chat Request %s\nPrefix:\n\t%s\n",
requestID,
json.replace("\n", "\n\t")
)
)
val request = HttpPost("${apiBase[apiProvider]}/messages")
request.addHeader("Content-Type", "application/json")
request.addHeader("Accept", "application/json")
request.addHeader("x-api-key", "${key.get(apiProvider)}")
request.addHeader("anthropic-version", "2023-06-01")
request.entity = StringEntity(json, Charsets.UTF_8, false)
val rawResponse = post(request)
fromAnthropicResponse(rawResponse)
}
apiProvider == APIProvider.Perplexity -> {
val json =
JsonUtil.objectMapper().writerWithDefaultPrettyPrinter()
.writeValueAsString(chatRequest.copy(stop = null))
log(
level = Level.DEBUG,
msg = String.format(
"Chat Request %s\nPrefix:\n\t%s\n",
requestID,
json.replace("\n", "\n\t")
)
)
post("${apiBase[apiProvider]}/chat/completions", json, apiProvider)
}
apiProvider == APIProvider.Mistral -> {
val json = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter()
.writeValueAsString(toGroq(chatRequest))
log(
level = Level.DEBUG,
msg = String.format(
"Chat Request %s\nPrefix:\n\t%s\n",
requestID,
json.replace("\n", "\n\t")
)
)
post("${apiBase[apiProvider]}/chat/completions", json, apiProvider)
}
apiProvider == APIProvider.Groq -> {
val json = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter()
.writeValueAsString(toGroq(chatRequest))
log(
level = Level.DEBUG,
msg = String.format(
"Chat Request %s\nPrefix:\n\t%s\n",
requestID,
json.replace("\n", "\n\t")
)
)
post("${apiBase[apiProvider]}/chat/completions", json, apiProvider)
}
apiProvider == APIProvider.ModelsLab -> {
modelsLabThrottle.runWithPermit {
val json =
JsonUtil.objectMapper().writerWithDefaultPrettyPrinter()
.writeValueAsString(toModelsLab(chatRequest))
log(
level = Level.DEBUG,
msg = String.format(
"Chat Request %s\nPrefix:\n\t%s\n",
requestID,
json.replace("\n", "\n\t")
)
)
fromModelsLab(post("${apiBase[apiProvider]}/llm/chat", json, apiProvider))
}
}
apiProvider == APIProvider.AWS -> {
val awsAuth = JsonUtil.fromJson(key[apiProvider]!!, AWSAuth::class.java)
val invokeModelRequest = toAWS(model, chatRequest)
val bedrockRuntimeClient = BedrockRuntimeClient.builder()
.credentialsProvider(awsCredentials(awsAuth))
.region(Region.of(awsAuth.region))
.build()
log(
level = Level.DEBUG,
msg = String.format(
"Chat Request %s\nPrefix:\n\t%s\n",
requestID, JsonUtil.toJson(chatRequest).replace("\n", "\n\t")
)
)
val invokeModelResponse = bedrockRuntimeClient
.invokeModel(invokeModelRequest)
val responseBody = invokeModelResponse.body().asString(Charsets.UTF_8)
fromAWS(responseBody, model.modelName)
}
else -> {
val json =
JsonUtil.objectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(chatRequest)
log(
level = Level.DEBUG,
msg = String.format(
"Chat Request %s\nPrefix:\n\t%s\n",
requestID,
json.replace("\n", "\n\t")
)
)
post("${apiBase[apiProvider]}/chat/completions", json, apiProvider)
}
}
checkError(result)
val response = JsonUtil.objectMapper().readValue(result, ChatResponse::class.java)
if (response.usage != null) {
onUsage(model, response.usage.copy(cost = model.pricing(response.usage)))
}
log(
level = Level.DEBUG,
msg = String.format(
"Chat Completion %s:\n\t%s", requestID,
response.choices.firstOrNull()?.message?.content?.trim { it <= ' ' }?.replace("\n", "\n\t")
?: JsonUtil.toJson(response)
)
)
response
}
}
}
private fun awsCredentials(awsAuth: AWSAuth): AwsCredentialsProviderChain? =
AwsCredentialsProviderChain.builder().credentialsProviders(
InstanceProfileCredentialsProvider.create(),
ProfileCredentialsProvider.create(awsAuth.profile),
).build()
private fun fromGemini(responseBody: String): String {
val fromJson = JsonUtil.fromJson(responseBody, GenerateContentResponse::class.java)
return JsonUtil.toJson(
ChatResponse(
choices = fromJson.candidates.mapIndexed { index, candidate ->
ChatChoice(
message = ChatMessageResponse(
content = candidate.content.parts?.joinToString("\n") { it.text ?: "" }
),
index = index
)
}
)
)
}
private fun toGeminiChatRequest(chatRequest: ChatRequest, model: TextModel): GenerateContentRequest {
return GenerateContentRequest(
// model = model.modelName,
/*
system_instruction = chatRequest.messages.filter { it.role == Role.system }?.reduceOrNull {
acc, chatMessage ->
ChatMessage(
role = Role.system,
content = acc.content?.plus(chatMessage.content ?: emptyList())
?: chatMessage.content
)
}?.content?.let {
Content(
parts = it.map {
Part(
text = it.text ?: ""
)
}
)
},
*/
contents = collectRoleSequences(chatRequest.messages.filter {
when (it.role) {
// Role.system -> false
else -> true
}
}.map {
Content(
role = when (it.role) {
Role.user -> "user"
Role.system -> "user"
Role.assistant -> "model"
else -> throw RuntimeException("Unsupported role: ${it.role}")
},
parts = it.content?.map {
Part(
text = it.text
)
}
)
}).map { collectTextParts(it) },
generationConfig = GenerationConfig(
temperature = 0.3f,
/*chatRequest.temperature.toFloat(),*/
// candidateCount = 1,
// maxOutputTokens = model.maxOutTokens-1,
// topK = 0,
// topP = 0.9f,
// stopSequences = chatRequest.stop?.map { it.toString() }
)
/*
*/
)
}
private fun collectTextParts(it: Content): Content {
var text = ""
val partsList = it.parts?.toMutableList() ?: mutableListOf()
val newParts = mutableListOf()
while (partsList.isNotEmpty()) {
val parts = partsList.takeWhile { it.text != null }
text = parts.joinToString("\n") { it.text ?: "" }
partsList.removeAll(parts)
newParts.add(Part(text = text))
// Copy all non-text parts
val nonTextParts = partsList.takeWhile { it.text == null }
newParts.addAll(nonTextParts)
partsList.removeAll(nonTextParts)
}
return Content(parts = newParts)
}
private fun collectRoleSequences(map: List): List {
val alternatingMessages = mutableListOf()
val messagesCopy = map.toMutableList()
while (messagesCopy.isNotEmpty()) {
val thisRole = messagesCopy.firstOrNull()?.role
val toConsolidate = messagesCopy.takeWhile { it.role == thisRole }.toTypedArray()
messagesCopy.removeAll(toConsolidate)
val consolidatedMessage = toConsolidate.reduceOrNull { acc, chatMessage ->
Content(
role = acc.role,
parts = acc.parts?.plus(chatMessage.parts ?: emptyList())
?: chatMessage.parts
)
}
alternatingMessages.add(consolidatedMessage ?: Content())
}
return alternatingMessages
}
private data class GenerateContentRequest(
val model: String? = null,
val contents: List? = null,
val system_instruction: Content? = null,
val safetySettings: List? = null,
val generationConfig: GenerationConfig? = null
)
private data class Content(
val role: String? = null,
val parts: List? = null
)
private data class Part(
val inlineData: Blob? = null,
val text: String? = null
)
private data class Blob(
val mimeType: String? = null,
val data: String? = null
)
private data class SafetySetting(
val threshold: String? = null,
val category: String? = null
)
private data class GenerationConfig(
val temperature: Float? = null,
val candidateCount: Int? = null,
val topK: Int? = null,
val maxOutputTokens: Int? = null,
val topP: Float? = null,
val stopSequences: List? = null
)
private data class GenerateContentResponse(
val candidates: List
)
private data class Candidate(
val content: Content, // Reuse or adjust your existing Content class
val finishReason: String,
val index: Int,
val safetyRatings: List
)
private data class SafetyRating(
val category: String,
val probability: String
)
private fun mapToAnthropicChatRequest(chatRequest: ChatRequest, model: TextModel): AnthropicChatRequest {
return AnthropicChatRequest(
model = chatRequest.model,
system = chatRequest.messages.firstOrNull { it.role == Role.system }?.content?.joinToString("\n\n") {
it.text ?: ""
},
messages = alternateAnthropicRoles(chatRequest.messages.filter { it.role != Role.system }),
max_tokens = chatRequest.max_tokens ?: model.maxOutTokens,
temperature = chatRequest.temperature,
// top_p = chatRequest.top_p,
// top_k = chatRequest.top_k
)
}
private fun alternateAnthropicRoles(messages: List): List {
val alternatingMessages = mutableListOf()
val remainingMessages = messages.toMutableList()
while (remainingMessages.isNotEmpty()) {
val thisRole = remainingMessages.firstOrNull()?.role
val toConsolidate = remainingMessages.takeWhile { it.role == thisRole }.toTypedArray()
remainingMessages.removeAll(toConsolidate)
alternatingMessages += AnthropicMessage(
role = thisRole.toString(),
content = toConsolidate.joinToString("\n\n") { it.content?.joinToString("\n") { it.text ?: "" } ?: "" }
)
}
return alternatingMessages
}
private data class AnthropicChatRequest(
val model: String? = null,
val system: String? = null,
val messages: List? = null,
val max_tokens: Int? = null,
val temperature: Double? = null,
val top_p: Double? = null,
val top_k: Int? = null
)
private data class AnthropicMessage(
val role: String? = null,
val content: String? = null
)
private data class AnthropicResponse(
val id: String,
val type: String,
val role: String,
val content: List,
val model: String,
val stop_reason: String,
val stop_sequence: String?,
val usage: AnthropicUsage
)
private data class AnthropicContentBlock(
val type: String,
val text: String?
)
private data class AnthropicUsage(
val input_tokens: Int,
val output_tokens: Int
)
private data class AWSAuth(
val profile: String = "default",
val region: String = Region.US_WEST_2.id(),
)
private fun toAWS(model: TextModel, chatRequest: ChatRequest) = InvokeModelRequest.builder()
.modelId(model.modelName)
.accept("application/json")
.contentType("application/json")
.body(SdkBytes.fromString(JsonUtil.toJson(awsBody(model, chatRequest)), Charsets.UTF_8))
.build()
private fun awsBody(
model: TextModel,
chatRequest: ChatRequest
): Map = when {
model.modelName.contains("llama") -> {
mapOf(
"prompt" to toSimplePrompt(chatRequest),
"max_gen_len" to model.maxOutTokens,
"temperature" to chatRequest.temperature,
// "top_p" to 0.9,
)
}
//mistral
model.modelName.contains("mistral") -> {
mapOf(
"prompt" to toSimplePrompt(chatRequest),
"max_tokens" to model.maxOutTokens,
"temperature" to chatRequest.temperature,
// "top_p" to 0.9,
// "top_k" to 50,
)
}
model.modelName.contains("titan") -> {
mapOf(
"inputText" to toSimplePrompt(chatRequest),
"textGenerationConfig" to mapOf(
"maxTokenCount" to model.maxTotalTokens,
"stopSequences" to emptyList(),
"temperature" to chatRequest.temperature,
// "topP" to 0.9,
)
)
}
model.modelName.contains("cohere") -> {
mapOf(
"prompt" to toSimplePrompt(chatRequest),
"max_tokens" to model.maxTotalTokens,
"temperature" to chatRequest.temperature,
// "p" to 1,
// "k" to 0,
)
}
model.modelName.contains("ai21") -> {
mapOf(
"prompt" to toSimplePrompt(chatRequest),
"maxTokens" to model.maxTotalTokens,
"temperature" to chatRequest.temperature,
// "topP" to 0.9,
"stopSequences" to emptyList(),
"countPenalty" to mapOf("scale" to 0),
"presencePenalty" to mapOf("scale" to 0),
"frequencyPenalty" to mapOf("scale" to 0),
)
}
model.modelName.contains("anthropic") -> {
val alternatingMessages = alternateMessagesRoles(chatRequest.messages)
mapOf(
"anthropic_version" to anthropic_version(model),
"max_tokens" to model.maxOutTokens,
"temperature" to chatRequest.temperature,
"messages" to alternatingMessages.filter {
when (it.role) {
Role.system -> false
else -> true
}
}.map {
mapOf(
"role" to it.role.toString(),
"content" to it.content?.map {
mapOf(
"type" to "text",
"text" to it.text
)
}
)
},
"system" to toSimplePrompt(chatRequest) { it.role == Role.system },
).filterValues { it != null }
}
else -> throw RuntimeException("Unsupported model: $model")
}
private fun anthropic_version(model: TextModel) = when {
else -> "bedrock-2023-05-31"
// else -> null
}
private fun alternateMessagesRoles(messages: List): List {
val alternatingMessages = mutableListOf()
val messagesCopy = messages.toMutableList()
var isFirst = true
while (messagesCopy.isNotEmpty()) {
val thisRole = messagesCopy.firstOrNull()?.role
val consolidatedMessage = takeAll(messagesCopy, thisRole)
if (isFirst) {
isFirst = false
if ((consolidatedMessage?.role ?: "") != "user") {
val chatMessage = takeAll(messagesCopy, Role.user)
alternatingMessages.add(
concat(
(consolidatedMessage ?: ChatMessage()).copy(role = Role.user),
chatMessage ?: ChatMessage()
)
)
continue
}
}
alternatingMessages.add(consolidatedMessage ?: ChatMessage())
}
return alternatingMessages
}
private fun takeAll(
messagesCopy: MutableList,
thisRole: Role?
): ChatMessage? {
val toConsolidate = messagesCopy.takeWhile { it.role == thisRole }.toTypedArray()
messagesCopy.removeAll(toConsolidate)
val consolidatedMessage = toConsolidate.reduceOrNull { acc, chatMessage ->
concat(acc, chatMessage)
}
return consolidatedMessage
}
private fun concat(
acc: ChatMessage,
chatMessage: ChatMessage
) = ChatMessage(
role = acc.role,
content = listOf(
ContentPart(
type = "text",
text = (acc.content?.plus(chatMessage.content ?: emptyList())
?: chatMessage.content)?.joinToString("\n") { it.text ?: "" }
)
)
)
private fun toSimplePrompt(
chatRequest: ChatRequest,
filterFn: (ChatMessage) -> Boolean = { true }
) = if (chatRequest.messages.filter(filterFn).map { it.role }.distinct().size <= 1) {
chatRequest.messages.filter(filterFn).joinToString("\n\n") {
it.content?.joinToString("\n") { it.text ?: "" } ?: ""
}
} else {
chatRequest.messages.filter(filterFn).joinToString("\n\n") {
"${it.role}: \n" + it.content?.joinToString("\n") { "\t" + (it.text ?: "") }
}
}
private fun fromAWS(responseBody: String, model: String): String {
return when {
model.contains("llama") -> {
val fromJson = JsonUtil.fromJson(responseBody, AwsResponseLlama2::class.java)
JsonUtil.toJson(
ChatResponse(
choices = listOf(
ChatChoice(
message = ChatMessageResponse(
content = fromJson.generation ?: ""
),
index = 0
)
),
usage = Usage(
prompt_tokens = fromJson.prompt_token_count?.toLong() ?: 0,
completion_tokens = fromJson.generation_token_count?.toLong() ?: 0,
total_tokens = (fromJson.prompt_token_count?.toLong()
?: 0) + (fromJson.generation_token_count?.toLong() ?: 0)
)
)
)
}
model.contains("mistral") -> {
val fromJson = JsonUtil.fromJson(responseBody, AwsResponseMistral::class.java)
JsonUtil.toJson(
ChatResponse(
choices = listOf(
ChatChoice(
message = ChatMessageResponse(
content = fromJson.outputs.first().text ?: ""
),
index = 0
)
)
)
)
}
model.contains("titan") -> {
val fromJson = JsonUtil.fromJson(responseBody, AwsResponseTitan::class.java)
JsonUtil.toJson(
ChatResponse(
choices = listOf(
ChatChoice(
message = ChatMessageResponse(
content = fromJson.results.first().outputText ?: ""
),
index = 0
)
)
)
)
}
model.contains("cohere") -> {
val fromJson = JsonUtil.fromJson(responseBody, AwsResponseCohere::class.java)
JsonUtil.toJson(
ChatResponse(
choices = listOf(
ChatChoice(
message = ChatMessageResponse(
content = fromJson.generations.first().text ?: ""
),
index = 0
)
)
)
)
}
model.contains("ai21") -> {
val fromJson = JsonUtil.objectMapper().readValue(responseBody, Ai21ChatResponse::class.java)
return JsonUtil.toJson(
ChatResponse(
choices = fromJson.completions?.mapIndexed { index, completion ->
ChatChoice(
message = ChatMessageResponse(
content = completion.data?.text ?: ""
),
index = index
)
} ?: emptyList(),
)
)
}
model.contains("anthropic") -> {
val fromJson = JsonUtil.fromJson(responseBody, AwsResponseAnthropic::class.java)
JsonUtil.toJson(
ChatResponse(
choices = listOf(
ChatChoice(
message = ChatMessageResponse(
content = fromJson.content?.firstOrNull()?.text ?: ""
),
index = 0
)
),
usage = Usage(
prompt_tokens = fromJson.usage?.input_tokens?.toLong() ?: 0,
completion_tokens = fromJson.usage?.output_tokens?.toLong() ?: 0,
total_tokens = (fromJson.usage?.input_tokens?.toLong()
?: 0) + (fromJson.usage?.output_tokens ?: 0)
)
)
)
}
else -> throw RuntimeException("Unsupported model: $model")
}
}
private data class AwsResponseAnthropic(
val id: String? = null,
val type: String? = null,
val role: String? = null,
val content: List? = null,
val model: String? = null,
val stop_reason: String? = null,
val stop_sequence: String? = null,
val usage: AwsResponseAnthropicUsage? = null
)
private data class AwsResponseAnthropicContent(
val type: String? = null,
val text: String? = null
)
private data class AwsResponseAnthropicUsage(
val input_tokens: Int? = null,
val output_tokens: Int? = null
)
private data class Ai21ChatResponse(
val id: Int? = null,
val prompt: Ai21Prompt? = null,
val completions: List? = null
)
private data class Ai21Completion(
val data: Ai21Data? = null,
val finishReason: Ai21FinishReason? = null
)
private data class Ai21FinishReason(
val reason: String? = null
)
private data class Ai21Data(
val text: String? = null,
val tokens: List? = null
)
private data class Ai21Prompt(
val text: String? = null,
val tokens: List? = null
)
private data class Ai21Token(
val generatedToken: Ai21GeneratedToken? = null,
val topTokens: List? = null,
val textRange: Ai21TextRange? = null
)
private data class Ai21GeneratedToken(
val token: String? = null,
val logprob: Double? = null,
val raw_logprob: Double? = null
)
private data class Ai21TopToken(
val token: String? = null,
val logprob: Double? = null,
val raw_logprob: Double? = null
)
private data class Ai21TextRange(
val start: Int? = null,
val end: Int? = null
)
private data class AwsResponseCohere(
val generations: List
)
private data class AwsResponseCohereGeneration(
val text: String? = null
)
private data class AwsResponseMistral(
val outputs: List
)
private data class AwsResponseMistralOutput(
val text: String? = null,
val stop_reason: String? = null
)
private data class AwsResponseTitan(
val inputTextTokenCount: Int? = null,
val results: List
)
private data class AwsResponseTitanResult(
val tokenCount: Int? = null,
val outputText: String? = null,
val completionReason: String? = null
)
private data class AwsResponseLlama2(
val generation: String? = null,
val prompt_token_count: Int? = null,
val generation_token_count: Int? = null,
val stop_reason: String? = null
)
private fun fromAnthropicResponse(rawResponse: String): String {
try {
val errorCheck = JsonUtil.objectMapper().readTree(rawResponse)
if (errorCheck.has("type") && errorCheck.get("type").asText() == "error") {
throw RuntimeException("Error response received: $rawResponse")
}
val response = JsonUtil.objectMapper().readValue(rawResponse, AnthropicResponse::class.java)
return JsonUtil.toJson(
ChatResponse(
id = response.id,
choices = listOf(
ChatChoice(
message = ChatMessageResponse(
content = response.content.joinToString("\n") { it.text ?: "" }
),
index = 0
)
),
usage = Usage(
prompt_tokens = response.usage.input_tokens.toLong(),
completion_tokens = response.usage.output_tokens.toLong(),
total_tokens = response.usage.input_tokens.toLong() + response.usage.output_tokens
)
)
)
} catch (e: Exception) {
throw RuntimeException("Error parsing Anthropic response: $rawResponse", e)
}
}
private fun fromModelsLab(rawResponse: String): String {
val response = JsonUtil.objectMapper().readValue(rawResponse, ModelsLabDataModel.ChatResponse::class.java)
return when (response.status) {
"success" -> {
JsonUtil.toJson(ChatResponse(
id = response.chat_id,
choices = listOf(
ChatChoice(
message = ChatMessageResponse(content = response.message),
index = 0
)
),
usage = response.meta?.let {
Usage(
prompt_tokens = it.max_new_tokens?.toLong() ?: 0,
completion_tokens = 0, // Assuming no direct mapping; adjust as needed.
total_tokens = it.max_new_tokens?.toLong() ?: 0
)
}
))
}
"processing" -> {
val seconds = response?.eta ?: 1
log.info("Chat response is still processing; waiting ${seconds}s and trying again.")
Thread.sleep(seconds * 1000L)
val postCheck = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(
mapOf(
"chat_id" to (response.meta?.chat_id ?: response.chat_id),
"key" to key[APIProvider.ModelsLab]
)
)
fromModelsLab(
post(
"${apiBase[defaultApiProvider]}/llm/get_queued_response",
postCheck,
defaultApiProvider
)
)
}
"error" -> {
throw RuntimeException("Error in chat request: ${response.message}\n$rawResponse")
}
"failed" -> {
throw RuntimeException("Chat request failed: ${response.message}\n$rawResponse")
}
else -> throw RuntimeException("Unknown status: ${response.status}\n${response.message}\n$rawResponse")
}
}
private fun toModelsLab(chatRequest: ChatRequest) =
modelslab_chatRequest_prototype.copy(
key = key[APIProvider.ModelsLab],
model_id = chatRequest.model,
system_prompt = chatRequest.messages.filter { it.role == Role.system }.joinToString("\n") {
it.content?.joinToString("\n") { it.text ?: "" } ?: ""
},
prompt = chatRequest.messages.filter { it.role != Role.system }.joinToString("\n") {
it.content?.joinToString("\n") { it.text ?: "" } ?: ""
},
temperature = chatRequest.temperature,
)
private fun toGroq(chatRequest: ChatRequest): GroqChatRequest = GroqChatRequest(
messages = chatRequest.messages.map { message ->
GroqChatMessage(
role = message.role,
content = message.content?.joinToString("\n") { it.text ?: "" } ?: "",
)
},
model = chatRequest.model,
max_tokens = chatRequest.max_tokens,
temperature = chatRequest.temperature,
)
companion object {
private val log = LoggerFactory.getLogger(OpenAIClient::class.java)
var modelsLabThrottle = Semaphore(1)
var modelslab_chatRequest_prototype = ModelsLabDataModel.ChatRequest(
max_new_tokens = 1000,
no_repeat_ngram_size = 5,
)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy