commonMain.dev.shreyaspatil.ai.client.generativeai.GenerativeModel.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of generativeai-google Show documentation
Show all versions of generativeai-google Show documentation
Google's Generative AI Multiplatform SDK
The newest version!
/*
* Copyright 2023 Shreyas Patil
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.shreyaspatil.ai.client.generativeai
import dev.shreyaspatil.ai.client.generativeai.common.APIController
import dev.shreyaspatil.ai.client.generativeai.common.CountTokensRequest
import dev.shreyaspatil.ai.client.generativeai.common.GenerateContentRequest
import dev.shreyaspatil.ai.client.generativeai.common.util.fullModelName
import dev.shreyaspatil.ai.client.generativeai.internal.util.toInternal
import dev.shreyaspatil.ai.client.generativeai.internal.util.toPublic
import dev.shreyaspatil.ai.client.generativeai.type.Bitmap
import dev.shreyaspatil.ai.client.generativeai.type.Content
import dev.shreyaspatil.ai.client.generativeai.type.CountTokensResponse
import dev.shreyaspatil.ai.client.generativeai.type.FinishReason
import dev.shreyaspatil.ai.client.generativeai.type.GenerateContentResponse
import dev.shreyaspatil.ai.client.generativeai.type.GenerationConfig
import dev.shreyaspatil.ai.client.generativeai.type.GoogleGenerativeAIException
import dev.shreyaspatil.ai.client.generativeai.type.PromptBlockedException
import dev.shreyaspatil.ai.client.generativeai.type.RequestOptions
import dev.shreyaspatil.ai.client.generativeai.type.ResponseStoppedException
import dev.shreyaspatil.ai.client.generativeai.type.SafetySetting
import dev.shreyaspatil.ai.client.generativeai.type.SerializationException
import dev.shreyaspatil.ai.client.generativeai.type.Tool
import dev.shreyaspatil.ai.client.generativeai.type.ToolConfig
import dev.shreyaspatil.ai.client.generativeai.type.content
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.map
import kotlinx.serialization.ExperimentalSerializationApi
import kotlin.jvm.JvmOverloads
/**
* A facilitator for a given multimodal model (eg; Gemini).
*
* @property modelName name of the model in the backend
* @property apiKey authentication key for interacting with the backend
* @property generationConfig configuration parameters to use for content generation
* @property safetySettings the safety bounds to use during alongside prompts during content
* generation
* @property systemInstruction contains a [Content] that directs the model to behave a certain way
* @property requestOptions configuration options to utilize during backend communication
*/
@OptIn(ExperimentalSerializationApi::class)
class GenerativeModel
internal constructor(
val modelName: String,
val apiKey: String,
val generationConfig: GenerationConfig? = null,
val safetySettings: List? = null,
val tools: List? = null,
val toolConfig: ToolConfig? = null,
val systemInstruction: Content? = null,
val requestOptions: RequestOptions = RequestOptions(),
private val controller: APIController,
) {
@JvmOverloads
constructor(
modelName: String,
apiKey: String,
generationConfig: GenerationConfig? = null,
safetySettings: List? = null,
requestOptions: RequestOptions = RequestOptions(),
tools: List? = null,
toolConfig: ToolConfig? = null,
systemInstruction: Content? = null,
) : this(
fullModelName(modelName),
apiKey,
generationConfig,
safetySettings,
tools,
toolConfig,
systemInstruction?.let { Content("system", it.parts) },
requestOptions,
APIController(
apiKey,
modelName,
requestOptions.toInternal(),
"genai-android",
),
)
/**
* Generates a response from the backend with the provided [Content]s.
*
* @param prompt A group of [Content]s to send to the model.
* @return A [GenerateContentResponse] after some delay. Function should be called within a
* suspend context to properly manage concurrency.
*/
suspend fun generateContent(vararg prompt: Content): GenerateContentResponse =
try {
controller.generateContent(constructRequest(*prompt)).toPublic().validate()
} catch (e: Throwable) {
throw GoogleGenerativeAIException.from(e)
}
/**
* Generates a streaming response from the backend with the provided [Content]s.
*
* @param prompt A group of [Content]s to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
*/
fun generateContentStream(vararg prompt: Content): Flow =
controller
.generateContentStream(constructRequest(*prompt))
.catch { throw GoogleGenerativeAIException.from(it) }
.map { it.toPublic().validate() }
/**
* Generates a response from the backend with the provided text represented [Content].
*
* @param prompt The text to be converted into a single piece of [Content] to send to the model.
* @return A [GenerateContentResponse] after some delay. Function should be called within a
* suspend context to properly manage concurrency.
*/
suspend fun generateContent(prompt: String): GenerateContentResponse =
generateContent(content { text(prompt) })
/**
* Generates a streaming response from the backend with the provided text represented [Content].
*
* @param prompt The text to be converted into a single piece of [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
*/
fun generateContentStream(prompt: String): Flow =
generateContentStream(content { text(prompt) })
/**
* Generates a response from the backend with the provided bitmap represented [Content].
*
* @param prompt The bitmap to be converted into a single piece of [Content] to send to the model.
* @return A [GenerateContentResponse] after some delay. Function should be called within a
* suspend context to properly manage concurrency.
*/
suspend fun generateContent(prompt: Bitmap): GenerateContentResponse =
generateContent(content { image(prompt) })
/**
* Generates a streaming response from the backend with the provided bitmap represented [Content].
*
* @param prompt The bitmap to be converted into a single piece of [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
*/
fun generateContentStream(prompt: Bitmap): Flow =
generateContentStream(content { image(prompt) })
/** Creates a chat instance which internally tracks the ongoing conversation with the model */
fun startChat(history: List = emptyList()): Chat = Chat(this, history.toMutableList())
/**
* Counts the number of tokens used in a prompt.
*
* @param prompt A group of [Content]s to count tokens of.
* @return A [CountTokensResponse] containing the number of tokens in the prompt.
*/
suspend fun countTokens(vararg prompt: Content): CountTokensResponse {
return controller.countTokens(constructCountTokensRequest(*prompt)).toPublic()
}
/**
* Counts the number of tokens used in a prompt.
*
* @param prompt The text to be converted to a single piece of [Content] to count the tokens of.
* @return A [CountTokensResponse] containing the number of tokens in the prompt.
*/
suspend fun countTokens(prompt: String): CountTokensResponse {
return countTokens(content { text(prompt) })
}
/**
* Counts the number of tokens used in a prompt.
*
* @param prompt The image to be converted to a single piece of [Content] to count the tokens of.
* @return A [CountTokensResponse] containing the number of tokens in the prompt.
*/
suspend fun countTokens(prompt: Bitmap): CountTokensResponse {
return countTokens(content { image(prompt) })
}
private fun constructRequest(vararg prompt: Content) =
GenerateContentRequest(
modelName,
prompt.map { it.toInternal() },
safetySettings?.map { it.toInternal() },
generationConfig?.toInternal(),
tools?.map { it.toInternal() },
toolConfig?.toInternal(),
systemInstruction?.toInternal(),
)
private fun constructCountTokensRequest(vararg prompt: Content) =
CountTokensRequest.forGenAI(constructRequest(*prompt))
private fun GenerateContentResponse.validate() = apply {
if (candidates.isEmpty() && promptFeedback == null) {
throw SerializationException("Error deserializing response, found no valid fields")
}
promptFeedback?.blockReason?.let { throw PromptBlockedException(this) }
candidates
.mapNotNull { it.finishReason }
.firstOrNull { it != FinishReason.STOP }
?.let { throw ResponseStoppedException(this) }
}
}