All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tri.ai.core.MultimodalChatMessage.kt Maven / Gradle / Ivy

/*-
 * #%L
 * tri.promptfx:promptkt
 * %%
 * Copyright (C) 2023 - 2025 Johns Hopkins University Applied Physics Laboratory
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */
package tri.ai.core

import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonContentPolymorphicSerializer
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive

/** Generic representation of a multimodal chat message. */
data class MultimodalChatMessage(
    /** Role for what generated the message. */
    val role: MChatRole,
    /** Content of the message. */
    val content: List? = null,
    /** List of tool calls (typically populated automatically by AI) to invoke for more information. */
    val toolCalls: List? = null,
    /** Unique id of a tool call (typically populated by the code wrapping up the tool result for the next chat call), used to link a tool call request to the result. */
    val toolCallId: String? = null
) {
    companion object {
        /** Chat message with just text. */
        fun text(role: MChatRole, text: String) = MultimodalChatMessage(
            role,
            listOf(MChatMessagePart(text = text))
        )
        /** Chat message with a tool result. */
        fun tool(result: String, toolId: String) = MultimodalChatMessage(
            MChatRole.Tool,
            listOf(MChatMessagePart(text = result)),
            toolCallId = toolId
        )
    }
}

/** Model parameters for multimodal chat. */
class MChatParameters(
    /** Parameters for varying output. */
    val variation: MChatVariation = MChatVariation(),
    /** Parameters for tool use. */
    val tools: MChatTools? = null,
    /** Parameters for token limit. */
    val tokens: Int? = 1000,
    /** Parameters for stopping criteria. */
    val stop: List? = null,
    /** Parameters for response format. */
    val responseFormat: MResponseFormat = MResponseFormat.TEXT,
    /** Parameters for number of responses. */
    val numResponses: Int? = null
)

/** Model parameters related to likelihood, variation, and probabilities. */
class MChatVariation(
    val seed: Int? = null,
    val temperature: Double? = null,
    val topP: Double? = null,
    val topK: Int? = null,
    val frequencyPenalty: Double? = null,
    val presencePenalty: Double? = null
)

/** Model parameters related to tool use. */
class MChatTools(
    val toolChoice: MToolChoice = MToolChoice.AUTO,
    val tools: List
)

@Serializable(with = MToolChoiceSerializer::class)
sealed interface MToolChoice {
    @JvmInline
    @Serializable
    value class Mode(val value: String) : MToolChoice

    @Serializable
    data class Named(
        @SerialName("type") val type: MToolType,
        @SerialName("function") val function: MFunctionToolChoice
    ) : MToolChoice

    companion object {
        /** Represents the `auto` mode. */
        val AUTO: MToolChoice = Mode("AUTO")

        /** Represents the `none` mode. */
        val NONE: MToolChoice = Mode("NONE")

        /** Specifies a function for the model to call **/
        fun function(name: String): MToolChoice =
            Named(type = MToolType.FUNCTION, function = MFunctionToolChoice(name = name))
    }
}

@JvmInline
@Serializable
value class MToolType(val value: String) {
    companion object {
        val FUNCTION = MToolType("function")
    }
}

@Serializable
data class MFunctionToolChoice(val name: String)

internal class MToolChoiceSerializer : JsonContentPolymorphicSerializer(MToolChoice::class) {
    override fun selectDeserializer(element: JsonElement): DeserializationStrategy {
        return when (element) {
            is JsonPrimitive -> MToolChoice.Mode.serializer()
            is JsonObject -> MToolChoice.Named.serializer()
            else -> throw UnsupportedOperationException("Unsupported JSON element: $element")
        }
    }
}

class MTool(
    val name: String,
    val description: String,
    val jsonSchema: String
)

/** Reference to a function to execute. */
class MToolCall(
    val id: String,
    val name: String,
    val argumentsAsJson: String
)

enum class MResponseFormat {
    JSON,
    TEXT
}

data class MChatMessagePart(
    val partType: MPartType = MPartType.TEXT,
    val text: String? = null,
    // TODO - support for multiple types of inline data
    val inlineData: String? = null,
    val functionName: String? = null,
    val functionArgs: Map? = null
) {
    init {
        require(if (partType == MPartType.TEXT) text != null else true) { "Text must be provided for text part type." }
        require(if (partType == MPartType.IMAGE) inlineData != null else true) { "Inline data must be provided for image part type." }
        require(if (partType == MPartType.TOOL_CALL) functionName != null && functionArgs != null else true) { "Function name and arguments must be provided for tool call part type." }
        require(if (partType == MPartType.TOOL_RESPONSE) functionName != null && functionArgs != null else true) { "Function name and arguments must be provided for tool response part type." }
    }

    companion object {
        fun text(text: String) = MChatMessagePart(MPartType.TEXT, text)
        fun image(inlineData: String) = MChatMessagePart(MPartType.IMAGE, inlineData = inlineData)
        fun toolCall(name: String, args: Map) = MChatMessagePart(MPartType.TOOL_CALL, functionName = name, functionArgs = args)
        fun toolResponse(name: String, response: Map) = MChatMessagePart(MPartType.TOOL_RESPONSE, functionName = name, functionArgs = response)
    }
}

enum class MPartType {
    TEXT,
    IMAGE,
    TOOL_CALL,
    TOOL_RESPONSE
}

//region BUILDERS

/** Build a [MultimodalChatMessage] from a builder. */
fun chatMessage(role: MChatRole? = null, block: MChatMessageBuilder.() -> Unit) =
    MChatMessageBuilder().apply(block).also {
        if (role != null) it.role = role
    }.build()

/** Builder object for [MultimodalChatMessage]. */
class MChatMessageBuilder {
    var role = MChatRole.User
    var content = mutableListOf()
    var params: MChatParameters? = null
    var toolCalls = mutableListOf()
    var toolCallId: String? = null

    fun role(role: MChatRole) {
        this.role = role
    }
    fun text(text: String) {
        content += MChatMessagePart(MPartType.TEXT, text)
    }
    fun image(imageUrl: String) {
        content += MChatMessagePart(MPartType.IMAGE, inlineData = imageUrl)
    }
    fun content(vararg parts: MChatMessagePart) {
        content += parts.toList()
    }
    fun content(parts: List) {
        content += parts
    }
    fun parameters(block: MChatParameters.() -> Unit) {
        params = MChatParameters().apply(block)
    }
    fun toolCalls(vararg calls: MToolCall) {
        toolCalls += calls.toList()
    }
    fun toolCalls(calls: List) {
        toolCalls += calls
    }
    fun toolCallId(id: String?) {
        toolCallId = id
    }

    fun build() = MultimodalChatMessage(role, content, toolCalls, toolCallId)
}

//endregion




© 2015 - 2025 Weber Informatics LLC | Privacy Policy