
com.simiacryptus.jopenai.OpenAIClient.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jo-penai Show documentation
Show all versions of jo-penai Show documentation
A Java client for OpenAI's API
The newest version!
package com.simiacryptus.jopenai
import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.node.ObjectNode
import com.google.common.util.concurrent.ListeningScheduledExecutorService
import com.google.gson.Gson
import com.google.gson.JsonObject
import com.simiacryptus.jopenai.exceptions.ModerationException
import com.simiacryptus.jopenai.models.*
import com.simiacryptus.jopenai.models.ApiModel.*
import com.simiacryptus.jopenai.util.ClientUtil.allowedCharset
import com.simiacryptus.jopenai.util.ClientUtil.checkError
import com.simiacryptus.jopenai.util.ClientUtil.defaultApiProvider
import com.simiacryptus.jopenai.util.ClientUtil.keyMap
import com.simiacryptus.util.JsonUtil
import com.simiacryptus.util.StringUtil
import org.apache.hc.client5.http.classic.methods.HttpGet
import org.apache.hc.client5.http.classic.methods.HttpPost
import org.apache.hc.client5.http.entity.mime.FileBody
import org.apache.hc.client5.http.entity.mime.HttpMultipartMode
import org.apache.hc.client5.http.entity.mime.MultipartEntityBuilder
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient
import org.apache.hc.core5.http.ContentType
import org.apache.hc.core5.http.HttpRequest
import org.apache.hc.core5.http.io.entity.EntityUtils
import org.apache.hc.core5.http.io.entity.StringEntity
import org.slf4j.event.Level
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.awt.image.BufferedImage
import java.io.BufferedOutputStream
import java.io.IOException
import java.net.URL
import java.util.*
import java.util.concurrent.ThreadPoolExecutor
import javax.imageio.ImageIO
open class OpenAIClient(
protected var key: Map = keyMap.mapKeys { APIProvider.valueOf(it.key) },
protected val apiBase: Map = APIProvider.values().associate { it to (it.base ?: "") },
logLevel: Level = Level.INFO,
logStreams: MutableList = mutableListOf(),
scheduledPool: ListeningScheduledExecutorService = HttpClientManager.scheduledPool,
workPool: ThreadPoolExecutor = HttpClientManager.workPool,
client: CloseableHttpClient = createHttpClient()
) : HttpClientManager(
logLevel = logLevel,
logStreams = logStreams,
scheduledPool = scheduledPool,
workPool = workPool
) {
private val logger: Logger = LoggerFactory.getLogger(OpenAIClient::class.java).apply {
info("OpenAIClient initialized with log level: $logLevel")
}
var user: Any? = null
var session: Any? = null
open fun onUsage(model: OpenAIModel?, tokens: Usage) {
}
@Throws(IOException::class, InterruptedException::class)
protected fun post(url: String, json: String, apiProvider: APIProvider): String {
val request = HttpPost(url)
request.addHeader("Content-Type", "application/json")
request.addHeader("Accept", "application/json")
logger.info("Sending POST request to URL: $url with payload: $json")
authorize(request, apiProvider)
request.entity = StringEntity(json, Charsets.UTF_8, false)
return post(request)
logger.info("Executed POST request: ${request.uri}")
}
protected fun post(request: HttpPost): String = withClient { EntityUtils.toString(it.execute(request).entity) }
@Throws(IOException::class)
protected open fun authorize(request: HttpRequest, apiProvider: APIProvider) {
when (apiProvider) {
APIProvider.Google -> {
// request.addHeader("X-goog-api-key", "${key.get(apiProvider)}")
}
APIProvider.Anthropic -> {
request.addHeader("x-api-key", "${key.get(apiProvider)}")
request.addHeader("anthropic-version", "2023-06-01")
}
else -> request.addHeader("Authorization", "Bearer ${key.get(apiProvider)}")
}
}
@Throws(IOException::class)
protected operator fun get(url: String?, apiProvider: APIProvider): String = withClient {
val request = HttpGet(url)
request.addHeader("Content-Type", "application/json")
request.addHeader("Accept", "application/json")
logger.debug("Sending GET request to URL: $url")
authorize(request, apiProvider)
EntityUtils.toString(it.execute(request).entity)
}
fun listEngines(): List = JsonUtil.objectMapper().readValue(
JsonUtil.objectMapper().readValue(
get("${apiBase[defaultApiProvider]}/engines", defaultApiProvider), ObjectNode::class.java
)["data"]?.toString() ?: "{}", JsonUtil.objectMapper().typeFactory.constructCollectionType(
List::class.java, Engine::class.java
)
)
open fun complete(
request: CompletionRequest, model: TextModel
): CompletionResponse = withReliability {
withPerformanceLogging {
if (request.suffix == null) {
log(
msg = String.format(
"Text Completion Request\nPrefix:\n\t%s\n", request.prompt.replace("\n", "\n\t")
)
)
logger.debug("Text Completion Request with Prefix: ${request.prompt}")
} else {
log(
msg = String.format(
"Text Completion Request\nPrefix:\n\t%s\nSuffix:\n\t%s\n",
request.prompt.replace("\n", "\n\t"),
request.suffix.replace("\n", "\n\t")
)
)
logger.debug("Text Completion Request with Prefix: ${request.prompt} and Suffix: ${request.suffix}")
}
val result = post(
"${apiBase[defaultApiProvider]}/engines/${model.modelName}/completions",
StringUtil.restrictCharacterSet(
JsonUtil.objectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(request),
allowedCharset
),
defaultApiProvider
)
checkError(result)
val response = JsonUtil.objectMapper().readValue(
result, CompletionResponse::class.java
)
if (response.usage != null) {
onUsage(model, response.usage.copy(cost = model.pricing(response.usage)))
}
val completionResult =
StringUtil.stripPrefix(response.firstChoice.orElse("").toString().trim { it <= ' ' },
request.prompt.trim { it <= ' ' })
log(
msg = String.format(
"Text Completion:\n\t%s", completionResult.toString().replace("\n", "\n\t")
)
)
logger.debug("Text Completion Result: $completionResult")
response
}
}
open fun transcription(wavAudio: ByteArray, prompt: String = ""): String = withReliability {
withPerformanceLogging {
val url = "$apiBase/audio/transcriptions"
val request = HttpPost(url)
request.addHeader("Accept", "application/json")
authorize(request, defaultApiProvider)
val entity = MultipartEntityBuilder.create()
entity.setMode(HttpMultipartMode.EXTENDED)
entity.addBinaryBody("file", wavAudio, ContentType.create("audio/x-wav"), "audio.wav")
entity.addTextBody("model", "whisper-1")
entity.addTextBody("response_format", "verbose_json")
if (prompt.isNotEmpty()) entity.addTextBody("prompt", prompt)
request.entity = entity.build()
val response = post(request)
logger.info("Transcription response received")
val jsonObject = Gson().fromJson(response, JsonObject::class.java)
if (jsonObject.has("error")) {
val errorObject = jsonObject.getAsJsonObject("error")
throw RuntimeException(IOException(errorObject["message"].asString))
}
try {
val result = JsonUtil.objectMapper().readValue(response, TranscriptionResult::class.java)
result.text ?: ""
} catch (e: Exception) {
jsonObject.get("text").asString ?: ""
}
}
}
open fun createSpeech(request: ApiModel.SpeechRequest): ByteArray? = withReliability {
withPerformanceLogging {
val httpRequest = HttpPost("${apiBase[defaultApiProvider]}/audio/speech")
authorize(httpRequest, defaultApiProvider)
httpRequest.addHeader("Accept", "application/json")
httpRequest.addHeader("Content-Type", "application/json")
httpRequest.entity =
StringEntity(JsonUtil.objectMapper().writeValueAsString(request), Charsets.UTF_8, false)
val response = withClient { it.execute(httpRequest).entity }
val contentType = response.contentType
val bytes = response.content.readAllBytes()
logger.info("Speech creation response received with content type: $contentType")
if (contentType != null && contentType.startsWith("text") || contentType.startsWith("application/json")) {
checkError(bytes.toString(Charsets.UTF_8))
null
} else {
val model = AudioModels.values().find { it.modelName.equals(request.model, true) }
onUsage(
model, Usage(
prompt_tokens = request.input.length.toLong(),
cost = model?.pricing(request.input.length)
)
)
bytes
}
}
}
open fun render(prompt: String = "", resolution: Int = 1024, count: Int = 1): List =
withReliability {
withPerformanceLogging {
val url = "${apiBase[defaultApiProvider]}/images/generations"
val request = HttpPost(url)
request.addHeader("Accept", "application/json")
request.addHeader("Content-Type", "application/json")
authorize(request, defaultApiProvider)
val jsonObject = JsonObject()
jsonObject.addProperty("prompt", prompt)
jsonObject.addProperty("n", count)
jsonObject.addProperty("size", "${resolution}x$resolution")
request.entity = StringEntity(jsonObject.toString(), Charsets.UTF_8, false)
val response = post(request)
logger.info("Image generation response received")
val jsonObject2 = Gson().fromJson(response, JsonObject::class.java)
if (jsonObject2.has("error")) {
val errorObject = jsonObject2.getAsJsonObject("error")
throw RuntimeException(IOException(errorObject["message"].asString))
}
val dataArray = jsonObject2.getAsJsonArray("data")
val images = ArrayList()
for (i in 0 until dataArray.size()) {
images.add(ImageIO.read(URL(dataArray[i].asJsonObject.get("url").asString)))
}
images
}
}
data class Content(
val role: String? = null,
val parts: List? = null
)
data class Part(
val inlineData: Blob? = null,
val text: String? = null
)
data class Blob(
val mimeType: String? = null,
val data: String? = null
)
open fun moderate(text: String) = withReliability {
when {
defaultApiProvider == APIProvider.Groq -> return@withReliability
defaultApiProvider == APIProvider.ModelsLab -> return@withReliability
}
withPerformanceLogging {
val body: String = try {
JsonUtil.objectMapper().writeValueAsString(
mapOf(
"input" to StringUtil.restrictCharacterSet(text, allowedCharset)
)
)
} catch (e: JsonProcessingException) {
throw RuntimeException(e)
}
val result: String = try {
this.post("${apiBase[defaultApiProvider]}/moderations", body, defaultApiProvider)
} catch (e: IOException) {
logger.warn("IOException during moderation request", e)
throw RuntimeException(e)
} catch (e: InterruptedException) {
throw RuntimeException(e)
}
val jsonObject = Gson().fromJson(
result, JsonObject::class.java
) ?: return@withPerformanceLogging
if (jsonObject.has("error")) {
val errorObject = jsonObject.getAsJsonObject("error")
throw RuntimeException(IOException(errorObject["message"].asString))
}
val moderationResult = jsonObject.getAsJsonArray("results")[0].asJsonObject
if (moderationResult["flagged"].asBoolean) {
val categoriesObj = moderationResult["categories"].asJsonObject
throw RuntimeException(
ModerationException("Moderation flagged this request due to " + categoriesObj.keySet()
.stream().filter { c: String? ->
categoriesObj[c].asBoolean
}.reduce { a: String, b: String -> "$a, $b" }.orElse("???")
)
)
}
}
}
open fun edit(
editRequest: EditRequest
): CompletionResponse = withReliability {
withPerformanceLogging {
if (editRequest.input == null) {
log(
msg = String.format(
"Text Edit Request\nInstruction:\n\t%s\n", editRequest.instruction.replace("\n", "\n\t")
)
)
} else {
log(
msg = String.format(
"Text Edit Request\nInstruction:\n\t%s\nInput:\n\t%s\n",
editRequest.instruction.replace("\n", "\n\t"),
editRequest.input.replace("\n", "\n\t")
)
)
}
val request: String = StringUtil.restrictCharacterSet(
JsonUtil.objectMapper().writeValueAsString(editRequest), allowedCharset
)
val result = post("${apiBase[defaultApiProvider]}/edits", request, defaultApiProvider)
logger.info("Edit response received")
checkError(result)
val response = JsonUtil.objectMapper().readValue(
result, CompletionResponse::class.java
)
if (response.usage != null) {
val model = EditModels.values().values.find { it.modelName.equals(editRequest.model, true) }
onUsage(
model, response.usage.copy(cost = model?.pricing(response.usage))
)
}
log(
msg = String.format(
"Edit Completion:\n\t%s",
response.firstChoice.orElse("").toString().trim { it <= ' ' }.toString().replace("\n", "\n\t")
)
)
response
}
}
open fun listModels(): ApiModel.ModelListResponse {
val result = get("${apiBase[defaultApiProvider]}/models", defaultApiProvider)
checkError(result)
return JsonUtil.objectMapper().readValue(result, ModelListResponse::class.java)
}
open fun createEmbedding(
request: EmbeddingRequest
): EmbeddingResponse {
return withReliability {
withPerformanceLogging {
if (request.input is String) {
log(
msg = String.format(
"Embedding Creation Request\nModel:\n\t%s\nInput:\n\t%s\n",
request.model,
request.input.replace("\n", "\n\t")
)
)
}
val result = post(
"${apiBase[defaultApiProvider]}/embeddings", StringUtil.restrictCharacterSet(
JsonUtil.objectMapper().writeValueAsString(request), allowedCharset
), defaultApiProvider
)
logger.info("Embedding creation response received")
checkError(result)
val response = JsonUtil.objectMapper().readValue(
result, EmbeddingResponse::class.java
)
if (response.usage != null) {
val model = EmbeddingModels.values().values.find { it.modelName.equals(request.model, true) }
onUsage(
model,
response.usage.copy(cost = model?.pricing(response.usage))
)
}
response
}
}
}
open fun createImage(request: ImageGenerationRequest): ImageGenerationResponse = withReliability {
withPerformanceLogging {
val url = "${apiBase[defaultApiProvider]}/images/generations"
val httpRequest = HttpPost(url)
httpRequest.addHeader("Accept", "application/json")
httpRequest.addHeader("Content-Type", "application/json")
authorize(httpRequest, defaultApiProvider)
val requestBody = Gson().toJson(request)
httpRequest.entity = StringEntity(requestBody, Charsets.UTF_8, false)
val response = post(httpRequest)
checkError(response)
logger.info("Image creation response received")
val model = ImageModels.values().find { it.modelName.equals(request.model, true) }
val dims = request.size?.split("x")
onUsage(
model, Usage(
completion_tokens = 1, cost = model?.pricing(
width = dims?.get(0)?.toInt() ?: 0,
height = dims?.get(1)?.toInt() ?: 0
)
)
)
JsonUtil.objectMapper().readValue(response, ImageGenerationResponse::class.java)
}
}
open fun createImageEdit(request: ImageEditRequest): ImageEditResponse = withReliability {
withPerformanceLogging {
val url = "${apiBase[defaultApiProvider]}/images/edits"
val httpRequest = HttpPost(url)
httpRequest.addHeader("Accept", "application/json")
authorize(httpRequest, defaultApiProvider)
val entityBuilder = MultipartEntityBuilder.create()
entityBuilder.addPart("image", FileBody(request.image))
entityBuilder.addTextBody("prompt", request.prompt)
request.mask?.let { entityBuilder.addPart("mask", FileBody(it)) }
request.model?.let { entityBuilder.addTextBody("model", it) }
request.n?.let { entityBuilder.addTextBody("n", it.toString()) }
request.size?.let { entityBuilder.addTextBody("size", it) }
request.responseFormat?.let { entityBuilder.addTextBody("response_format", it) }
request.user?.let { entityBuilder.addTextBody("user", it) }
httpRequest.entity = entityBuilder.build()
val response = post(httpRequest)
checkError(response)
logger.info("Image edit response received")
JsonUtil.objectMapper().readValue(response, ImageEditResponse::class.java)
}
}
open fun createImageVariation(request: ImageVariationRequest): ImageVariationResponse = withReliability {
withPerformanceLogging {
val url = "${apiBase[defaultApiProvider]}/images/variations"
val httpRequest = HttpPost(url)
httpRequest.addHeader("Accept", "application/json")
authorize(httpRequest, defaultApiProvider)
val entityBuilder = MultipartEntityBuilder.create()
entityBuilder.addPart("image", FileBody(request.image))
//request.model?.let { entityBuilder.addTextBody("model", it) }
request.n?.let { entityBuilder.addTextBody("n", it.toString()) }
request.responseFormat?.let { entityBuilder.addTextBody("response_format", it) }
request.size?.let { entityBuilder.addTextBody("size", it) }
request.user?.let { entityBuilder.addTextBody("user", it) }
httpRequest.entity = entityBuilder.build()
val response = post(httpRequest)
checkError(response)
logger.info("Image variation response received")
JsonUtil.objectMapper().readValue(response, ImageVariationResponse::class.java)
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy