All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.unitmesh.prompt.executor.base.JobStrategyExecutor.kt Maven / Gradle / Ivy

package cc.unitmesh.prompt.executor.base

import cc.unitmesh.cf.core.llms.LlmProvider
import cc.unitmesh.cf.core.llms.MockLlmProvider
import cc.unitmesh.connection.ConnectionConfig
import cc.unitmesh.connection.MockLlmConnection
import cc.unitmesh.connection.OpenAiConnection
import cc.unitmesh.openai.OpenAiProvider
import cc.unitmesh.prompt.model.Job
import com.charleskorn.kaml.PolymorphismStyle
import com.charleskorn.kaml.Yaml
import com.charleskorn.kaml.YamlConfiguration
import kotlinx.datetime.*
import kotlinx.serialization.decodeFromString
import java.math.BigDecimal
import java.nio.file.Path

interface JobStrategyExecutor {
    companion object {
        val log: org.slf4j.Logger = org.slf4j.LoggerFactory.getLogger(JobStrategyExecutor::class.java)
    }

    val basePath: Path

    fun execute()

    fun createLlmProvider(job: Job, temperature: BigDecimal?): LlmProvider {
        val llmProvider = when (val connection = initConnectionConfig(job)) {
            is OpenAiConnection -> {
                val provider = OpenAiProvider(connection.apiKey, connection.apiHost)
                if (temperature != null) {
                    provider.temperature = temperature.toDouble()
                }
                provider
            }

            is MockLlmConnection -> MockLlmProvider(connection.response)
            else -> throw Exception("unsupported connection type: ${connection.type}")
        }
        return llmProvider
    }

    fun handleJobResult(jobName: String, job: Job, llmResult: String) {
        log.debug("execute job: $jobName")
        val validators = job.buildValidators(llmResult)
        validators.forEach {
            val isSuccess = it.validate()
            val simpleName = it.javaClass.simpleName
            if (!isSuccess) {
                log.error("$simpleName validate failed: ${it.input}")
            } else {
                log.debug("$simpleName validate success: ${it.input}")
            }
        }

        // write to output
        val resultFileName = createFileName(jobName)
        writeToFile(resultFileName, llmResult)
    }


    fun writeToFile(resultFileName: String, llmResult: String) {
        val resultFile = this.basePath.resolve(resultFileName).toFile()
        val relativePath = this.basePath.relativize(resultFile.toPath())
        log.info("write result to file: $relativePath")
        resultFile.writeText(llmResult)
    }

    fun createFileName(name: String): String {
        val currentMoment: Instant = Clock.System.now()
        val datetimeInUtc: LocalDateTime = currentMoment.toLocalDateTime(TimeZone.UTC)
        val timeStr = datetimeInUtc.toString().replace(":", "-")
        val jobName = name.replace(" ", "-")

        return "${jobName}-${timeStr}.txt"
    }

    private fun initConnectionConfig(job: Job): ConnectionConfig {
        val connectionFile = this.basePath.resolve(job.connection).toFile()
        log.info("connection file: ${connectionFile.absolutePath}")
        val text = connectionFile.readBytes().toString(Charsets.UTF_8)

        val configuration = YamlConfiguration(polymorphismStyle = PolymorphismStyle.Property)
        val connection = Yaml(configuration = configuration).decodeFromString(text)
        return connection.convert()
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy