All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.skyenet.apps.plan.PlanCoordinator.kt Maven / Gradle / Ivy

There is a newer version: 1.2.21
Show newest version
package com.simiacryptus.skyenet.apps.plan


import com.simiacryptus.diff.FileValidationUtils
import com.simiacryptus.jopenai.API
import com.simiacryptus.jopenai.ApiModel
import com.simiacryptus.jopenai.describe.Description
import com.simiacryptus.jopenai.util.ClientUtil.toContentList
import com.simiacryptus.jopenai.util.JsonUtil
import com.simiacryptus.skyenet.AgentPatterns
import com.simiacryptus.skyenet.Discussable
import com.simiacryptus.skyenet.TabbedDisplay
import com.simiacryptus.skyenet.core.actors.ParsedResponse
import com.simiacryptus.skyenet.core.platform.ApplicationServices
import com.simiacryptus.skyenet.core.platform.Session
import com.simiacryptus.skyenet.core.platform.StorageInterface
import com.simiacryptus.skyenet.core.platform.User
import com.simiacryptus.skyenet.set
import com.simiacryptus.skyenet.webui.application.ApplicationInterface
import com.simiacryptus.skyenet.webui.session.SessionTask
import com.simiacryptus.skyenet.webui.util.MarkdownUtil
import org.slf4j.LoggerFactory
import java.io.File
import java.nio.file.Path
import java.util.*
import java.util.concurrent.Future
import java.util.concurrent.ThreadPoolExecutor

class PlanCoordinator(
    val user: User?,
    val session: Session,
    val dataStorage: StorageInterface,
    val ui: ApplicationInterface,
    val api: API,
    val settings: Settings,
    val root: Path
) {
    private val taskBreakdownActor by lazy { settings.planningActor() }

    data class TaskBreakdownResult(
        val tasksByID: Map? = null,
        val finalTaskID: String? = null,
    )

    val pool: ThreadPoolExecutor by lazy { ApplicationServices.clientManager.getPool(session, user) }

    data class Task(
        val description: String? = null,
        val taskType: TaskType? = null,
        var task_dependencies: List? = null,
        val input_files: List? = null,
        val output_files: List? = null,
        var state: AbstractTask.TaskState? = null,
        @Description("Command and arguments (in list form) for the task")
        val command: List? = null,
    )

    val virtualFiles: Array by lazy {
        FileValidationUtils.expandFileList(root.toFile())
    }

    private val codeFiles: Map
        get() = virtualFiles
            .filter { it.exists() && it.isFile }
            .filter { !it.name.startsWith(".") }
            .associate { file -> getKey(file) to getValue(file) }


    private fun getValue(file: File) = try {
        file.inputStream().bufferedReader().use { it.readText() }
    } catch (e: Exception) {
        log.warn("Error reading file", e)
        ""
    }

    private fun getKey(file: File) = root.relativize(file.toPath())

    fun startProcess(userMessage: String) {
        val codeFiles = codeFiles
        val eventStatus = if (!codeFiles.all { it.key.toFile().isFile } || codeFiles.size > 2) """
 Files:
 ${codeFiles.keys.joinToString("\n") { "* $it" }}  
     """.trimMargin() else {
            """
            |${
                virtualFiles.joinToString("\n\n") {
                    val path = root.relativize(it.toPath())
                    """
 ## $path
              |
 ${(codeFiles[path] ?: "").let { "$TRIPLE_TILDE\n${it/*.indent("  ")*/}\n$TRIPLE_TILDE" }}
             """.trimMargin()
                }
            }
           """.trimMargin()
        }
        val task = ui.newTask()
        val toInput = { it: String ->
            listOf(
                eventStatus,
                it
            )
        }
        val highLevelPlan = Discussable(
            task = task,
            heading = MarkdownUtil.renderMarkdown(userMessage, ui = ui),
            userMessage = { userMessage },
            initialResponse = { it: String -> taskBreakdownActor.answer(toInput(it), api = api) },
            outputFn = { design: ParsedResponse ->
                AgentPatterns.displayMapInTabs(
                    mapOf(
                        "Text" to MarkdownUtil.renderMarkdown(design.text, ui = ui),
                        "JSON" to MarkdownUtil.renderMarkdown(
                            "${TRIPLE_TILDE}json\n${JsonUtil.toJson(design.obj)/*.indent("  ")*/}\n$TRIPLE_TILDE",
                            ui = ui
                        ),
                    )
                )
            },
            ui = ui,
            reviseResponse = { userMessages: List> ->
                taskBreakdownActor.respond(
                    messages = (userMessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) }
                        .toTypedArray()),
                    input = toInput(userMessage),
                    api = api
                )
            },
        ).call()

        initPlan(highLevelPlan, userMessage, task)
    }

    fun initPlan(
        plan: ParsedResponse,
        userMessage: String,
        task: SessionTask
    ) {
        try {
            val tasksByID =
                plan.obj.tasksByID?.entries?.toTypedArray()?.associate { it.key to it.value } ?: mapOf()
            val genState = GenState(tasksByID.toMutableMap())
            val diagramTask = ui.newTask(false).apply { task.add(placeholder) }
            val diagramBuffer =
                diagramTask.add(
                    MarkdownUtil.renderMarkdown(
                        "## Task Dependency Graph\n${TRIPLE_TILDE}mermaid\n${buildMermaidGraph(genState.subTasks)}\n$TRIPLE_TILDE",
                        ui = ui
                    )
                )
            val taskIdProcessingQueue = genState.taskIdProcessingQueue
            val subTasks = genState.subTasks
            executePlan(
                task,
                diagramBuffer,
                subTasks,
                diagramTask,
                genState,
                taskIdProcessingQueue,
                pool,
                userMessage,
                plan
            )
        } catch (e: Throwable) {
            log.warn("Error during incremental code generation process", e)
            task.error(ui, e)
        }
    }

    fun executePlan(
        task: SessionTask,
        diagramBuffer: StringBuilder?,
        subTasks: Map,
        diagramTask: SessionTask,
        genState: GenState,
        taskIdProcessingQueue: MutableList,
        pool: ThreadPoolExecutor,
        userMessage: String,
        plan: ParsedResponse
    ) {
        val taskTabs = object : TabbedDisplay(ui.newTask(false).apply { task.add(placeholder) }) {
            override fun renderTabButtons(): String {
                diagramBuffer?.set(
                    MarkdownUtil.renderMarkdown(
                        """
                                |## Task Dependency Graph
                                |${TRIPLE_TILDE}mermaid
                                |${buildMermaidGraph(subTasks)}
                                |$TRIPLE_TILDE
                                """.trimMargin(), ui = ui
                    )
                )
                diagramTask.complete()
                return buildString {
                    append("
\n") super.tabs.withIndex().forEach { (idx, t) -> val (taskId, taskV) = t val subTask = genState.tasksByDescription[taskId] if (null == subTask) { log.warn("Task tab not found: $taskId") } val isChecked = if (taskId in taskIdProcessingQueue) "checked" else "" val style = when (subTask?.state) { AbstractTask.TaskState.Completed -> " style='text-decoration: line-through;'" null -> " style='opacity: 20%;'" AbstractTask.TaskState.Pending -> " style='opacity: 30%;'" else -> "" } append("
\n") } append("
") } } } // Initialize task tabs taskIdProcessingQueue.forEach { taskId -> val newTask = ui.newTask(false) genState.uitaskMap[taskId] = newTask val subtask = genState.subTasks[taskId] val description = subtask?.description log.debug("Creating task tab: $taskId ${System.identityHashCode(subtask)} $description") taskTabs[description ?: taskId] = newTask.placeholder } Thread.sleep(100) while (taskIdProcessingQueue.isNotEmpty()) { val taskId = taskIdProcessingQueue.removeAt(0) val subTask = genState.subTasks[taskId] ?: throw RuntimeException("Task not found: $taskId") genState.taskFutures[taskId] = pool.submit { subTask.state = AbstractTask.TaskState.Pending taskTabs.update() log.debug("Awaiting dependencies: ${subTask.task_dependencies?.joinToString(", ") ?: ""}") subTask.task_dependencies ?.associate { it to genState.taskFutures[it] } ?.forEach { (id, future) -> try { future?.get() ?: log.warn("Dependency not found: $id") } catch (e: Throwable) { log.warn("Error", e) } } subTask.state = AbstractTask.TaskState.InProgress taskTabs.update() log.debug("Running task: ${System.identityHashCode(subTask)} ${subTask.description}") val task1 = genState.uitaskMap.get(taskId) ?: ui.newTask(false).apply { taskTabs[taskId] = placeholder } try { val dependencies = subTask.task_dependencies?.toMutableSet() ?: mutableSetOf() dependencies += getAllDependencies( subTask = subTask, subTasks = genState.subTasks, visited = mutableSetOf() ) task1.add( MarkdownUtil.renderMarkdown( """ ## Task `${taskId}` ${subTask.description ?: ""} | |${TRIPLE_TILDE}json |${JsonUtil.toJson(data = subTask)/*.indent(" ")*/} |$TRIPLE_TILDE | |### Dependencies: |${dependencies.joinToString("\n") { "- $it" }} | """.trimMargin(), ui = ui ) ) settings.getImpl(subTask).run( agent = this, taskId = taskId, userMessage = userMessage, plan = plan, genState = genState, task = task1, taskTabs = taskTabs ) } catch (e: Throwable) { log.warn("Error during task execution", e) task1.error(ui, e) } finally { genState.completedTasks.add(element = taskId) subTask.state = AbstractTask.TaskState.Completed log.debug("Completed task: $taskId ${System.identityHashCode(subTask)}") taskTabs.update() } } } genState.taskFutures.forEach { (id, future) -> try { future.get() ?: log.warn("Dependency not found: $id") } catch (e: Throwable) { log.warn("Error", e) } } } private fun getAllDependencies( subTask: Task, subTasks: Map, visited: MutableSet ): List { val dependencies = subTask.task_dependencies?.toMutableList() ?: mutableListOf() subTask.task_dependencies?.forEach { dep -> if (dep in visited) return@forEach val subTask = subTasks[dep] if (subTask != null) { visited.add(dep) dependencies.addAll(getAllDependencies(subTask, subTasks, visited)) } } return dependencies } companion object { val log = LoggerFactory.getLogger(PlanCoordinator::class.java) fun executionOrder(tasks: Map): List { val taskIds: MutableList = mutableListOf() val taskMap = tasks.toMutableMap() while (taskMap.isNotEmpty()) { val nextTasks = taskMap.filter { (_, task) -> task.task_dependencies?.all { taskIds.contains(it) } ?: true } if (nextTasks.isEmpty()) { throw RuntimeException("Circular dependency detected in task breakdown") } taskIds.addAll(nextTasks.keys) nextTasks.keys.forEach { taskMap.remove(it) } } return taskIds } val isWindows = System.getProperty("os.name").lowercase(Locale.getDefault()).contains("windows") private fun sanitizeForMermaid(input: String) = input .replace(" ", "_") .replace("\"", "\\\"") .replace("[", "\\[") .replace("]", "\\]") .replace("(", "\\(") .replace(")", "\\)") .let { "`$it`" } private fun escapeMermaidCharacters(input: String) = input .replace("\"", "\\\"") .let { '"' + it + '"' } fun buildMermaidGraph(subTasks: Map): String { val graphBuilder = StringBuilder("graph TD;\n") subTasks.forEach { (taskId, task) -> val sanitizedTaskId = sanitizeForMermaid(taskId) val taskType = task.taskType?.name ?: "Unknown" val escapedDescription = escapeMermaidCharacters(task.description ?: "") val style = when (task.state) { AbstractTask.TaskState.Completed -> ":::completed" AbstractTask.TaskState.InProgress -> ":::inProgress" else -> ":::$taskType" } graphBuilder.append(" ${sanitizedTaskId}[$escapedDescription]$style;\n") task.task_dependencies?.forEach { dependency -> val sanitizedDependency = sanitizeForMermaid(dependency) graphBuilder.append(" $sanitizedDependency --> ${sanitizedTaskId};\n") } } graphBuilder.append(" classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;\n") graphBuilder.append(" classDef NewFile fill:lightblue,stroke:#333,stroke-width:2px;\n") graphBuilder.append(" classDef EditFile fill:lightgreen,stroke:#333,stroke-width:2px;\n") graphBuilder.append(" classDef Documentation fill:lightyellow,stroke:#333,stroke-width:2px;\n") graphBuilder.append(" classDef Inquiry fill:orange,stroke:#333,stroke-width:2px;\n") graphBuilder.append(" classDef TaskPlanning fill:lightgrey,stroke:#333,stroke-width:2px;\n") graphBuilder.append(" classDef completed fill:#90EE90,stroke:#333,stroke-width:2px;\n") graphBuilder.append(" classDef inProgress fill:#FFA500,stroke:#333,stroke-width:2px;\n") return graphBuilder.toString() } } data class GenState( val subTasks: Map, val tasksByDescription: MutableMap = subTasks.entries.toTypedArray() .associate { it.value.description to it.value }.toMutableMap(), val taskIdProcessingQueue: MutableList = executionOrder(subTasks).toMutableList(), val taskResult: MutableMap = mutableMapOf(), val completedTasks: MutableList = mutableListOf(), val taskFutures: MutableMap> = mutableMapOf(), val uitaskMap: MutableMap = mutableMapOf() ) } const val TRIPLE_TILDE = "```"




© 2015 - 2024 Weber Informatics LLC | Privacy Policy