All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tri.ai.gemini.GenerateContentRequest.kt Maven / Gradle / Ivy

/*-
 * #%L
 * tri.promptfx:promptkt
 * %%
 * Copyright (C) 2023 - 2025 Johns Hopkins University Applied Physics Laboratory
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */
package tri.ai.gemini

import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import java.net.URI

// https://ai.google.dev/api/generate-content#method:-models.generatecontent

@Serializable
data class GenerateContentRequest(
    val contents: List,
    val tools: List? = null,
    val toolConfig: ToolConfig? = null,
    val safetySettings: List? = null,
    val systemInstruction: Content? = null, // this is a beta feature
    val generationConfig: GenerationConfig? = null,
    val cachedContent: String? = null
) {
    constructor(content: Content, systemInstruction: Content? = null, generationConfig: GenerationConfig? = null) :
            this(listOf(content), systemInstruction = systemInstruction, generationConfig = generationConfig)
}

//region [Content] and [Part]

@Serializable
data class Content(
    val parts: List,
    val role: ContentRole? = null
) {
    companion object {
        /** Content with a single text part. */
        fun text(text: String) = Content(listOf(Part(text)), ContentRole.user)
        /** Content with a system message. */
        fun systemMessage(text: String) = text(text) // TODO - support for system messages if Gemini supports it
    }
}

@Serializable
enum class ContentRole {
    user, model
}

@Serializable
data class Part(
    // [Part] is a union type that can contain only one of the following accepted types
    val text: String? = null,
    val inlineData: Blob? = null,
    val functionCall: FunctionCall? = null,
    val functionResponse: FunctionResponse? = null,
    val fileData: FileData? = null,
    val executableCode: ExecutableCode? = null,
    val codeExecutionResult: CodeExecutionResult? = null
)

// https://ai.google.dev/gemini-api/docs/document-processing
const val MIME_TYPE_TEXT = "text/plain"
const val MIME_TYPE_ENUM = "text/x.enum"
const val MIME_TYPE_CSV = "text/csv"
const val MIME_TYPE_HTML = "text/html"
const val MIME_TYPE_MD = "text/md"
const val MIME_TYPE_RTF = "text/rtf"
const val MIME_TYPE_XML = "text/xml"
const val MIME_TYPE_JSON = "application/json"
const val MIME_TYPE_PDF = "application/pdf"

// https://ai.google.dev/gemini-api/docs/vision
const val MIME_TYPE_JPEG = "image/jpeg"
const val MIME_TYPE_PNG = "image/png"
const val MIME_TYPE_HEIC = "image/heic"

// https://ai.google.dev/gemini-api/docs/audio
const val MIME_TYPE_WAV = "audio/wav"
const val MIME_TYPE_MP3 = "audio/mp3"

// https://ai.google.dev/gemini-api/docs/vision#prompting-video
const val MIME_TYPE_MP4 = "video/mp4"
const val MIME_TYPE_MOV = "video/mov"
const val MIME_TYPE_MPEG = "video/mpeg"
const val MIME_TYPE_MPG = "video/mpg"

@Serializable
data class Blob(
    val mimeType: String,
    val data: String
) {
    companion object {
        /** Generate blob from image URL. */
        fun fromDataUrl(url: URI) = fromDataUrl(url.toASCIIString())

        /** Generate blob from image URL. */
        fun fromDataUrl(urlStr: String): Blob {
            if (urlStr.startsWith("data:")) {
                val mimeType = urlStr.substringBefore(";base64,").substringAfter("data:")
                val base64 = urlStr.substringAfter(";base64,")
                return Blob(mimeType, base64)
            } else {
                throw UnsupportedOperationException("Expected a data URL but was $urlStr")
            }
        }
    }
}

// TODO - [args] is a "Struct" which is essentially a map of key-value pairs, unsure how to serialize properly with kotlin
@Serializable
data class FunctionCall(
    val name: String,
    val args: Map
)

// TODO - [response] is a "Struct" which is essentially a map of key-value pairs, unsure how to serialize properly with kotlin
@Serializable
data class FunctionResponse(
    val name: String,
    val response: Map
)

@Serializable
data class FileData(
    val mimeType: String? = null,
    val fileUri: String
)

@Serializable
data class ExecutableCode(
    val language: CodeLanguage,
    val code: String,
)

@Serializable
enum class CodeLanguage {
    LANGUAGE_UNSPECIFIED,
    PYTHON
}

@Serializable
data class CodeExecutionResult(
    val outcome: CodeExecutionOutcome,
    val output: String? = null
)

@Serializable
enum class CodeExecutionOutcome {
    OUTCOME_UNSPECIFIED,
    OUTCOME_OK,
    OUTCOME_FAILED,
    OUTCOME_DEADLINE_EXCEEDED
}

//endregion

//region [Tool] and [ToolConfig]

@Serializable
data class Tool(
    val functionDeclarations: List? = null,
    val googleSearchRetrieval: GoogleSearchRetrieval? = null,
    val codeExecution: CodeExecution? = null
)

@Serializable
data class FunctionDeclaration(
    val name: String,
    val description: String,
    val parameters: Schema? = null
)

@Serializable
data class Schema(
    val type: Type,
    val format: String? = null,
    val description: String? = null,
    val nullable: Boolean? = null,
    val `enum`: List? = null,
    val maxItems: Int? = null,
    val minItems: Int? = null,
    val properties: Map? = null,
    val required: List? = null,
    val items: List? = null
)

@Serializable(with = TypeSerializer::class)
enum class Type {
    TYPE_UNSPECIFIED,
    STRING,
    NUMBER,
    INTEGER,
    BOOLEAN,
    ARRAY,
    OBJECT;

    companion object {
        fun fromString(value: String) =
            values().find { it.name.equals(value, ignoreCase = true) }
                ?: throw SerializationException("Unknown type: $value")
    }
}

object TypeSerializer : KSerializer {
    override val descriptor: SerialDescriptor =
        PrimitiveSerialDescriptor("Type", PrimitiveKind.STRING)

    override fun serialize(encoder: Encoder, value: Type) =
        encoder.encodeString(value.name.lowercase()) // Preserve lowercase formatting

    override fun deserialize(decoder: Decoder) =
        Type.fromString(decoder.decodeString())
}

@Serializable
data class GoogleSearchRetrieval(
    val dynamicRetrievalConfig: DynamicRetrievalConfig? = null
)

@Serializable
data class DynamicRetrievalConfig(
    val mode: DynamicRetrievalConfigMode,
    val dynamicThreshold: Float? = null
)

@Serializable
enum class DynamicRetrievalConfigMode {
    MODE_UNSPECIFIED,
    MODE_DYNAMIC
}

@Serializable
class CodeExecution { }

@Serializable
data class ToolConfig(
    val functionCallingConfig: FunctionCallingConfig? = null
)

@Serializable
data class FunctionCallingConfig(
    val mode: FunctionCallingConfigMode? = null,
    val allowedFunctionNames: List? = null
)

enum class FunctionCallingConfigMode {
    MODE_UNSPECIFIED,
    AUTO,
    ANY,
    NONE
}

//endregion

//region CONFIGS

@Serializable
data class SafetySetting(
    val category: HarmCategory,
    val threshold: HarmBlockThreshold
)

@Serializable
enum class HarmCategory {
    HARM_CATEGORY_UNSPECIFIED,
    HARM_CATEGORY_DEROGATORY,
    HARM_CATEGORY_TOXICITY,
    HARM_CATEGORY_VIOLENCE,
    HARM_CATEGORY_SEXUAL,
    HARM_CATEGORY_MEDICAL,
    HARM_CATEGORY_DANGEROUS,
    HARM_CATEGORY_HARASSMENT,
    HARM_CATEGORY_HATE_SPEECH,
    HARM_CATEGORY_SEXUALLY_EXPLICIT,
    HARM_CATEGORY_DANGEROUS_CONTENT,
    HARM_CATEGORY_CIVIC_INTEGRITY
}

@Serializable
enum class HarmBlockThreshold {
    HARM_BLOCK_THRESHOLD_UNSPECIFIED,
    BLOCK_LOW_AND_ABOVE,
    BLOCK_MEDIUM_AND_ABOVE,
    BLOCK_ONLY_HIGH,
    BLOCK_NONE,
    OFF
}

private val ALLOWED_MIMES = setOf(null, MIME_TYPE_TEXT, MIME_TYPE_JPEG, MIME_TYPE_JSON)

@Serializable
data class GenerationConfig(
    val stopSequences: List? = null,
    val responseMimeType: String? = null,
    val responseSchema: Schema? = null,
    val candidateCount: Int? = null, // only 1 allowed for now
    val maxOutputTokens: Int? = null,
    val temperature: Double? = null,
    val topP: Double? = null,
    val topK: Int? = null,
    val presencePenalty: Double? = null,
    val frequencyPenalty: Double? = null,
    val responseLogprobs: Boolean? = null,
    val logprobs: Int? = null
) {
    init {
        require(responseMimeType in ALLOWED_MIMES) { "Unexpected responseMimeType: $responseMimeType" }
    }
}

//endregion

//region [GenerateContentResponse]

@Serializable
data class GenerateContentResponse(
    var candidates: List?,
    var promptFeedback: PromptFeedback? = null,
    var usageMetadata: UsageMetadata? = null
)

@Serializable
data class Candidate(
    val content: Content,
    val finishReason: FinishReason,
    val safetyRatings: List? = null,
    val citationMetadata: List? = null,
    val tokenCount: Int? = null,
    val groundingAttributions: List? = null,
    val groundingMetadata: GroundingMetadata? = null,
    val avgLogprobs: Double? = null,
    val logprobsResult: LogprobsResult? = null,
    val index: Int? = null
)

@Serializable
enum class FinishReason {
    FINISH_REASON_UNSPECIFIED,
    STOP,
    MAX_TOKENS,
    SAFETY,
    RECITATION,
    LANGUAGE,
    OTHER,
    BLOCKLIST,
    PROHIBITED_CONTENT,
    SPII,
    MALFORMED_FUNCTION_CALL
}

@Serializable
data class CitationMetadata(
    val citationSources: List
)

@Serializable
data class CitationSource(
    val startIndex: Int? = null,
    val endIndex: Int? = null,
    val uri: String? = null,
    val license: String? = null
)

@Serializable
data class GroundingAttribution(
    val sourceId: AttributionSourceId,
    val content: Content
)

@Serializable
data class AttributionSourceId(
    val groundingPassage: GroundingPassageId,
    val semanticRetrieverChunk: SemanticRetrieverChunk
)

@Serializable
data class GroundingPassageId(
    val passageID: String,
    val partIndex: Int
)

@Serializable
data class SemanticRetrieverChunk(
    val source: String,
    val chunk: String
)

@Serializable
data class GroundingMetadata(
    val groundingChunks: List,
    val groundingSupports: List,
    val webSearchQueries: List,
    val searchEntryPoint: SearchEntryPoint? = null,
    val retrievalMetadata: RetrievalMetadata
)

@Serializable
data class GroundingChunk(
    val web: GroundingChunkWeb? = null
)

@Serializable
data class GroundingChunkWeb(
    val uri: String,
    val title: String
)

@Serializable
data class GroundingSupport(
    val groundingChunkIndices: List,
    val confidenceScores: List,
    val segment: Segment
)

@Serializable
data class Segment(
    val partIndex: Int,
    val startIndex: Int,
    val endIndex: Int,
    val text: String
)

@Serializable
data class RetrievalMetadata(
    val googleSearchDynamicRetrievalScore: Float? = null
)

@Serializable
data class SearchEntryPoint(
    val renderedContent: String? = null,
    val sdkBlob: String? = null
)

@Serializable
data class LogprobsResult(
    val topCandidates: List,
    val chosenCandidates: List
)

@Serializable
data class TopCandidate(
    val candidates: List
)

@Serializable
data class LogprobsCandidate(
    val token: String,
    val tokenId: Int,
    val logProbability: Double
)

@Serializable
data class PromptFeedback(
    val blockReason: BlockReason? = null,
    val safetyRatings: List? = null
)

@Serializable
enum class BlockReason {
    BLOCK_REASON_UNSPECIFIED,
    SAFTEY,
    OTHER,
    BLOCKLIST,
    PROHIBITED_CONTENT
}

@Serializable
data class SafetyRating(
    val category: HarmCategory,
    val probability: HarmProbability,
    val blocked: Boolean? = null
)

@Serializable
enum class HarmProbability {
    HARM_PROBABILITY_UNSPECIFIED,
    NEGLIGIBLE,
    LOW,
    MEDIUM,
    HIGH
}

@Serializable
data class Error(
    val message: String
)

@Serializable
data class UsageMetadata(
    val promptTokenCount: Int,
    val cachedContentTokenCount: Int? = null,
    val candidatesTokenCount: Int,
    val totalTokenCount: Int
)

//endregion




© 2015 - 2025 Weber Informatics LLC | Privacy Policy