tri.ai.openai.OpenAiClient.kt Maven / Gradle / Ivy
/*-
* #%L
* promptkt-0.1.0-SNAPSHOT
* %%
* Copyright (C) 2023 Johns Hopkins University Applied Physics Laboratory
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package tri.ai.openai
import com.aallam.openai.api.LegacyOpenAI
import com.aallam.openai.api.audio.TranscriptionRequest
import com.aallam.openai.api.chat.ChatCompletionRequest
import com.aallam.openai.api.completion.CompletionRequest
import com.aallam.openai.api.core.Usage
import com.aallam.openai.api.edits.EditsRequest
import com.aallam.openai.api.embedding.EmbeddingRequest
import com.aallam.openai.api.file.FileSource
import com.aallam.openai.api.http.Timeout
import com.aallam.openai.api.image.ImageCreation
import com.aallam.openai.api.logging.LogLevel
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.LoggingConfig
import com.aallam.openai.client.OpenAI
import com.aallam.openai.client.OpenAIConfig
import com.aallam.openai.client.OpenAIHost
import com.fasterxml.jackson.databind.DeserializationFeature
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
import com.fasterxml.jackson.module.kotlin.KotlinModule
import io.ktor.http.*
import okio.FileSystem
import okio.Path.Companion.toOkioPath
import tri.ai.pips.AiTaskResult
import tri.ai.pips.AiTaskResult.Companion.result
import tri.ai.pips.UsageUnit
import java.io.File
import java.util.*
import java.util.logging.Logger
import kotlin.time.Duration.Companion.seconds
/** OpenAI API client with built-in usage tracking. */
class OpenAiClient(val settings: OpenAiSettings) {
/** OpenAI API client. */
val client
get() = settings.client
/** OpenAI API usage stats. */
val usage = mutableMapOf()
//region QUICK API CALLS
/** Runs an embedding using ADA embedding model. */
suspend fun quickEmbedding(modelId: String = EMBEDDING_ADA, inputs: List) =
quickEmbedding(modelId, *inputs.toTypedArray())
/** Runs an embedding using ADA embedding model. */
suspend fun quickEmbedding(modelId: String, vararg inputs: String) = client.embeddings(EmbeddingRequest(
ModelId(modelId),
inputs.toList()
)).let { it ->
usage.increment(it.usage)
result(it.embeddings.map { it.embedding }, modelId)
}
/** Runs a quick audio transcription for a given file. */
suspend fun quickTranscribe(modelId: String = AUDIO_WHISPER, audioFile: File): AiTaskResult {
if (!audioFile.isAudioFile())
return AiTaskResult.invalidRequest("Audio file not provided.")
val request = TranscriptionRequest(
model = ModelId(modelId),
audio = FileSource(audioFile.toOkioPath(), FileSystem.SYSTEM)
)
return client.transcription(request).let {
usage.increment(0, UsageUnit.AUDIO_MINUTES)
result(it.text, modelId)
}
}
//endregion
//region DIRECT API CALLS
/** Runs a text completion request. */
@OptIn(LegacyOpenAI::class)
suspend fun completion(completionRequest: CompletionRequest) =
client.completion(completionRequest).let {
usage.increment(it.usage)
result(it.choices[0].text, completionRequest.model.id)
}
/** Runs a text completion request using a chat model. */
suspend fun chatCompletion(completionRequest: ChatCompletionRequest) =
client.chatCompletion(completionRequest).let {
usage.increment(it.usage)
result(it.choices[0].message.content ?: "", completionRequest.model.id)
}
/** Runs a chat response. */
suspend fun chat(completionRequest: ChatCompletionRequest) =
client.chatCompletion(completionRequest).let {
usage.increment(it.usage)
result(it.choices[0].message, completionRequest.model.id)
}
/** Runs an edit request (deprecated API). */
@Suppress("DEPRECATION")
suspend fun edit(request: EditsRequest) =
client.edit(request).let {
usage.increment(it.usage)
result(it.choices[0].text, request.model.id)
}
/** Runs an image creation request. */
suspend fun imageURL(imageCreation: ImageCreation) =
client.imageURL(imageCreation).let {
usage.increment(it.size, UsageUnit.IMAGES)
result(it.map { it.url }, IMAGE_DALLE)
}
//endregion
//region USAGE TRACKING
/** Increment usage map with usage from response. */
private fun MutableMap.increment(usage: Usage?) {
this[UsageUnit.TOKENS] = (this[UsageUnit.TOKENS] ?: 0) + (usage?.totalTokens ?: 0)
}
/** Increment usage map with usage from response. */
private fun MutableMap.increment(totalTokens: Int, unit: UsageUnit) {
this[unit] = (this[unit] ?: 0) + totalTokens
}
//endregion
companion object {
val INSTANCE by lazy { OpenAiClient(OpenAiSettings()) }
}
}
/** Manages OpenAI API key and client. */
class OpenAiSettings {
companion object {
const val API_KEY_FILE = "apikey.txt"
const val API_KEY_ENV = "OPENAI_API_KEY"
}
var baseUrl: String? = null
set(value) {
field = value
buildClient()
}
var apiKey = readApiKey()
set(value) {
field = value
buildClient()
}
var logLevel = LogLevel.Info
set(value) {
field = value
buildClient()
}
var timeoutSeconds = 60
set(value) {
field = value
buildClient()
}
var client: OpenAI
private set
init {
client = buildClient()
}
/** Read API key by first checking for [API_KEY_FILE], and then checking user environment variable [API_KEY_ENV]. */
private fun readApiKey(): String {
val file = File("apikey.txt")
val key = if (file.exists()) {
file.readText()
} else
System.getenv(API_KEY_ENV)
return if (key.isNullOrBlank()) {
Logger.getLogger(OpenAiSettings::class.java.name).warning(
"No API key found. Please create a file named $API_KEY_FILE in the root directory, or set an environment variable named $API_KEY_ENV."
)
""
} else
key
}
@Throws(IllegalStateException::class)
private fun buildClient(): OpenAI {
client = OpenAI(
OpenAIConfig(
host = if (baseUrl == null) OpenAIHost.OpenAI else OpenAIHost(baseUrl!!),
token = apiKey,
logging = LoggingConfig(LogLevel.None),
timeout = Timeout(socket = timeoutSeconds.seconds)
)
)
return client
}
}
//region MODELS
const val COMBO_GPT4 = "gpt-4"
const val COMBO_GPT35 = "gpt-3.5-turbo"
const val COMBO_GPT35_16K = "gpt-3.5-turbo-16k"
const val TEXT_DAVINCI3 = "text-davinci-003"
const val TEXT_DAVINCI2 = "text-davinci-002"
const val TEXT_CURIE = "text-curie-001"
const val TEXT_BABBAGE = "text-babbage-001"
const val TEXT_ADA = "text-ada-001"
const val EDIT_DAVINCI = "text-davinci-edit-001"
const val INSERT_DAVINCI2 = "text-davinci-insert-002"
const val INSERT_DAVINCI = "text-davinci-insert-001"
const val CODE_DAVINCI2 = "code-davinci-002"
const val CODE_CUSHMAN1 = "code-cushman-001"
const val CODE_EDIT_DAVINCI = "code-davinci-edit-001"
const val EMBEDDING_ADA = "text-embedding-ada-002"
const val AUDIO_WHISPER = "whisper-1"
const val IMAGE_DALLE = "dalle-2"
val chatModels = listOf(COMBO_GPT35, COMBO_GPT4, "$COMBO_GPT35-0301", "$COMBO_GPT4-0314")
val textModels = listOf(TEXT_DAVINCI3, TEXT_CURIE, TEXT_BABBAGE, TEXT_ADA)
val codeModels = listOf(CODE_DAVINCI2, CODE_CUSHMAN1)
val completionModels = textModels + codeModels
val insertModels = listOf(TEXT_DAVINCI3, INSERT_DAVINCI2, INSERT_DAVINCI, TEXT_DAVINCI2, CODE_DAVINCI2)
val editsModels = listOf(EDIT_DAVINCI, CODE_EDIT_DAVINCI)
val embeddingsModels = listOf(EMBEDDING_ADA)
val audioModels = listOf(AUDIO_WHISPER)
val imageModels = listOf(IMAGE_DALLE)
//endregion
//region UTILS
val mapper = ObjectMapper()
.registerModule(JavaTimeModule())
.registerModule(KotlinModule.Builder().build())
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)!!
fun File.isAudioFile() = extension.lowercase(Locale.getDefault()) in
listOf("mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm")
//endregion
© 2015 - 2025 Weber Informatics LLC | Privacy Policy