tri.ai.memory.BotMemory.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of promptkt Show documentation
Show all versions of promptkt Show documentation
LLM and prompt engineering.
The newest version!
/*-
* #%L
* tri.promptfx:promptkt
* %%
* Copyright (C) 2023 - 2024 Johns Hopkins University Applied Physics Laboratory
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package tri.ai.memory
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.KotlinModule
import com.fasterxml.jackson.module.kotlin.readValue
import tri.ai.core.TextChat
import tri.ai.core.TextChatMessage
import tri.ai.core.TextChatRole
import tri.ai.embedding.EmbeddingService
import tri.ai.embedding.dot
import java.io.File
/**
* A memory of previous conversations. Uses a chat engine to summarize memories of previous conversations,
* which are then periodically ingested into the chat engine's memory.
*/
class BotMemory(val persona: BotPersona, val chatEngine: TextChat, val embeddingService: EmbeddingService) : MemoryService {
val memoryHistoryLimit = 5
val historyLimit = 20
/** Number of steps between interim saves. */
val stepsToSaveMemory = 20
val memoryFile = File("memory.json")
val chatHistory = mutableListOf()
//region API IMPLEMENTATION
override fun initMemory() {
if (!memoryFile.exists()) {
memoryFile.writeText("[]")
}
val memory = ObjectMapper()
.registerModule(KotlinModule.Builder().build())
.readValue>(memoryFile)
chatHistory.addAll(memory)
}
override suspend fun saveMemory(interimSave: Boolean) {
if (!interimSave || stepsSinceLastMemory() >= stepsToSaveMemory) {
println("\u001B[90mSaving memory...\u001B[0m")
generateMemories()
val memories = chatHistory.map { it.withEmbedding() }
ObjectMapper()
.registerModule(KotlinModule.Builder().build())
.writerWithDefaultPrettyPrinter()
.writeValue(memoryFile, memories)
}
}
override suspend fun addChat(chatMessage: MemoryItem) {
chatHistory += chatMessage.withEmbedding()
}
override fun buildContextualConversationHistory(userInput: MemoryItem): List {
// use embedding index with recent chat for relevant messages
val historyForMemorySearch = chatHistory.takeLast(2)
val avgHistoryEmbedding = historyForMemorySearch.map { it.embedding }
.mapIndexed { i, floats -> floats.map { it * (i + 1) } }
.reduce { acc, floats -> acc.zip(floats).map { it.first + it.second } }
val relevant = chatHistory.map { it to it.embedding.dot(avgHistoryEmbedding) }
.sortedByDescending { it.second }
.take(memoryHistoryLimit)
.filter { (it.first.content?: "").length > 50 }
.map { it.first }
// gather more recent memories
val memories = chatHistory.filter { it.isMemory() }.takeLast(memoryHistoryLimit).toSet()
// gather more recent chat messages
val recentChat = chatHistory.takeLast(historyLimit).toSet()
return (relevant - memories - recentChat) + (memories - recentChat) + recentChat
}
//endregion
private suspend fun MemoryItem.withEmbedding(): MemoryItem {
return if (embedding.isEmpty())
MemoryItem(role, content, embeddingService.calculateEmbedding(content ?: "").map { String.format("%.4f", it).toFloat() })
else
this
}
private fun stepsSinceLastMemory() = chatHistory.size - chatHistory.indexOfLast { it.isMemory() }
private suspend fun generateMemories() {
// collect chat since last memory
val lastMemory = chatHistory.indexOfLast { it.isMemory() }
val chatSinceLastMemory = chatHistory.subList(lastMemory + 1, chatHistory.size)
// summarize content for memory
val conversation = chatSinceLastMemory.joinToString("\n") {
(if (it.role == TextChatRole.Assistant) persona.name else it.role.toString()) + ": " + it.content
}
val query = """
Please summarize the following conversation:
'''
$conversation
'''
Include any notable topics discussed, specific facts, and in particular things you learned about the user.
""".trimIndent()
val response = chatEngine.chat(
listOf(
TextChatMessage(TextChatRole.System, "You are a chatbot that summarizes key content from prior conversations."),
TextChatMessage(TextChatRole.User, query)
))
val summaryMessage = TextChatMessage(TextChatRole.Assistant, "[MEMORY] " + (response.firstValue!!.content ?: "").trim())
chatHistory.add(MemoryItem(summaryMessage))
}
private fun MemoryItem.isMemory() =
role == TextChatRole.Assistant && (content ?: "").startsWith("[MEMORY]")
}