All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tri.ai.text.docs.LocalDocumentQaDriver.kt Maven / Gradle / Ivy

/*-
 * #%L
 * tri.promptfx:promptkt
 * %%
 * Copyright (C) 2023 - 2025 Johns Hopkins University Applied Physics Laboratory
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */
package tri.ai.text.docs

import tri.ai.core.TextPlugin
import tri.ai.embedding.LocalFolderEmbeddingIndex
import tri.ai.pips.AiPipelineExecutor
import tri.ai.pips.AiPipelineResult
import tri.ai.pips.PrintMonitor
import tri.ai.prompt.AiPromptLibrary
import tri.ai.prompt.trace.AiPromptTraceSupport
import tri.ai.text.chunks.GroupingTemplateJoiner
import java.io.File

/**
 * A local file driver for document Q&A, using plugins for completion and embedding models.
 * This driver requires a root folder, using the child folders as the set of available document sets.
 * Documents and embeddings within a folder are managed by [LocalFolderEmbeddingIndex].
 */
class LocalDocumentQaDriver(val root: File) : DocumentQaDriver {

    init {
        require(root.exists()) { "Root directory does not exist" }
        require(root.isDirectory) { "Root must be a directory" }
    }

    override val folders: List =
        root.listFiles()!!
            .filter { it.isDirectory }
            .map { it.name }
    override var folder: String = ""
        set(value) {
            require(value == "" || value in folders) { "Expected blank folder (to use root folder) or value to be a subfolder of root, but was '$value'." }
            field = value
        }

    val docsFolder
        get() = if (folder == "") root else File(root, folder)

    private var completionModelInst = TextPlugin.textCompletionModels().first()
    private var embeddingModelInst = TextPlugin.embeddingModels().first()

    override var completionModel
        get() = completionModelInst.modelId
        set(value) {
            completionModelInst = TextPlugin.textCompletionModels().first { it.modelId == value }
        }
    override var embeddingModel
        get() = embeddingModelInst.modelId
        set(value) {
            embeddingModelInst = TextPlugin.embeddingModels().first { it.modelId == value }
        }
    override var temp: Double = 1.0
    override var maxTokens: Int = 2000

    private val prompt = AiPromptLibrary.lookupPrompt("$PROMPT_PREFIX-docs")
    private val joiner = GroupingTemplateJoiner("$JOINER_PREFIX-citations")

    override fun initialize() {
    }

    override fun close() {
        TextPlugin.orderedPlugins.forEach { it.close() }
    }

    override suspend fun answerQuestion(input: String): AiPipelineResult {
        val index = LocalFolderEmbeddingIndex(docsFolder, embeddingModelInst)
        val planner = DocumentQaPlanner(index, completionModelInst, listOf(), 1).plan(
            question = input,
            prompt = prompt,
            chunksToRetrieve = 8,
            minChunkSize = 50,
            contextStrategy = joiner,
            contextChunks = 10,
            maxTokens = maxTokens,
            temp = temp,
            numResponses = 1,
            snippetCallback = { }
        )
        val monitor = PrintMonitor()
        val result = AiPipelineExecutor.execute(planner.plan, monitor).finalResult as AiPromptTraceSupport
        return result.asPipelineResult()
    }

    companion object {
        const val PROMPT_PREFIX = "question-answer"
        const val JOINER_PREFIX = "snippet-joiner"
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy