tri.ai.gemini.GeminiMultimodalChat.kt Maven / Gradle / Ivy
/*-
* #%L
* tri.promptfx:promptkt
* %%
* Copyright (C) 2023 - 2025 Johns Hopkins University Applied Physics Laboratory
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package tri.ai.gemini
import kotlinx.serialization.SerializationException
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import tri.ai.core.*
import tri.ai.gemini.GeminiClient.Companion.fromGeminiRole
import tri.ai.prompt.trace.*
import tri.util.info
/** Chat completion with Gemini models. */
class GeminiMultimodalChat(override val modelId: String = GeminiModelIndex.GEMINI_15_FLASH, val client: GeminiClient = GeminiClient.INSTANCE) :
MultimodalChat {
override fun toString() = modelId
override suspend fun chat(
messages: List,
parameters: MChatParameters
): AiPromptTrace {
val modelInfo = AiModelInfo.info(modelId, tokens = parameters.tokens, stop = parameters.stop, requestJson = parameters.responseFormat == MResponseFormat.JSON)
val t0 = System.currentTimeMillis()
if ((parameters.numResponses ?: 1) > 1)
info("Gemini chat API does not support multiple responses; only the first response will be returned.")
val system = messages.lastOrNull { it.role == MChatRole.System }?.content
val nonSystem = messages.filter { it.role != MChatRole.System }
val request = GenerateContentRequest(
contents = nonSystem.map { it.gemini() },
tools = parameters.geminiTools(),
toolConfig = parameters.geminiToolConfig(),
safetySettings = null,
systemInstruction = system?.let { Content(it.map { it.gemini() }, ContentRole.user) },
generationConfig = parameters.gemini(),
cachedContent = null
)
val response = client.generateContent(modelId, request)
return if (response.promptFeedback != null)
AiPromptTrace.invalidRequest(modelInfo, response.promptFeedback.toString())
else
response.trace(modelInfo, t0)
}
companion object {
private const val DEFAULT_MAX_TOKENS = 500
/** Create trace for chat message response, with given model info and start query time. */
internal fun GenerateContentResponse.trace(modelInfo: AiModelInfo, t0: Long): AiPromptTrace {
val pf = promptFeedback
return if (pf?.blockReason != null) {
val msg = "Gemini blocked response: ${pf.blockReason}"
AiPromptTrace.error(modelInfo, msg, duration = System.currentTimeMillis() - t0)
} else if (candidates.isNullOrEmpty()) {
AiPromptTrace.error(modelInfo, "Gemini returned no candidates", duration = System.currentTimeMillis() - t0)
} else {
val firstCandidate = candidates!!.first()
val msg = firstCandidate.fromGeminiCandidate()
AiPromptTrace(
null,
modelInfo,
AiExecInfo(responseTimeMillis = System.currentTimeMillis() - t0),
AiOutputInfo.output(msg)
)
}
}
//region CONVERSIONS
fun GenerateContentResponse.toMultimodalChatMessage(): MultimodalChatMessage =
candidates!!.first().content.let {
MultimodalChatMessage(
role = it.role.fromGeminiRole(),
content = it.parts.map { it.fromGemini() }
)
}
fun Candidate.fromGeminiCandidate(): MultimodalChatMessage {
val role = content.role.fromGeminiRole()
val functionCall = content.parts.filter { it.functionCall != null }.map { it.functionCall!! }
val parts = content.parts.filter { it.text != null || it.inlineData != null }.map { it.fromGemini() }
return MultimodalChatMessage(role, parts, functionCall.map { it.fromGemini() })
}
fun Part.fromGemini(): MChatMessagePart = when {
text != null -> MChatMessagePart.text(text)
inlineData != null -> MChatMessagePart.image(inlineData.data)
functionCall != null -> MChatMessagePart.toolCall(functionCall.name, functionCall.args)
functionResponse != null -> MChatMessagePart.toolResponse(functionResponse.name, functionResponse.response)
else -> throw UnsupportedOperationException("Unsupported Gemini part: $this")
}
fun MultimodalChatMessage.gemini(): Content {
return if (!toolCalls.isNullOrEmpty()) {
val args = toolCalls.first().argumentsAsJson
val toolCallArgs = Json.decodeFromString © 2015 - 2025 Weber Informatics LLC | Privacy Policy