All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tri.ai.core.ModelLibrary.kt Maven / Gradle / Ivy

/*-
 * #%L
 * tri.promptfx:promptkt
 * %%
 * Copyright (C) 2023 - 2025 Johns Hopkins University Applied Physics Laboratory
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */
package tri.ai.core

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
import com.fasterxml.jackson.module.kotlin.KotlinModule
import com.fasterxml.jackson.module.kotlin.readValue
import java.io.File

/** Library of models, including model information and list of enabled models. */
class ModelLibrary {
    var models = mapOf>()
    var audio = listOf()
    var chat = listOf()
    var multimodal = listOf()
    var completion = listOf()
    var embeddings = listOf()
    var image_generator = listOf()
    var moderation = listOf()
    var tts = listOf()
    var vision_language = listOf()

    /** Create model index with unique identifiers, including any registered snapshots. */
    fun modelInfoIndex() = models.values.flatten().flatMap { listOf(it) + it.createSnapshots() }.associateBy { it.id }
}

/** A model index with configurable model information and runtime overrides. */
abstract class ModelIndex(val modelFileName: String) {

    private val MAPPER = ObjectMapper(YAMLFactory()).apply {
        registerModule(KotlinModule.Builder().build())
        registerModule(JavaTimeModule())
    }

    /** Model definitions in library. */
    private val MODELS: ModelLibrary = MAPPER.readValue(javaClass.getResourceAsStream("resources/$modelFileName")!!)
    /** Model overrides at runtime - ids only. */
    private val RUNTIME_MODELS: ModelLibrary = setOf(File(modelFileName), File("config/$modelFileName"))
        .firstOrNull { it.exists() }?.let {
            MAPPER.readValue(it)
        } ?: ModelLibrary()
    /** [ModelInfo] by id, where config in runtime overrides preconfigured info. */
    internal val MODEL_INFO_INDEX = MODELS.modelInfoIndex() + RUNTIME_MODELS.modelInfoIndex()

    /** Get chat models, including vision-language models which have the same API. */
    fun chatModelsInclusive(includeSnapshots: Boolean = false) = models(ModelLibrary::chat, includeSnapshots) + models(ModelLibrary::vision_language, includeSnapshots)
    /** Get chat models without vision-language models. */
    fun chatModels(includeSnapshots: Boolean = false) = models(ModelLibrary::chat, includeSnapshots)

    /** Get multimodal models. */
    fun multimodalModels(includeSnapshots: Boolean = false) = models(ModelLibrary::multimodal, includeSnapshots)

    fun completionModels(includeSnapshots: Boolean = false) = models(ModelLibrary::completion, includeSnapshots)
    fun embeddingModels(includeSnapshots: Boolean = false) = models(ModelLibrary::embeddings, includeSnapshots)
    fun moderationModels(includeSnapshots: Boolean = false) = models(ModelLibrary::moderation, includeSnapshots)
    fun audioModels(includeSnapshots: Boolean = false) = models(ModelLibrary::audio, includeSnapshots)
    fun ttsModels(includeSnapshots: Boolean = false) = models(ModelLibrary::tts, includeSnapshots)
    fun imageGeneratorModels(includeSnapshots: Boolean = false) = models(ModelLibrary::image_generator, includeSnapshots)
    fun visionLanguageModels(includeSnapshots: Boolean = false) = models(ModelLibrary::vision_language, includeSnapshots)

    /** Get list of available models. */
    private fun models(op: ModelLibrary.() -> List, includeSnapshots: Boolean = false) =
        lookupModels(op(MODELS), op(RUNTIME_MODELS))
            .flatMap { it.ids(includeSnapshots) }

    /**
     * Get list of models from two lists.
     * If no runtime models are provided, returns the preconfigured list, otherwise the runtime list.
     * If using the preconfigured list, check to make sure there are model definitions in the yaml (optional for runtime).
     */
    private fun lookupModels(preconfigured: List, runtime: List) =
        (runtime.ifEmpty { preconfigured }).map {
            MODEL_INFO_INDEX[it] ?: ModelInfo(it, ModelType.UNKNOWN, "runtime")
        }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy