All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.jopenai.OpenAIClient.kt Maven / Gradle / Ivy

There is a newer version: 1.1.12
Show newest version
package com.simiacryptus.jopenai

import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.node.ObjectNode
import com.google.common.util.concurrent.ListeningScheduledExecutorService
import com.google.gson.Gson
import com.google.gson.JsonObject
import com.simiacryptus.jopenai.ApiModel.*
import com.simiacryptus.jopenai.util.ClientUtil.allowedCharset
import com.simiacryptus.jopenai.util.ClientUtil.checkError
import com.simiacryptus.jopenai.util.ClientUtil.keyTxt
import com.simiacryptus.jopenai.exceptions.ModerationException
import com.simiacryptus.jopenai.models.*
import com.simiacryptus.jopenai.util.JsonUtil
import com.simiacryptus.jopenai.util.StringUtil
import org.apache.hc.client5.http.classic.methods.HttpGet
import org.apache.hc.client5.http.classic.methods.HttpPost
import org.apache.hc.client5.http.entity.mime.FileBody
import org.apache.hc.client5.http.entity.mime.HttpMultipartMode
import org.apache.hc.client5.http.entity.mime.MultipartEntityBuilder
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient
import org.apache.hc.core5.http.ContentType
import org.apache.hc.core5.http.HttpRequest
import org.apache.hc.core5.http.io.entity.EntityUtils
import org.apache.hc.core5.http.io.entity.StringEntity
import org.slf4j.event.Level
import java.awt.image.BufferedImage
import java.io.BufferedOutputStream
import java.io.IOException
import java.net.URL
import java.util.*
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.atomic.AtomicInteger
import javax.imageio.ImageIO

open class OpenAIClient(
  protected var key: String = keyTxt,
  private val apiBase: String = "https://api.openai.com/v1",
  logLevel: Level = Level.INFO,
  logStreams: MutableList = mutableListOf(),
  scheduledPool: ListeningScheduledExecutorService = HttpClientManager.scheduledPool,
  workPool: ThreadPoolExecutor = HttpClientManager.workPool,
  client: CloseableHttpClient = HttpClientManager.client
) : HttpClientManager(
  logLevel = logLevel,
  logStreams = logStreams,
  scheduledPool = scheduledPool,
  workPool = workPool,
  client = client
) {

  private val tokenCounter = AtomicInteger(0)

  open fun onUsage(model: OpenAIModel?, tokens: Usage) {
    tokenCounter.addAndGet(tokens.total_tokens)
  }

  open val metrics: Map
    get() = hashMapOf(
      "tokens" to tokenCounter.get(),
      "chats" to chatCounter.get(),
      "completions" to completionCounter.get(),
      "moderations" to moderationCounter.get(),
      "renders" to renderCounter.get(),
      "transcriptions" to transcriptionCounter.get(),
      "edits" to editCounter.get(),
    )
  protected val chatCounter = AtomicInteger(0)
  protected val completionCounter = AtomicInteger(0)
  protected val moderationCounter = AtomicInteger(0)
  protected val renderCounter = AtomicInteger(0)
  protected val transcriptionCounter = AtomicInteger(0)
  protected val editCounter = AtomicInteger(0)

  @Throws(IOException::class, InterruptedException::class)
  protected fun post(url: String, json: String): String {
    val request = HttpPost(url)
    request.addHeader("Content-Type", "application/json")
    request.addHeader("Accept", "application/json")
    authorize(request)
    request.entity = StringEntity(json, Charsets.UTF_8, false)
    return post(request)
  }

  protected fun post(request: HttpPost): String = withClient { EntityUtils.toString(it.execute(request).entity) }

  @Throws(IOException::class)
  protected open fun authorize(request: HttpRequest) {
    request.addHeader("Authorization", "Bearer $key")
  }

  @Throws(IOException::class)
  protected operator fun get(url: String?): String = withClient {
    val request = HttpGet(url)
    request.addHeader("Content-Type", "application/json")
    request.addHeader("Accept", "application/json")
    authorize(request)
    EntityUtils.toString(it.execute(request).entity)
  }

  fun listEngines(): List = JsonUtil.objectMapper().readValue(
    JsonUtil.objectMapper().readValue(
      get("$apiBase/engines"), ObjectNode::class.java
    )["data"]?.toString() ?: "{}", JsonUtil.objectMapper().typeFactory.constructCollectionType(
      List::class.java, Engine::class.java
    )
  )

  fun getEngineIds(): Array = listEngines().map { it.id }.sortedBy { it }.toTypedArray()

  open fun complete(
    request: CompletionRequest, model: OpenAITextModel
  ): CompletionResponse = withReliability {
    withPerformanceLogging {
      completionCounter.incrementAndGet()
      if (request.suffix == null) {
        log(
          msg = String.format(
            "Text Completion Request\nPrefix:\n\t%s\n", request.prompt.replace("\n", "\n\t")
          )
        )
      } else {
        log(
          msg = String.format(
            "Text Completion Request\nPrefix:\n\t%s\nSuffix:\n\t%s\n",
            request.prompt.replace("\n", "\n\t"),
            request.suffix.replace("\n", "\n\t")
          )
        )
      }
      val result = post(
        "$apiBase/engines/${model.modelName}/completions", StringUtil.restrictCharacterSet(
          JsonUtil.objectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(request),
          allowedCharset
        )
      )
      checkError(result)
      val response = JsonUtil.objectMapper().readValue(
        result, CompletionResponse::class.java
      )
      if (response.usage != null) {
        onUsage(model, response.usage.copy(cost = model.pricing(response.usage)))
      }
      val completionResult =
        StringUtil.stripPrefix(response.firstChoice.orElse("").toString().trim { it <= ' ' },
          request.prompt.trim { it <= ' ' })
      log(
        msg = String.format(
          "Chat Completion:\n\t%s", completionResult.toString().replace("\n", "\n\t")
        )
      )
      response
    }
  }

  open fun transcription(wavAudio: ByteArray, prompt: String = ""): String = withReliability {
    withPerformanceLogging {
      transcriptionCounter.incrementAndGet()
      val url = "$apiBase/audio/transcriptions"
      val request = HttpPost(url)
      request.addHeader("Accept", "application/json")
      authorize(request)
      val entity = MultipartEntityBuilder.create()
      entity.setMode(HttpMultipartMode.EXTENDED)
      entity.addBinaryBody("file", wavAudio, ContentType.create("audio/x-wav"), "audio.wav")
      entity.addTextBody("model", "whisper-1")
      entity.addTextBody("response_format", "verbose_json")
      if (prompt.isNotEmpty()) entity.addTextBody("prompt", prompt)
      request.entity = entity.build()
      val response = post(request)
      val jsonObject = Gson().fromJson(response, JsonObject::class.java)
      if (jsonObject.has("error")) {
        val errorObject = jsonObject.getAsJsonObject("error")
        throw RuntimeException(IOException(errorObject["message"].asString))
      }
      try {
        val result = JsonUtil.objectMapper().readValue(response, TranscriptionResult::class.java)
        result.text ?: ""
      } catch (e: Exception) {
        jsonObject.get("text").asString ?: ""
      }
    }
  }

  open fun createSpeech(request: SpeechRequest): ByteArray? = withReliability {
    withPerformanceLogging {
      val httpRequest = HttpPost("$apiBase/audio/speech")
      authorize(httpRequest)
      httpRequest.addHeader("Accept", "application/json")
      httpRequest.addHeader("Content-Type", "application/json")
      httpRequest.entity = StringEntity(JsonUtil.objectMapper().writeValueAsString(request), Charsets.UTF_8, false)
      val response = withClient { it.execute(httpRequest).entity }
      val contentType = response.contentType
      val bytes = response.content.readAllBytes()
      if (contentType != null && contentType.startsWith("text") || contentType.startsWith("application/json")) {
        checkError(bytes.toString(Charsets.UTF_8))
        null
      } else {
        val model = AudioModels.values().find { it.modelName.equals(request.model, true) }
        onUsage(
          model, Usage(
          prompt_tokens = request.input.length,
          cost = model?.pricing(request.input.length)
        ))
        bytes
      }
    }
  }


  open fun render(prompt: String = "", resolution: Int = 1024, count: Int = 1): List =
    withReliability {
      withPerformanceLogging {
        renderCounter.incrementAndGet()
        val url = "$apiBase/images/generations"
        val request = HttpPost(url)
        request.addHeader("Accept", "application/json")
        request.addHeader("Content-Type", "application/json")
        authorize(request)
        val jsonObject = JsonObject()
        jsonObject.addProperty("prompt", prompt)
        jsonObject.addProperty("n", count)
        jsonObject.addProperty("size", "${resolution}x$resolution")
        request.entity = StringEntity(jsonObject.toString(), Charsets.UTF_8, false)
        val response = post(request)
        val jsonObject2 = Gson().fromJson(response, JsonObject::class.java)
        if (jsonObject2.has("error")) {
          val errorObject = jsonObject2.getAsJsonObject("error")
          throw RuntimeException(IOException(errorObject["message"].asString))
        }
        val dataArray = jsonObject2.getAsJsonArray("data")
        val images = ArrayList()
        for (i in 0 until dataArray.size()) {
          images.add(ImageIO.read(URL(dataArray[i].asJsonObject.get("url").asString)))
        }
        images
      }
    }

  open fun chat(
    chatRequest: ChatRequest, model: OpenAITextModel
  ): ChatResponse = withReliability {
    withPerformanceLogging {
      chatCounter.incrementAndGet()
      val reqJson =
        JsonUtil.objectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(chatRequest)
      log(
        msg = String.format(
          "Chat Request\nPrefix:\n\t%s\n", reqJson.replace("\n", "\n\t")
        )
      )

      val jsonRequest = JsonUtil.objectMapper().writeValueAsString(chatRequest)
      val result = post("$apiBase/chat/completions", jsonRequest)
      checkError(result)
      val response = JsonUtil.objectMapper().readValue(result, ChatResponse::class.java)
      if (response.usage != null) {
        onUsage(model, response.usage.copy(cost = model.pricing(response.usage)))
      }
      log(
        msg = String.format(
          "Chat Completion:\n\t%s",
          response.choices.firstOrNull()?.message?.content?.trim { it <= ' ' }?.replace("\n", "\n\t")
            ?: JsonUtil.toJson(response)
        )
      )
      response
    }
  }

  open fun moderate(text: String) = withReliability {
    withPerformanceLogging {
      moderationCounter.incrementAndGet()
      val body: String = try {
        JsonUtil.objectMapper().writeValueAsString(
          mapOf(
            "input" to StringUtil.restrictCharacterSet(text, allowedCharset)
          )
        )
      } catch (e: JsonProcessingException) {
        throw RuntimeException(e)
      }
      val result: String = try {
        this.post("$apiBase/moderations", body)
      } catch (e: IOException) {
        throw RuntimeException(e)
      } catch (e: InterruptedException) {
        throw RuntimeException(e)
      }
      val jsonObject = Gson().fromJson(
        result, JsonObject::class.java
      ) ?: return@withPerformanceLogging
      if (jsonObject.has("error")) {
        val errorObject = jsonObject.getAsJsonObject("error")
        throw RuntimeException(IOException(errorObject["message"].asString))
      }
      val moderationResult = jsonObject.getAsJsonArray("results")[0].asJsonObject
      if (moderationResult["flagged"].asBoolean) {
        val categoriesObj = moderationResult["categories"].asJsonObject
        throw RuntimeException(
          ModerationException("Moderation flagged this request due to " + categoriesObj.keySet()
            .stream().filter { c: String? ->
              categoriesObj[c].asBoolean
            }.reduce { a: String, b: String -> "$a, $b" }.orElse("???")
          )
        )
      }
    }
  }

  open fun edit(
    editRequest: EditRequest
  ): CompletionResponse = withReliability {
    withPerformanceLogging {
      editCounter.incrementAndGet()
      if (editRequest.input == null) {
        log(
          msg = String.format(
            "Text Edit Request\nInstruction:\n\t%s\n", editRequest.instruction.replace("\n", "\n\t")
          )
        )
      } else {
        log(
          msg = String.format(
            "Text Edit Request\nInstruction:\n\t%s\nInput:\n\t%s\n",
            editRequest.instruction.replace("\n", "\n\t"),
            editRequest.input.replace("\n", "\n\t")
          )
        )
      }
      val request: String = StringUtil.restrictCharacterSet(
        JsonUtil.objectMapper().writeValueAsString(editRequest), allowedCharset
      )
      val result = post("$apiBase/edits", request)
      checkError(result)
      val response = JsonUtil.objectMapper().readValue(
        result, CompletionResponse::class.java
      )
      if (response.usage != null) {
        val model = EditModels.values().find { it.modelName.equals(editRequest.model, true) }
        onUsage(
          model, response.usage.copy(cost = model?.pricing(response.usage))
        )
      }
      log(
        msg = String.format(
          "Chat Completion:\n\t%s",
          response.firstChoice.orElse("").toString().trim { it <= ' ' }.toString().replace("\n", "\n\t")
        )
      )
      response
    }
  }

  open fun listModels(): ModelListResponse {
    val result = get("$apiBase/models")
    checkError(result)
    return JsonUtil.objectMapper().readValue(result, ModelListResponse::class.java)
  }

  open fun createEmbedding(
    request: EmbeddingRequest
  ): EmbeddingResponse {
    return withReliability {
      withPerformanceLogging {
        if (request.input is String) {
          log(
            msg = String.format(
              "Embedding Creation Request\nModel:\n\t%s\nInput:\n\t%s\n",
              request.model,
              request.input.replace("\n", "\n\t")
            )
          )
        }
        val result = post(
          "$apiBase/embeddings", StringUtil.restrictCharacterSet(
            JsonUtil.objectMapper().writeValueAsString(request), allowedCharset
          )
        )
        checkError(result)
        val response = JsonUtil.objectMapper().readValue(
          result, EmbeddingResponse::class.java
        )
        if (response.usage != null) {
          val model = EmbeddingModels.values().find { it.modelName.equals(request.model, true) }
          onUsage(
            model,
            response.usage.copy(cost = model?.pricing(response.usage))
          )
        }
        response
      }
    }
  }

  open fun createImage(request: ImageGenerationRequest): ImageGenerationResponse = withReliability {
    withPerformanceLogging {
      val url = "$apiBase/images/generations"
      val httpRequest = HttpPost(url)
      httpRequest.addHeader("Accept", "application/json")
      httpRequest.addHeader("Content-Type", "application/json")
      authorize(httpRequest)

      val requestBody = Gson().toJson(request)
      httpRequest.entity = StringEntity(requestBody, Charsets.UTF_8, false)

      val response = post(httpRequest)
      checkError(response)
      val model = ImageModels.values().find { it.modelName.equals(request.model, true) }
      val dims = request.size?.split("x")
      onUsage(model, Usage(completion_tokens = 1, cost = model?.pricing(
        width = dims?.get(0)?.toInt() ?: 0,
        height = dims?.get(1)?.toInt() ?: 0
      )))

      JsonUtil.objectMapper().readValue(response, ImageGenerationResponse::class.java)
    }
  }

  open fun createImageEdit(request: ImageEditRequest): ImageEditResponse = withReliability {
    withPerformanceLogging {
      val url = "$apiBase/images/edits"
      val httpRequest = HttpPost(url)
      httpRequest.addHeader("Accept", "application/json")
      authorize(httpRequest)

      val entityBuilder = MultipartEntityBuilder.create()
      entityBuilder.addPart("image", FileBody(request.image))
      entityBuilder.addTextBody("prompt", request.prompt)
      request.mask?.let { entityBuilder.addPart("mask", FileBody(it)) }
      request.model?.let { entityBuilder.addTextBody("model", it) }
      request.n?.let { entityBuilder.addTextBody("n", it.toString()) }
      request.size?.let { entityBuilder.addTextBody("size", it) }
      request.responseFormat?.let { entityBuilder.addTextBody("response_format", it) }
      request.user?.let { entityBuilder.addTextBody("user", it) }

      httpRequest.entity = entityBuilder.build()
      val response = post(httpRequest)
      checkError(response)

      JsonUtil.objectMapper().readValue(response, ImageEditResponse::class.java)
    }
  }

  open fun createImageVariation(request: ImageVariationRequest): ImageVariationResponse = withReliability {
    withPerformanceLogging {
      val url = "$apiBase/images/variations"
      val httpRequest = HttpPost(url)
      httpRequest.addHeader("Accept", "application/json")
      authorize(httpRequest)

      val entityBuilder = MultipartEntityBuilder.create()
      entityBuilder.addPart("image", FileBody(request.image))
      //request.model?.let { entityBuilder.addTextBody("model", it) }
      request.n?.let { entityBuilder.addTextBody("n", it.toString()) }
      request.responseFormat?.let { entityBuilder.addTextBody("response_format", it) }
      request.size?.let { entityBuilder.addTextBody("size", it) }
      request.user?.let { entityBuilder.addTextBody("user", it) }

      httpRequest.entity = entityBuilder.build()
      val response = post(httpRequest)
      checkError(response)

      JsonUtil.objectMapper().readValue(response, ImageVariationResponse::class.java)
    }
  }

  companion object {
    private val log = org.slf4j.LoggerFactory.getLogger(OpenAIClient::class.java)
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy