tri.ai.text.docs.DocumentQaPlanner.kt Maven / Gradle / Ivy
/*-
* #%L
* tri.promptfx:promptfx
* %%
* Copyright (C) 2023 - 2025 Johns Hopkins University Applied Physics Laboratory
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package tri.ai.text.docs
import tri.ai.core.TextChatMessage
import tri.ai.core.TextCompletion
import tri.ai.core.instructTask
import tri.ai.embedding.*
import tri.ai.pips.AiTaskList
import tri.ai.pips.task
import tri.ai.prompt.AiPrompt
import tri.ai.prompt.trace.*
import tri.ai.text.chunks.SnippetJoiner
import tri.util.ANSI_GRAY
import tri.util.ANSI_RESET
import tri.util.info
/** Runs the document QA information retrieval, query, and summarization process. */
class DocumentQaPlanner(val index: EmbeddingIndex, val completionEngine: TextCompletion, val chatHistory: List, val historySize: Int) {
/**
* Asynchronous tasks to execute for answering the given question.
* @param question question to answer
* @param prompt prompt to use for answering the question
* @param chunksToRetrieve number of chunks to retrieve
* @param minChunkSize minimum size of a chunk for use in a prompt
* @param contextStrategy strategy for constructing the context
* @param contextChunks how many of the retrieved chunks to use for constructing the context
* @param maxTokens maximum number of tokens to generate
* @param temp temperature for sampling
* @param numResponses number of responses to generate
*/
fun plan(
question: String,
prompt: AiPrompt,
chunksToRetrieve: Int,
minChunkSize: Int,
contextStrategy: SnippetJoiner,
contextChunks: Int,
maxTokens: Int,
temp: Double,
numResponses: Int,
snippetCallback: (List) -> Unit
): AiTaskList = task("upgrade-embeddings-file") {
snippetCallback(emptyList())
(index as? LocalFolderEmbeddingIndex)?.upgradeEmbeddingIndex()
}.task("load-embeddings-file-and-calculate") {
// trigger loading of embeddings file using a similarity query
index.findMostSimilar("a", 1)
}.aitask("find-relevant-sections") {
// for each question, generate a list of relevant chunks
findRelevantSection(question, chunksToRetrieve).also {
snippetCallback(it.values!!)
}
}.aitaskonlist("question-answer") { snippets ->
val queryChunks = snippets.filter { it.chunkSize >= minChunkSize }
.take(contextChunks)
val context = contextStrategy.constructContext(queryChunks)
val response = completionEngine.instructTask(prompt, question, context, maxTokens, temp, numResponses,
history = chatHistory.takeLast(historySize)
)
val questionEmbedding = index.embeddingService.calculateEmbedding(question)
val responseEmbeddings = response.values?.map {
index.embeddingService.calculateEmbedding(it)
} ?: listOf()
// TODO - make this support more than one response embedding
// add snippet response scores for first response embedding only
if (responseEmbeddings.isNotEmpty()) {
snippets.forEach {
it.responseScore = cosineSimilarity(responseEmbeddings[0], it.chunkEmbedding).toFloat()
}
}
response.mapOutput {
QuestionAnswerResult(
query = SemanticTextQuery(question, questionEmbedding, index.embeddingService.modelId),
matches = snippets,
trace = response,
responseEmbeddings = responseEmbeddings
)
}
}.aitask("process-result") {
info("$ANSI_GRAY Similarity of question to response: ${it.responseScore}$ANSI_RESET")
FormattedPromptTraceResult(it.trace, it.splitOutputs().map { it.formatResult() })
}
//region SIMILARITY CALCULATIONS
/** Finds the most relevant section to the query. */
private suspend fun findRelevantSection(query: String, maxChunks: Int): AiPromptTrace {
// documentLibrary?.let { return findRelevantSection(it, query, maxChunks) }
val matches = index.findMostSimilar(query, maxChunks)
val modelId = (index as? LocalFolderEmbeddingIndex)?.embeddingService?.modelId
return AiPromptTrace(
modelInfo = modelId?.let { AiModelInfo(it) },
outputInfo = AiOutputInfo(matches)
)
}
// /** Finds the most relevant section to the query. */
// private suspend fun findRelevantSection(library: TextLibrary, query: String, maxChunks: Int): AiPromptTrace {
// val embeddingSvc = index.embeddingService
// val modelId = embeddingSvc.modelId
// val semanticTextQuery = SemanticTextQuery(query, embeddingSvc.calculateEmbedding(query), modelId)
// val matches = library.docs.flatMap { doc ->
// doc.calculateMissingEmbeddings(embeddingSvc)
// doc.chunks.map {
// val chunkEmbedding = it.getEmbeddingInfo(modelId)!!
// EmbeddingMatch(semanticTextQuery, doc, it, modelId, chunkEmbedding,
// cosineSimilarity(semanticTextQuery.embedding, chunkEmbedding).toFloat()
// )
// }
// }.sortedByDescending { it.queryScore }.take(maxChunks)
// return AiPromptTrace(
// modelInfo = AiModelInfo(modelId),
// outputInfo = AiOutputInfo(matches)
// )
// }
//endregion
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy