tri.ai.gemini.GeminiClient.kt Maven / Gradle / Ivy
/*-
* #%L
* tri.promptfx:promptkt
* %%
* Copyright (C) 2023 - 2025 Johns Hopkins University Applied Physics Laboratory
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package tri.ai.gemini
import io.ktor.client.call.*
import io.ktor.client.request.*
import kotlinx.serialization.Serializable
import tri.ai.core.TextChatMessage
import tri.ai.core.MChatRole
import tri.ai.core.VisionLanguageChatMessage
import java.io.Closeable
/**
* General purpose client for the Gemini API.
* See https://ai.google.dev/api?lang=web
*/
class GeminiClient : Closeable {
private val settings = GeminiSettings()
private val client = settings.client
/** Returns true if the client is configured with an API key. */
fun isConfigured() = settings.apiKey.isNotBlank()
//region CORE API METHODS
suspend fun listModels(): ModelsResponse {
return client.get("models")
.body()
}
suspend fun generateContent(modelId: String, request: GenerateContentRequest): GenerateContentResponse {
return client.post("models/$modelId:generateContent") {
setBody(request)
}.body()
}
//endregion
//region ALTERNATE API METHODS
suspend fun embedContent(content: String, modelId: String, outputDimensionality: Int? = null): EmbedContentResponse {
val request = EmbedContentRequest(Content(listOf(Part(content))), outputDimensionality = outputDimensionality)
return client.post("models/$modelId:embedContent") {
setBody(request)
}.body()
}
suspend fun batchEmbedContents(content: List, modelId: String, outputDimensionality: Int? = null): BatchEmbedContentsResponse {
val request = BatchEmbedContentRequest(
content.map { EmbedContentRequest(Content(listOf(Part(it))), model = "models/$modelId", outputDimensionality = outputDimensionality) }
)
return client.post("models/$modelId:batchEmbedContents") {
setBody(request)
}.body()
}
suspend fun generateContent(prompt: String, modelId: String, numResponses: Int? = null, history: List): GenerateContentResponse {
val system = history.lastOrNull { it.role == MChatRole.System }?.content
val request = GenerateContentRequest(
contents = history.filter { it.role != MChatRole.System }.map {
val role = it.role.toGeminiRole()
Content(listOf(Part(it.content)), role)
} + Content.text(prompt),
systemInstruction = system?.let { Content(listOf(Part(it)), ContentRole.user) },
// TODO - enable when Gemini API supports candidateCount, see https://ai.google.dev/api/generate-content#v1beta.GenerationConfig
// generationConfig = numResponses?.let { GenerationConfig(candidateCount = it) }
)
return generateContent(modelId, request)
}
suspend fun generateContent(prompt: String, image: String, modelId: String, numResponses: Int? = null, history: List): GenerateContentResponse {
val system = history.lastOrNull { it.role == MChatRole.System }?.content
val request = GenerateContentRequest(
contents = history.filter { it.role != MChatRole.System }.map {
val role = it.role.toGeminiRole()
Content(listOf(Part(it.content)), role)
} + Content(listOf(
Part(text = prompt),
Part(inlineData = Blob(image, MIME_TYPE_JPEG))
)),
systemInstruction = system?.let { Content(listOf(Part(it)), ContentRole.user) },
// TODO - enable when Gemini API supports candidateCount, see https://ai.google.dev/api/generate-content#v1beta.GenerationConfig
// generationConfig = numResponses?.let { GenerationConfig(candidateCount = it) }
)
return generateContent(modelId, request)
}
suspend fun generateContent(messages: List, modelId: String, config: GenerationConfig? = null): GenerateContentResponse {
val system = messages.lastOrNull { it.role == MChatRole.System }?.content
val request = GenerateContentRequest(
messages.filter { it.role != MChatRole.System }.map {
val role = it.role.toGeminiRole()
Content(listOf(Part(it.content)), role)
},
systemInstruction = system?.let { Content(listOf(Part(it)), ContentRole.user) },
generationConfig = config
)
return generateContent(modelId, request)
}
suspend fun generateContentVision(messages: List, modelId: String, config: GenerationConfig? = null): GenerateContentResponse {
val system = messages.lastOrNull { it.role == MChatRole.System }?.content
val request = GenerateContentRequest(
messages.filter { it.role != MChatRole.System }.map {
val role = it.role.toGeminiRole()
Content(listOf(
Part(it.content),
Part(null, Blob.fromDataUrl(it.image))
), role)
},
systemInstruction = system?.let { Content(listOf(Part(it)), ContentRole.user) }, // TODO - support for system messages
generationConfig = config
)
return generateContent(modelId, request)
}
//endregion
override fun close() {
client.close()
}
companion object {
val INSTANCE by lazy { GeminiClient() }
/** Convert from [MChatRole] to string representing Gemini role. */
fun MChatRole.toGeminiRole() = when (this) {
MChatRole.User -> ContentRole.user
MChatRole.Assistant -> ContentRole.model
else -> error("Invalid role: $this")
}
/** Convert from string representing Gemini role to [MChatRole]. */
fun ContentRole?.fromGeminiRole() = when (this) {
ContentRole.user -> MChatRole.User
ContentRole.model -> MChatRole.Assistant
else -> error("Invalid role: $this")
}
}
}
//region DTO's - see https://ai.google.dev/api?lang=web
@Serializable
data class ModelsResponse(
val models: List
)
@Serializable
data class ModelInfo(
val name: String,
val baseModelId: String? = null, // though marked as required, not returned by API
val version: String,
val displayName: String,
val description: String? = null, // though marked as required, not always returned by API
val inputTokenLimit: Int,
val outputTokenLimit: Int,
val supportedGenerationMethods: List,
val temperature: Double? = null,
val maxTemperature: Double? = null,
val topP: Double? = null,
val topK: Int? = null
)
//endregion
© 2015 - 2025 Weber Informatics LLC | Privacy Policy