All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tri.ai.gemini.GeminiAiPlugin.kt Maven / Gradle / Ivy

/*-
 * #%L
 * tri.promptfx:promptkt
 * %%
 * Copyright (C) 2023 - 2025 Johns Hopkins University Applied Physics Laboratory
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */
package tri.ai.gemini

import kotlinx.coroutines.runBlocking
import tri.ai.core.*
import java.time.LocalDate
import java.time.format.DateTimeFormatter
import java.time.format.DateTimeParseException

/** Plugin registering models and services via the Gemini API. */
class GeminiAiPlugin : TextPlugin {

    private val client = GeminiClient()

    override fun modelSource() = "Gemini"

    override fun modelInfo() = if (client.isConfigured())
        runBlocking {
            try {
                client.listModels().models.map { it.toCoreModelInfo() }
            } catch (x: Exception) {
                x.printStackTrace()
                emptyList()
            }
        }
    else listOf()

    override fun embeddingModels() = models(GeminiModelIndex.embeddingModels()) { GeminiEmbeddingService(it, client) }

    override fun chatModels() =
        models(GeminiModelIndex.chatModelsInclusive()) { GeminiTextChat(it, client) }

    override fun multimodalModels() =
        models(GeminiModelIndex.multimodalModels()) { GeminiMultimodalChat(it, client) }

    override fun textCompletionModels() =
        models(GeminiModelIndex.completionModels() + GeminiModelIndex.chatModelsInclusive()) {
            GeminiTextCompletion(it, client)
        }

    override fun visionLanguageModels() =
        models(GeminiModelIndex.visionLanguageModels()) { GeminiVisionLanguageChat(it, client) }

    override fun imageGeneratorModels() = models(GeminiModelIndex.imageGeneratorModels()) { TODO() }

    override fun close() {
        client.close()
    }

    private fun  models(ids: List, factory: (String) -> T): List =
        if (!client.isConfigured()) listOf() else ids.map(factory)

    //region GETTING MODEL INFO (WITH HEURISTICS)

    private fun ModelInfo.toCoreModelInfo() =
        ModelInfo(
            id = name.substringAfter("models/"),
            type = ModelType.UNKNOWN,
            source = modelSource()
        ).also {
            it.name = displayName
            it.version = version
            it.description = description
            it.inputTokenLimit = inputTokenLimit
            it.outputTokenLimit = outputTokenLimit

            it.created = findReleaseDate(description)
            it.deprecation = findDeprecation(description)
            it.lifecycle = findLifecycle(it.id, description)
            it.type = findType(it.id, supportedGenerationMethods.toSet())
            it.inputs = when (it.type) {
                ModelType.QUESTION_ANSWER -> listOf(DataModality.text)
                ModelType.TEXT_EMBEDDING -> listOf(DataModality.text)
                ModelType.TEXT_CHAT -> listOf(DataModality.text)
                ModelType.TEXT_VISION_CHAT -> listOf(DataModality.text, DataModality.image, DataModality.audio, DataModality.video)
                else -> null
            }
            it.outputs = when (it.type) {
                ModelType.QUESTION_ANSWER -> listOf(DataModality.text)
                ModelType.TEXT_EMBEDDING -> listOf(DataModality.embedding)
                ModelType.TEXT_CHAT -> listOf(DataModality.text)
                ModelType.TEXT_VISION_CHAT -> listOf(DataModality.text)
                else -> null
            }

            it.params(
                "supportedGenerationMethods" to supportedGenerationMethods,
                "baseModelId" to baseModelId,
                "temperature" to temperature,
                "maxTemperature" to maxTemperature,
                "topP" to topP,
                "topK" to topK
            )
        }

    private fun findDeprecation(description: String?): String? {
        return when {
            description == null -> null
            description.contains("will be discontinued on") ->
                description.substringAfter("will be discontinued on")
                    .substringBefore(".").parseDate()
            description.contains("was deprecated on") ->
                description.substringAfter("was deprecated on")
                    .substringBefore(".").parseDate()
            else -> null
        }?.toString()
    }

    private fun findReleaseDate(description: String?): LocalDate? {
        return when {
            description == null -> null
            description.contains("released in") ->
                description.substringAfter("released in")
                    .substringBefore(".").parseDate()
            description.contains("Experimental release (") ->
                description.substringAfter("Experimental release (")
                    .substringBefore(")").parseDate()
            else -> null
        }
    }

    private fun String?.parseDate(): LocalDate? {
        if (this == null) return null
        try {
            // common pattern e.g. "May of 2024"
            return LocalDate.parse(trim().replace(" of ", " 1, "), DateTimeFormatter.ofPattern("MMMM d, yyyy"))
        } catch (e: DateTimeParseException) {
            // ignore and try next
        }
        try {
            // common pattern e.g. "May 1st, 2025"
            val cleaned = replaceFirst(Regex("(\\d+)(st|nd|rd|th)"), "$1").trim()
            return LocalDate.parse(cleaned, DateTimeFormatter.ofPattern("MMMM d, yyyy"))
        } catch (e: DateTimeParseException) {
            // ignore and try next
        }
        return null
    }

    private fun findLifecycle(id: String, description: String?): ModelLifecycle {
        return when {
            "-exp" in id -> ModelLifecycle.EXPERIMENTAL
            description == null -> ModelLifecycle.PRODUCTION
            "Experimental release" in description -> ModelLifecycle.EXPERIMENTAL
            "released in" in description -> ModelLifecycle.PRODUCTION
            "most recent production" in description -> ModelLifecycle.PRODUCTION_ALIAS
            "legacy" in description -> ModelLifecycle.LEGACY
            "will be discontinued" in description -> ModelLifecycle.DEPRECATION_PLANNED
            "was deprecated" in description -> ModelLifecycle.DEPRECATED
            else -> ModelLifecycle.PRODUCTION
        }
    }

    /**
     * This is mostly done through trial and error.
     * See https://ai.google.dev/gemini-api/docs/models/gemini#model-variations for inputs supported.
     */
    private fun findType(id: String, methods: Set): ModelType {
        val ANSWER = "generateAnswer"
        val BIDI_GENERATE = "bidiGenerateContent"
        val CACHED = "createCachedContent"
        val COUNT = "countTokens"
        val EMBED = "embedContent"
        val GENERATE = "generateContent"
        val TUNED = "createTunedModel"

        val GENERATE2 = "generateMessage"
        val COUNT2 = "countMessageTokens"

        val COUNT3 = "countTextTokens"
        val EMBED3 = "embedText"
        val GENERATE3 = "generateText"
        val TUNED3 = "createTunedTextModel"

        var type = when (methods) {
            setOf(EMBED) -> ModelType.TEXT_EMBEDDING
            setOf(EMBED3, COUNT3) -> ModelType.TEXT_EMBEDDING
            setOf(GENERATE) -> ModelType.TEXT_CHAT
            setOf(GENERATE, COUNT) -> ModelType.TEXT_CHAT
            setOf(GENERATE, COUNT, TUNED) -> ModelType.TEXT_CHAT
            setOf(GENERATE2, COUNT2) -> ModelType.TEXT_CHAT
            setOf(GENERATE, COUNT2, TUNED) -> ModelType.TEXT_CHAT
            setOf(GENERATE3, COUNT3, TUNED3) -> ModelType.TEXT_CHAT
            setOf(ANSWER) -> ModelType.QUESTION_ANSWER
            setOf(GENERATE, COUNT, CACHED) -> ModelType.TEXT_CHAT
            setOf(GENERATE, COUNT, BIDI_GENERATE) -> ModelType.TEXT_CHAT
            else -> ModelType.UNKNOWN
        }

        if (type == ModelType.TEXT_CHAT && ("vision" in id || "gemini-1.5" in id || "gemini-2.0" in id)) {
            type = ModelType.TEXT_VISION_CHAT
        }
        return type
    }

    //endregion

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy