All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.platon.pulsar.external.impl.ChatModelImpl.kt Maven / Gradle / Ivy

package ai.platon.pulsar.external.impl

import ai.platon.pulsar.common.getLogger
import ai.platon.pulsar.dom.FeaturedDocument
import ai.platon.pulsar.external.ChatModel
import ai.platon.pulsar.external.ModelResponse
import ai.platon.pulsar.external.ResponseState
import ai.platon.pulsar.external.TokenUsage
import dev.langchain4j.data.message.SystemMessage
import dev.langchain4j.data.message.UserMessage
import dev.langchain4j.model.chat.ChatLanguageModel
import dev.langchain4j.model.output.FinishReason
import org.jsoup.nodes.Element

open class ChatModelImpl(
    private val langchainModel: ChatLanguageModel
) : ChatModel {
    private val logger = getLogger(this)
    
    /**
     * Generates a response from the model based on a sequence of messages.
     * Typically, the sequence contains messages in the following order:
     * System (optional) - User - AI - User - AI - User ...
     *
     * @return The response generated by the model.
     */
    override fun call(prompt: String) = call("", prompt)
    
    /**
     * Generates a response from the model based on a sequence of messages.
     * Typically, the sequence contains messages in the following order:
     * System (optional) - User - AI - User - AI - User ...
     *
     * @param context The text context.
     * @return The response generated by the model.
     */
    override fun call(userMessage: String, systemMessage: String): ModelResponse {
        val um = UserMessage.userMessage(userMessage)
        val sm = SystemMessage.systemMessage(systemMessage)
        
        val response = try {
            langchainModel.generate(um, sm)
        } catch (e: Exception) {
            logger.warn("Model call interrupted. | {}", e.message)
            return ModelResponse("", ResponseState.OTHER)
        }
        
        val u = response.tokenUsage()
        val tokenUsage = TokenUsage(u.inputTokenCount(), u.outputTokenCount(), u.totalTokenCount())
        val r = response.finishReason()
        val state = when (r) {
            FinishReason.STOP -> ResponseState.STOP
            FinishReason.LENGTH -> ResponseState.LENGTH
            FinishReason.TOOL_EXECUTION -> ResponseState.TOOL_EXECUTION
            FinishReason.CONTENT_FILTER -> ResponseState.CONTENT_FILTER
            else -> ResponseState.OTHER
        }
        return ModelResponse(response.content().text(), state, tokenUsage)
    }
    
    /**
     * Generates a response from the model based on a sequence of messages.
     * Typically, the sequence contains messages in the following order:
     * System (optional) - User - AI - User - AI - User ...
     *
     * @param document An array of messages.
     * @return The response generated by the model.
     */
    override fun call(document: FeaturedDocument, prompt: String) = call(document.document, prompt)
    
    /**
     * Generates a response from the model based on a sequence of messages.
     * Typically, the sequence contains messages in the following order:
     * System (optional) - User - AI - User - AI - User ...
     *
     * @param ele The Element to ask.
     * @return The response generated by the model.
     */
    override fun call(ele: Element, prompt: String) = call(ele.text(), prompt)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy