All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.jopenai.ChatClient.kt Maven / Gradle / Ivy

The newest version!
package com.simiacryptus.jopenai

import com.fasterxml.jackson.core.JsonProcessingException
import com.google.common.util.concurrent.ListeningScheduledExecutorService
import com.google.gson.Gson
import com.google.gson.JsonObject
import com.simiacryptus.jopenai.exceptions.ModerationException
import com.simiacryptus.jopenai.models.*
import com.simiacryptus.jopenai.models.ApiModel.*
import com.simiacryptus.jopenai.util.ClientUtil.allowedCharset
import com.simiacryptus.jopenai.util.ClientUtil.checkError
import com.simiacryptus.jopenai.util.ClientUtil.defaultApiProvider
import com.simiacryptus.jopenai.util.ClientUtil.keyMap
import com.simiacryptus.util.JsonUtil
import com.simiacryptus.util.StringUtil
import com.simiacryptus.util.runWithPermit
import org.apache.hc.client5.http.classic.methods.HttpPost
import org.apache.hc.core5.http.HttpRequest
import org.apache.hc.core5.http.io.entity.EntityUtils
import org.apache.hc.core5.http.io.entity.StringEntity
import org.slf4j.LoggerFactory
import org.slf4j.event.Level
import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain
import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider
import software.amazon.awssdk.core.SdkBytes
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest
import java.io.BufferedOutputStream
import java.io.IOException
import java.util.*
import java.util.concurrent.Semaphore
import java.util.concurrent.ThreadPoolExecutor

open class ChatClient(
    protected var key: Map = keyMap.mapKeys { APIProvider.valueOf(it.key) },
    protected val apiBase: Map = APIProvider.values().associate { it to (it.base ?: "") },
    logLevel: Level = Level.INFO,
    logStreams: MutableList = mutableListOf(),
    scheduledPool: ListeningScheduledExecutorService = HttpClientManager.scheduledPool,
    workPool: ThreadPoolExecutor = HttpClientManager.workPool
) : HttpClientManager(
    logLevel = logLevel,
    logStreams = logStreams,
    scheduledPool = scheduledPool,
    workPool = workPool
) {

    open var session: Any? = null
    open var user: Any? = null
    var budget: Number? = null

    protected open class ChildClient(
        val inner: ChatClient,
        key: Map = inner.key,
        apiBase: Map = inner.apiBase
    ) : ChatClient(
        key = key,
        apiBase = apiBase,
        logLevel = Level.INFO
    ) {
        override fun log(level: Level, msg: String) {
            super.log(level, msg)
            inner.log(level, msg)
        }
    }

    open fun getChildClient(): ChatClient = ChildClient(inner = this, key = key, apiBase = apiBase).apply {
        session = inner.session
        user = inner.user
    }

    protected open fun onUsage(model: OpenAIModel?, tokens: Usage) {
        log.debug(
            "Usage recorded for session: {}, user: {}, model: {}, tokens: {}",
            session,
            user,
            model,
            tokens
        )
        if (null != budget) budget = budget!!.toDouble() - (tokens.cost ?: 0.0)
    }

    fun moderate(text: String) = withReliability {
        when {
            defaultApiProvider == APIProvider.Groq -> return@withReliability
            defaultApiProvider == APIProvider.ModelsLab -> return@withReliability
        }
        withPerformanceLogging {
            val body: String = try {
                JsonUtil.objectMapper().writeValueAsString(
                    mapOf(
                        "input" to StringUtil.restrictCharacterSet(text, allowedCharset)
                    )
                )
            } catch (e: JsonProcessingException) {
                throw RuntimeException(e)
            }
            val result: String = try {
                this.post("${apiBase[defaultApiProvider]}/moderations", body, defaultApiProvider)
            } catch (e: IOException) {
                throw RuntimeException(e)
            } catch (e: InterruptedException) {
                throw RuntimeException(e)
            }
            val jsonObject = Gson().fromJson(
                result, JsonObject::class.java
            ) ?: return@withPerformanceLogging
            if (jsonObject.has("error")) {
                val errorObject = jsonObject.getAsJsonObject("error")
                throw RuntimeException(IOException(errorObject["message"].asString))
            }
            val moderationResult = jsonObject.getAsJsonArray("results")[0].asJsonObject
            if (moderationResult["flagged"].asBoolean) {
                val categoriesObj = moderationResult["categories"].asJsonObject
                throw RuntimeException(
                    ModerationException("Moderation flagged this request due to " + categoriesObj.keySet()
                        .stream().filter { c: String? ->
                            categoriesObj[c].asBoolean
                        }.reduce { a: String, b: String -> "$a, $b" }.orElse("???")
                    )
                )
            }
        }
    }

    @Throws(IOException::class, InterruptedException::class)
    private fun post(url: String, json: String, apiProvider: APIProvider): String {
        val request = HttpPost(url)
        request.addHeader("Content-Type", "application/json")
        request.addHeader("Accept", "application/json")
        authorize(request, apiProvider)
        request.entity = StringEntity(json, Charsets.UTF_8, false)
        return post(request)
    }

    private fun post(request: HttpPost): String = withClient { EntityUtils.toString(it.execute(request).entity) }

    @Throws(IOException::class)
    protected open fun authorize(request: HttpRequest, apiProvider: APIProvider) {
        log.debug("Authorizing request for session: {}, user: {}, apiProvider: {}", session, user, apiProvider)
        require(null == budget || budget!!.toDouble() > 0.0) { "Budget Exceeded" }
        when (apiProvider) {
            APIProvider.Google -> {
//        request.addHeader("X-goog-api-key", "${key.get(apiProvider)}")
            }

            APIProvider.Anthropic -> {
                request.addHeader("x-api-key", "${key.get(apiProvider)}")
                request.addHeader("anthropic-version", "2023-06-01")
            }

            else -> request.addHeader("Authorization", "Bearer ${key.get(apiProvider)}")
        }
    }

    open fun chat(
        chatRequest: ChatRequest, model: TextModel
    ): ChatResponse {
        var chatRequest = chatRequest
        log.info("Starting chat with model: ${model.modelName}")
        if (model.modelName in listOf("o1-preview", "o1-mini")) {
            chatRequest = chatRequest.copy(
                messages = chatRequest.messages.map { message ->
                    if (message.role == Role.system) {
                        message.copy(role = Role.user)
                    } else {
                        message
                    }
                },
                temperature = 1.0,
                stop = null
            )
            log.debug("Adjusted chat request for model: ${model.modelName}")
        }
        if (chatRequest.messages.any { it.content?.any { it.text?.contains("
{ val geminiChatRequest = toGeminiChatRequest(chatRequest.copy(messages = chatRequest.messages.map { it.copy( role = when (it.role) { Role.system -> Role.user else -> it.role } ) }), model) val json = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter() .writeValueAsString(geminiChatRequest) log( level = Level.DEBUG, msg = String.format( "Chat Request %s\nPrefix:\n\t%s\n", requestID, json.replace("\n", "\n\t") ) ) fromGemini( post( "${apiBase[apiProvider]}/v1beta/${model.modelName}:generateContent?key=${key[apiProvider]}", json, apiProvider ) ) } apiProvider == APIProvider.Anthropic -> { val anthropicChatRequest = mapToAnthropicChatRequest(chatRequest, model) val json = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter() .writeValueAsString(anthropicChatRequest) log( level = Level.DEBUG, msg = String.format( "Chat Request %s\nPrefix:\n\t%s\n", requestID, json.replace("\n", "\n\t") ) ) val request = HttpPost("${apiBase[apiProvider]}/messages") request.addHeader("Content-Type", "application/json") request.addHeader("Accept", "application/json") request.addHeader("x-api-key", "${key.get(apiProvider)}") request.addHeader("anthropic-version", "2023-06-01") request.entity = StringEntity(json, Charsets.UTF_8, false) val rawResponse = post(request) fromAnthropicResponse(rawResponse) } apiProvider == APIProvider.Perplexity -> { val json = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter() .writeValueAsString(chatRequest.copy(stop = null)) log( level = Level.DEBUG, msg = String.format( "Chat Request %s\nPrefix:\n\t%s\n", requestID, json.replace("\n", "\n\t") ) ) post("${apiBase[apiProvider]}/chat/completions", json, apiProvider) } apiProvider == APIProvider.Mistral -> { val json = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter() .writeValueAsString(toGroq(chatRequest)) log( level = Level.DEBUG, msg = String.format( "Chat Request %s\nPrefix:\n\t%s\n", requestID, json.replace("\n", "\n\t") ) ) post("${apiBase[apiProvider]}/chat/completions", json, apiProvider) } apiProvider == APIProvider.Groq -> { val json = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter() .writeValueAsString(toGroq(chatRequest)) log( level = Level.DEBUG, msg = String.format( "Chat Request %s\nPrefix:\n\t%s\n", requestID, json.replace("\n", "\n\t") ) ) post("${apiBase[apiProvider]}/chat/completions", json, apiProvider) } apiProvider == APIProvider.ModelsLab -> { modelsLabThrottle.runWithPermit { val json = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter() .writeValueAsString(toModelsLab(chatRequest)) log( level = Level.DEBUG, msg = String.format( "Chat Request %s\nPrefix:\n\t%s\n", requestID, json.replace("\n", "\n\t") ) ) fromModelsLab(post("${apiBase[apiProvider]}/llm/chat", json, apiProvider)) } } apiProvider == APIProvider.AWS -> { val awsAuth = JsonUtil.fromJson(key[apiProvider]!!, AWSAuth::class.java) val invokeModelRequest = toAWS(model, chatRequest) val bedrockRuntimeClient = BedrockRuntimeClient.builder() .credentialsProvider(awsCredentials(awsAuth)) .region(Region.of(awsAuth.region)) .build() log( level = Level.DEBUG, msg = String.format( "Chat Request %s\nPrefix:\n\t%s\n", requestID, JsonUtil.toJson(chatRequest).replace("\n", "\n\t") ) ) val invokeModelResponse = bedrockRuntimeClient .invokeModel(invokeModelRequest) val responseBody = invokeModelResponse.body().asString(Charsets.UTF_8) fromAWS(responseBody, model.modelName) } else -> { val json = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(chatRequest) log( level = Level.DEBUG, msg = String.format( "Chat Request %s\nPrefix:\n\t%s\n", requestID, json.replace("\n", "\n\t") ) ) post("${apiBase[apiProvider]}/chat/completions", json, apiProvider) } } checkError(result) val response = JsonUtil.objectMapper().readValue(result, ChatResponse::class.java) if (response.usage != null) { onUsage(model, response.usage.copy(cost = model.pricing(response.usage))) } log( level = Level.DEBUG, msg = String.format( "Chat Completion %s:\n\t%s", requestID, response.choices.firstOrNull()?.message?.content?.trim { it <= ' ' }?.replace("\n", "\n\t") ?: JsonUtil.toJson(response) ) ) response } } } private fun awsCredentials(awsAuth: AWSAuth): AwsCredentialsProviderChain? = AwsCredentialsProviderChain.builder().credentialsProviders( InstanceProfileCredentialsProvider.create(), ProfileCredentialsProvider.create(awsAuth.profile), ).build() private fun fromGemini(responseBody: String): String { val fromJson = JsonUtil.fromJson(responseBody, GenerateContentResponse::class.java) return JsonUtil.toJson( ChatResponse( choices = fromJson.candidates.mapIndexed { index, candidate -> ChatChoice( message = ChatMessageResponse( content = candidate.content.parts?.joinToString("\n") { it.text ?: "" } ), index = index ) } ) ) } private fun toGeminiChatRequest(chatRequest: ChatRequest, model: TextModel): GenerateContentRequest { return GenerateContentRequest( // model = model.modelName, /* system_instruction = chatRequest.messages.filter { it.role == Role.system }?.reduceOrNull { acc, chatMessage -> ChatMessage( role = Role.system, content = acc.content?.plus(chatMessage.content ?: emptyList()) ?: chatMessage.content ) }?.content?.let { Content( parts = it.map { Part( text = it.text ?: "" ) } ) }, */ contents = collectRoleSequences(chatRequest.messages.filter { when (it.role) { // Role.system -> false else -> true } }.map { Content( role = when (it.role) { Role.user -> "user" Role.system -> "user" Role.assistant -> "model" else -> throw RuntimeException("Unsupported role: ${it.role}") }, parts = it.content?.map { Part( text = it.text ) } ) }).map { collectTextParts(it) }, generationConfig = GenerationConfig( temperature = 0.3f, /*chatRequest.temperature.toFloat(),*/ // candidateCount = 1, // maxOutputTokens = model.maxOutTokens-1, // topK = 0, // topP = 0.9f, // stopSequences = chatRequest.stop?.map { it.toString() } ) /* */ ) } private fun collectTextParts(it: Content): Content { var text = "" val partsList = it.parts?.toMutableList() ?: mutableListOf() val newParts = mutableListOf() while (partsList.isNotEmpty()) { val parts = partsList.takeWhile { it.text != null } text = parts.joinToString("\n") { it.text ?: "" } partsList.removeAll(parts) newParts.add(Part(text = text)) // Copy all non-text parts val nonTextParts = partsList.takeWhile { it.text == null } newParts.addAll(nonTextParts) partsList.removeAll(nonTextParts) } return Content(parts = newParts) } private fun collectRoleSequences(map: List): List { val alternatingMessages = mutableListOf() val messagesCopy = map.toMutableList() while (messagesCopy.isNotEmpty()) { val thisRole = messagesCopy.firstOrNull()?.role val toConsolidate = messagesCopy.takeWhile { it.role == thisRole }.toTypedArray() messagesCopy.removeAll(toConsolidate) val consolidatedMessage = toConsolidate.reduceOrNull { acc, chatMessage -> Content( role = acc.role, parts = acc.parts?.plus(chatMessage.parts ?: emptyList()) ?: chatMessage.parts ) } alternatingMessages.add(consolidatedMessage ?: Content()) } return alternatingMessages } private data class GenerateContentRequest( val model: String? = null, val contents: List? = null, val system_instruction: Content? = null, val safetySettings: List? = null, val generationConfig: GenerationConfig? = null ) private data class Content( val role: String? = null, val parts: List? = null ) private data class Part( val inlineData: Blob? = null, val text: String? = null ) private data class Blob( val mimeType: String? = null, val data: String? = null ) private data class SafetySetting( val threshold: String? = null, val category: String? = null ) private data class GenerationConfig( val temperature: Float? = null, val candidateCount: Int? = null, val topK: Int? = null, val maxOutputTokens: Int? = null, val topP: Float? = null, val stopSequences: List? = null ) private data class GenerateContentResponse( val candidates: List ) private data class Candidate( val content: Content, // Reuse or adjust your existing Content class val finishReason: String, val index: Int, val safetyRatings: List ) private data class SafetyRating( val category: String, val probability: String ) private fun mapToAnthropicChatRequest(chatRequest: ChatRequest, model: TextModel): AnthropicChatRequest { return AnthropicChatRequest( model = chatRequest.model, system = chatRequest.messages.firstOrNull { it.role == Role.system }?.content?.joinToString("\n\n") { it.text ?: "" }, messages = alternateAnthropicRoles(chatRequest.messages.filter { it.role != Role.system }), max_tokens = chatRequest.max_tokens ?: model.maxOutTokens, temperature = chatRequest.temperature, // top_p = chatRequest.top_p, // top_k = chatRequest.top_k ) } private fun alternateAnthropicRoles(messages: List): List { val alternatingMessages = mutableListOf() val remainingMessages = messages.toMutableList() while (remainingMessages.isNotEmpty()) { val thisRole = remainingMessages.firstOrNull()?.role val toConsolidate = remainingMessages.takeWhile { it.role == thisRole }.toTypedArray() remainingMessages.removeAll(toConsolidate) alternatingMessages += AnthropicMessage( role = thisRole.toString(), content = toConsolidate.joinToString("\n\n") { it.content?.joinToString("\n") { it.text ?: "" } ?: "" } ) } return alternatingMessages } private data class AnthropicChatRequest( val model: String? = null, val system: String? = null, val messages: List? = null, val max_tokens: Int? = null, val temperature: Double? = null, val top_p: Double? = null, val top_k: Int? = null ) private data class AnthropicMessage( val role: String? = null, val content: String? = null ) private data class AnthropicResponse( val id: String, val type: String, val role: String, val content: List, val model: String, val stop_reason: String, val stop_sequence: String?, val usage: AnthropicUsage ) private data class AnthropicContentBlock( val type: String, val text: String? ) private data class AnthropicUsage( val input_tokens: Int, val output_tokens: Int ) private data class AWSAuth( val profile: String = "default", val region: String = Region.US_WEST_2.id(), ) private fun toAWS(model: TextModel, chatRequest: ChatRequest) = InvokeModelRequest.builder() .modelId(model.modelName) .accept("application/json") .contentType("application/json") .body(SdkBytes.fromString(JsonUtil.toJson(awsBody(model, chatRequest)), Charsets.UTF_8)) .build() private fun awsBody( model: TextModel, chatRequest: ChatRequest ): Map = when { model.modelName.contains("llama") -> { mapOf( "prompt" to toSimplePrompt(chatRequest), "max_gen_len" to model.maxOutTokens, "temperature" to chatRequest.temperature, // "top_p" to 0.9, ) } //mistral model.modelName.contains("mistral") -> { mapOf( "prompt" to toSimplePrompt(chatRequest), "max_tokens" to model.maxOutTokens, "temperature" to chatRequest.temperature, // "top_p" to 0.9, // "top_k" to 50, ) } model.modelName.contains("titan") -> { mapOf( "inputText" to toSimplePrompt(chatRequest), "textGenerationConfig" to mapOf( "maxTokenCount" to model.maxTotalTokens, "stopSequences" to emptyList(), "temperature" to chatRequest.temperature, // "topP" to 0.9, ) ) } model.modelName.contains("cohere") -> { mapOf( "prompt" to toSimplePrompt(chatRequest), "max_tokens" to model.maxTotalTokens, "temperature" to chatRequest.temperature, // "p" to 1, // "k" to 0, ) } model.modelName.contains("ai21") -> { mapOf( "prompt" to toSimplePrompt(chatRequest), "maxTokens" to model.maxTotalTokens, "temperature" to chatRequest.temperature, // "topP" to 0.9, "stopSequences" to emptyList(), "countPenalty" to mapOf("scale" to 0), "presencePenalty" to mapOf("scale" to 0), "frequencyPenalty" to mapOf("scale" to 0), ) } model.modelName.contains("anthropic") -> { val alternatingMessages = alternateMessagesRoles(chatRequest.messages) mapOf( "anthropic_version" to anthropic_version(model), "max_tokens" to model.maxOutTokens, "temperature" to chatRequest.temperature, "messages" to alternatingMessages.filter { when (it.role) { Role.system -> false else -> true } }.map { mapOf( "role" to it.role.toString(), "content" to it.content?.map { mapOf( "type" to "text", "text" to it.text ) } ) }, "system" to toSimplePrompt(chatRequest) { it.role == Role.system }, ).filterValues { it != null } } else -> throw RuntimeException("Unsupported model: $model") } private fun anthropic_version(model: TextModel) = when { else -> "bedrock-2023-05-31" // else -> null } private fun alternateMessagesRoles(messages: List): List { val alternatingMessages = mutableListOf() val messagesCopy = messages.toMutableList() var isFirst = true while (messagesCopy.isNotEmpty()) { val thisRole = messagesCopy.firstOrNull()?.role val consolidatedMessage = takeAll(messagesCopy, thisRole) if (isFirst) { isFirst = false if ((consolidatedMessage?.role ?: "") != "user") { val chatMessage = takeAll(messagesCopy, Role.user) alternatingMessages.add( concat( (consolidatedMessage ?: ChatMessage()).copy(role = Role.user), chatMessage ?: ChatMessage() ) ) continue } } alternatingMessages.add(consolidatedMessage ?: ChatMessage()) } return alternatingMessages } private fun takeAll( messagesCopy: MutableList, thisRole: Role? ): ChatMessage? { val toConsolidate = messagesCopy.takeWhile { it.role == thisRole }.toTypedArray() messagesCopy.removeAll(toConsolidate) val consolidatedMessage = toConsolidate.reduceOrNull { acc, chatMessage -> concat(acc, chatMessage) } return consolidatedMessage } private fun concat( acc: ChatMessage, chatMessage: ChatMessage ) = ChatMessage( role = acc.role, content = listOf( ContentPart( type = "text", text = (acc.content?.plus(chatMessage.content ?: emptyList()) ?: chatMessage.content)?.joinToString("\n") { it.text ?: "" } ) ) ) private fun toSimplePrompt( chatRequest: ChatRequest, filterFn: (ChatMessage) -> Boolean = { true } ) = if (chatRequest.messages.filter(filterFn).map { it.role }.distinct().size <= 1) { chatRequest.messages.filter(filterFn).joinToString("\n\n") { it.content?.joinToString("\n") { it.text ?: "" } ?: "" } } else { chatRequest.messages.filter(filterFn).joinToString("\n\n") { "${it.role}: \n" + it.content?.joinToString("\n") { "\t" + (it.text ?: "") } } } private fun fromAWS(responseBody: String, model: String): String { return when { model.contains("llama") -> { val fromJson = JsonUtil.fromJson(responseBody, AwsResponseLlama2::class.java) JsonUtil.toJson( ChatResponse( choices = listOf( ChatChoice( message = ChatMessageResponse( content = fromJson.generation ?: "" ), index = 0 ) ), usage = Usage( prompt_tokens = fromJson.prompt_token_count?.toLong() ?: 0, completion_tokens = fromJson.generation_token_count?.toLong() ?: 0, total_tokens = (fromJson.prompt_token_count?.toLong() ?: 0) + (fromJson.generation_token_count?.toLong() ?: 0) ) ) ) } model.contains("mistral") -> { val fromJson = JsonUtil.fromJson(responseBody, AwsResponseMistral::class.java) JsonUtil.toJson( ChatResponse( choices = listOf( ChatChoice( message = ChatMessageResponse( content = fromJson.outputs.first().text ?: "" ), index = 0 ) ) ) ) } model.contains("titan") -> { val fromJson = JsonUtil.fromJson(responseBody, AwsResponseTitan::class.java) JsonUtil.toJson( ChatResponse( choices = listOf( ChatChoice( message = ChatMessageResponse( content = fromJson.results.first().outputText ?: "" ), index = 0 ) ) ) ) } model.contains("cohere") -> { val fromJson = JsonUtil.fromJson(responseBody, AwsResponseCohere::class.java) JsonUtil.toJson( ChatResponse( choices = listOf( ChatChoice( message = ChatMessageResponse( content = fromJson.generations.first().text ?: "" ), index = 0 ) ) ) ) } model.contains("ai21") -> { val fromJson = JsonUtil.objectMapper().readValue(responseBody, Ai21ChatResponse::class.java) return JsonUtil.toJson( ChatResponse( choices = fromJson.completions?.mapIndexed { index, completion -> ChatChoice( message = ChatMessageResponse( content = completion.data?.text ?: "" ), index = index ) } ?: emptyList(), ) ) } model.contains("anthropic") -> { val fromJson = JsonUtil.fromJson(responseBody, AwsResponseAnthropic::class.java) JsonUtil.toJson( ChatResponse( choices = listOf( ChatChoice( message = ChatMessageResponse( content = fromJson.content?.firstOrNull()?.text ?: "" ), index = 0 ) ), usage = Usage( prompt_tokens = fromJson.usage?.input_tokens?.toLong() ?: 0, completion_tokens = fromJson.usage?.output_tokens?.toLong() ?: 0, total_tokens = (fromJson.usage?.input_tokens?.toLong() ?: 0) + (fromJson.usage?.output_tokens ?: 0) ) ) ) } else -> throw RuntimeException("Unsupported model: $model") } } private data class AwsResponseAnthropic( val id: String? = null, val type: String? = null, val role: String? = null, val content: List? = null, val model: String? = null, val stop_reason: String? = null, val stop_sequence: String? = null, val usage: AwsResponseAnthropicUsage? = null ) private data class AwsResponseAnthropicContent( val type: String? = null, val text: String? = null ) private data class AwsResponseAnthropicUsage( val input_tokens: Int? = null, val output_tokens: Int? = null ) private data class Ai21ChatResponse( val id: Int? = null, val prompt: Ai21Prompt? = null, val completions: List? = null ) private data class Ai21Completion( val data: Ai21Data? = null, val finishReason: Ai21FinishReason? = null ) private data class Ai21FinishReason( val reason: String? = null ) private data class Ai21Data( val text: String? = null, val tokens: List? = null ) private data class Ai21Prompt( val text: String? = null, val tokens: List? = null ) private data class Ai21Token( val generatedToken: Ai21GeneratedToken? = null, val topTokens: List? = null, val textRange: Ai21TextRange? = null ) private data class Ai21GeneratedToken( val token: String? = null, val logprob: Double? = null, val raw_logprob: Double? = null ) private data class Ai21TopToken( val token: String? = null, val logprob: Double? = null, val raw_logprob: Double? = null ) private data class Ai21TextRange( val start: Int? = null, val end: Int? = null ) private data class AwsResponseCohere( val generations: List ) private data class AwsResponseCohereGeneration( val text: String? = null ) private data class AwsResponseMistral( val outputs: List ) private data class AwsResponseMistralOutput( val text: String? = null, val stop_reason: String? = null ) private data class AwsResponseTitan( val inputTextTokenCount: Int? = null, val results: List ) private data class AwsResponseTitanResult( val tokenCount: Int? = null, val outputText: String? = null, val completionReason: String? = null ) private data class AwsResponseLlama2( val generation: String? = null, val prompt_token_count: Int? = null, val generation_token_count: Int? = null, val stop_reason: String? = null ) private fun fromAnthropicResponse(rawResponse: String): String { try { val errorCheck = JsonUtil.objectMapper().readTree(rawResponse) if (errorCheck.has("type") && errorCheck.get("type").asText() == "error") { throw RuntimeException("Error response received: $rawResponse") } val response = JsonUtil.objectMapper().readValue(rawResponse, AnthropicResponse::class.java) return JsonUtil.toJson( ChatResponse( id = response.id, choices = listOf( ChatChoice( message = ChatMessageResponse( content = response.content.joinToString("\n") { it.text ?: "" } ), index = 0 ) ), usage = Usage( prompt_tokens = response.usage.input_tokens.toLong(), completion_tokens = response.usage.output_tokens.toLong(), total_tokens = response.usage.input_tokens.toLong() + response.usage.output_tokens ) ) ) } catch (e: Exception) { throw RuntimeException("Error parsing Anthropic response: $rawResponse", e) } } private fun fromModelsLab(rawResponse: String): String { val response = JsonUtil.objectMapper().readValue(rawResponse, ModelsLabDataModel.ChatResponse::class.java) return when (response.status) { "success" -> { JsonUtil.toJson(ChatResponse( id = response.chat_id, choices = listOf( ChatChoice( message = ChatMessageResponse(content = response.message), index = 0 ) ), usage = response.meta?.let { Usage( prompt_tokens = it.max_new_tokens?.toLong() ?: 0, completion_tokens = 0, // Assuming no direct mapping; adjust as needed. total_tokens = it.max_new_tokens?.toLong() ?: 0 ) } )) } "processing" -> { val seconds = response?.eta ?: 1 log.info("Chat response is still processing; waiting ${seconds}s and trying again.") Thread.sleep(seconds * 1000L) val postCheck = JsonUtil.objectMapper().writerWithDefaultPrettyPrinter().writeValueAsString( mapOf( "chat_id" to (response.meta?.chat_id ?: response.chat_id), "key" to key[APIProvider.ModelsLab] ) ) fromModelsLab( post( "${apiBase[defaultApiProvider]}/llm/get_queued_response", postCheck, defaultApiProvider ) ) } "error" -> { throw RuntimeException("Error in chat request: ${response.message}\n$rawResponse") } "failed" -> { throw RuntimeException("Chat request failed: ${response.message}\n$rawResponse") } else -> throw RuntimeException("Unknown status: ${response.status}\n${response.message}\n$rawResponse") } } private fun toModelsLab(chatRequest: ChatRequest) = modelslab_chatRequest_prototype.copy( key = key[APIProvider.ModelsLab], model_id = chatRequest.model, system_prompt = chatRequest.messages.filter { it.role == Role.system }.joinToString("\n") { it.content?.joinToString("\n") { it.text ?: "" } ?: "" }, prompt = chatRequest.messages.filter { it.role != Role.system }.joinToString("\n") { it.content?.joinToString("\n") { it.text ?: "" } ?: "" }, temperature = chatRequest.temperature, ) private fun toGroq(chatRequest: ChatRequest): GroqChatRequest = GroqChatRequest( messages = chatRequest.messages.map { message -> GroqChatMessage( role = message.role, content = message.content?.joinToString("\n") { it.text ?: "" } ?: "", ) }, model = chatRequest.model, max_tokens = chatRequest.max_tokens, temperature = chatRequest.temperature, ) companion object { private val log = LoggerFactory.getLogger(OpenAIClient::class.java) var modelsLabThrottle = Semaphore(1) var modelslab_chatRequest_prototype = ModelsLabDataModel.ChatRequest( max_new_tokens = 1000, no_repeat_ngram_size = 5, ) } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy