Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
fi.evolver.ai.vaadin.component.AiChatComponent Maven / Gradle / Ivy
package fi.evolver.ai.vaadin.component;
import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.vaadin.flow.component.AttachEvent;
import com.vaadin.flow.component.Component;
import com.vaadin.flow.component.UI;
import com.vaadin.flow.component.messages.MessageInput.SubmitEvent;
import com.vaadin.flow.component.notification.Notification;
import com.vaadin.flow.theme.lumo.LumoUtility;
import fi.evolver.ai.spring.AsyncRunner;
import fi.evolver.ai.spring.ContentSubscriber;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.chat.ChatApi;
import fi.evolver.ai.spring.chat.ChatResponse;
import fi.evolver.ai.spring.chat.FunctionCall;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt;
import fi.evolver.ai.spring.chat.prompt.Message;
import fi.evolver.ai.spring.chat.prompt.Message.Role;
import fi.evolver.ai.spring.provider.openai.OpenAiRequestGenerator;
import fi.evolver.ai.vaadin.ChatRepository;
import fi.evolver.ai.vaadin.PromptRepository;
import fi.evolver.ai.vaadin.entity.Chat;
import fi.evolver.ai.vaadin.entity.ChatMessage;
import fi.evolver.ai.vaadin.entity.ChatMessage.ChatMessageRole;
import fi.evolver.ai.vaadin.entity.Prompt;
import fi.evolver.ai.vaadin.util.AuthUtils;
import fi.evolver.ai.vaadin.util.ChatUtils;
import fi.evolver.ai.vaadin.util.ChatUtils.CommandType;
import fi.evolver.ai.vaadin.view.HistoryAwareChat;
import fi.evolver.basics.spring.log.LogMetadataAttribute;
import fi.evolver.basics.spring.util.MessageChainUtils;
import fi.evolver.basics.spring.util.MessageChainUtils.MessageChain;
import fi.evolver.utils.ContextUtils;
import fi.evolver.utils.string.StringUtils;
public class AiChatComponent extends BaseChatComponent implements HistoryAwareChat {
@Serial
private static final long serialVersionUID = 1L;
private static final Logger LOG = LoggerFactory.getLogger(AiChatComponent.class);
public static final LogMetadataAttribute CHAT_ID = new LogMetadataAttribute("ChatId");
private final List chatMessages = new ArrayList<>();
private final ChatApi api;
private final ChatRepository chatRepository;
private final PromptRepository promptRepository;
private final Optional> preprocessPromptCreator;
private final Function, ChatPrompt> promptCreator;
private final ChatSummaryPromptFunction chatSummaryPromptFunction;
private final Function> commandCreator;
private final Optional> formatResponseFunction;
private final Function> functionCallHandler;
private final boolean allowRecursiveFunctionCalls;
private final AsyncRunner asyncRunner;
private Chat chatSession;
private int maxLength = Integer.MAX_VALUE;
private String chatType;
private ChatMessageContainer chatMessageContainer;
public AiChatComponent(
ChatApi api,
ChatRepository chatRepository,
PromptRepository promptRepository,
Class viewClass,
Function preprocessPromptCreator,
Function, ChatPrompt> promptCreator,
ChatSummaryPromptFunction chatSummaryFunction,
Function> commandCreator,
Function formatResponseFunction,
Function> functionCallHandler,
boolean allowRecursiveFunctionCalls,
AsyncRunner asyncRunner,
ChatMessageContainer chatMessageContainer) {
super(viewClass);
this.api = api;
this.chatRepository = chatRepository;
this.promptRepository = promptRepository;
this.preprocessPromptCreator = Optional.ofNullable(preprocessPromptCreator);
this.promptCreator = promptCreator;
this.chatSummaryPromptFunction = chatSummaryFunction;
this.commandCreator = commandCreator;
this.formatResponseFunction = Optional.ofNullable(formatResponseFunction);
this.functionCallHandler = functionCallHandler;
this.allowRecursiveFunctionCalls = allowRecursiveFunctionCalls;
this.asyncRunner = asyncRunner;
this.chatMessageContainer = chatMessageContainer;
addClassNames(LumoUtility.Width.FULL, LumoUtility.Display.FLEX, LumoUtility.Flex.AUTO, LumoUtility.Overflow.HIDDEN);
setSpacing(false);
setSizeFull();
reset();
setupChatMessageComponents();
}
public AiChatComponent(
ChatApi api,
ChatRepository chatRepository,
PromptRepository promptRepository,
Class viewClass,
Function preprocessPromptCreator,
Function, ChatPrompt> promptCreator,
ChatSummaryPromptFunction chatSummaryFunction,
Function> commandCreator,
Function formatResponseFunction,
AsyncRunner asyncRunner,
ChatMessageContainer chatMessageContainer) {
this(
api,
chatRepository,
promptRepository,
viewClass,
preprocessPromptCreator,
promptCreator,
chatSummaryFunction,
commandCreator,
formatResponseFunction,
null,
false,
asyncRunner,
chatMessageContainer);
}
public AiChatComponent(
ChatApi api,
ChatRepository chatRepository,
PromptRepository promptRepository,
Class viewClass,
Function preprocessPromptCreator,
Function, ChatPrompt> promptCreator,
ChatSummaryPromptFunction chatSummaryFunction,
Function> commandCreator,
Function formatResponseFunction,
AsyncRunner asyncRunner) {
this(
api,
chatRepository,
promptRepository,
viewClass,
preprocessPromptCreator,
promptCreator,
chatSummaryFunction,
commandCreator,
formatResponseFunction,
asyncRunner,
new ChatMessageContainer());
}
public AiChatComponent(
ChatApi api,
ChatRepository chatRepository,
PromptRepository promptRepository,
Class viewClass,
Function, ChatPrompt> promptCreator,
ChatSummaryPromptFunction chatSummaryFunction,
AsyncRunner asyncRunner) {
this(
api,
chatRepository,
promptRepository,
viewClass,
null,
promptCreator,
chatSummaryFunction,
c -> Optional.empty(),
null,
asyncRunner);
}
public AiChatComponent(
ChatApi api,
ChatRepository chatRepository,
PromptRepository promptRepository,
Class viewClass,
Function preprocessPromptCreator,
Function, ChatPrompt> promptCreator,
AsyncRunner asyncRunner) {
this(
api,
chatRepository,
promptRepository,
viewClass,
preprocessPromptCreator,
promptCreator,
ChatUtils::createSummaryPrompt,
c -> Optional.empty(),
null,
asyncRunner);
}
public AiChatComponent(
ChatApi api,
ChatRepository chatRepository,
PromptRepository promptRepository,
Class viewClass,
Function, ChatPrompt> promptCreator,
AsyncRunner asyncRunner) {
this(
api,
chatRepository,
promptRepository,
viewClass,
null,
promptCreator,
ChatUtils::createSummaryPrompt,
c -> Optional.empty(),
null,
asyncRunner);
}
@Override
protected void onAttach(AttachEvent attachEvent) {
super.onAttach(attachEvent);
chatMessageContainer.scrollToEnd();
}
@Override
public void startChatWithHistory(String chatId) {
createChatSession(chatId);
actionMenu.setEnabled(true);
}
@Override
public void reset() {
chatSession = createChat();
actionMenu.setEnabled(false);
chatMessages.clear();
chatMessageContainer.reset();
chatRating.setValue(null);
}
@Override
protected void onRatingChange(Integer newValue) {
chatSession.setChatRating(newValue);
chatRepository.save(chatSession);
}
public void setInputEnabled(boolean enabled) {
chatMessageContainer.setInputEnabled(enabled);
}
public void setMaxLength(int maxLength) {
this.maxLength = maxLength;
}
public void setChatType(String chatType) {
this.chatType = chatType;
if (chatSession != null)
chatSession.setChatType(chatType);
}
public Chat getChatSession() {
return chatSession;
}
public void saveMessage(ChatMessageRole role, String message, Prompt prompt, Model model) {
ChatMessage chatMessage = new ChatMessage(role, message, prompt, model);
chatSession.addChatMessage(chatMessage);
chatSession = chatRepository.save(chatSession);
}
public void saveMessage(ChatMessageRole role, String message) {
saveMessage(role, message, null, null);
}
private void setupChatMessageComponents() {
chatMessageContainer.addSubmitListener(this::handleChatMessageInput);
add(chatMessageContainer);
}
private void handleChatMessageInput(SubmitEvent event) {
try (var c = CHAT_ID.setForScope(chatSession.getChatId())) {
String text = event.getValue();
if (text == null || text.isEmpty())
return;
if (text.length() > this.maxLength) {
Notification.show(
getTranslation("common.messageTooLong", this.maxLength),
6000,
Notification.Position.MIDDLE);
return;
}
if (!actionMenu.isEnabled())
actionMenu.setEnabled(true);
showMessage(text, AuthUtils.getUsername());
getUI().ifPresent(UI::push);
Optional command = ChatUtils.parseCommand(text);
command.flatMap(commandCreator).ifPresentOrElse(
this::displayCommandResponse,
() -> handleChatMessage(text));
}
}
private void createChatSession(String chatId) {
if (chatId == null)
return;
chatSession = chatRepository.findChatByChatIdAndUsername(chatId, AuthUtils.getEmail());
if (chatSession != null) {
chatSession.getChatMessages().stream()
.filter(m -> m.getRole() == ChatMessageRole.ASSISTANT || m.getRole() == ChatMessageRole.USER)
.forEach(m -> {
chatMessageContainer.addItem(m.getSendTime(), m.getMessage(), ChatUtils.inferUsername(chatSession, m), false, m.getRole() == ChatMessageRole.ASSISTANT);
chatMessages.add(new Message(Role.valueOf(m.getRole().name()), m.getMessage()));
});
chatRating.setValue(chatSession.getChatRating());
getUI().ifPresent(UI::push);
}
else {
chatSession = createChat();
}
}
private void handleChatMessage(String message) {
try (MessageChain mc = MessageChainUtils.startMessageChain()) {
boolean isFirstChatMessage = !chatSession.hasMessages();
String preprocessedInput = preprocessMessage(message);
chatMessages.add(Message.user(preprocessedInput));
ChatPrompt chatPrompt = promptCreator.apply(chatMessages);
Prompt prompt = new Prompt(OpenAiRequestGenerator.generate(chatPrompt), chatPrompt.tokenCount());
promptRepository.save(prompt);
if (isFirstChatMessage) {
chatPrompt.messages().stream()
.filter(m -> Role.SYSTEM == m.getRole())
.findFirst()
.ifPresent(s -> saveMessage(ChatMessageRole.SYSTEM, s.getContent()));
chatSession.setSummary("%s%s".formatted(
message.substring(0, Math.min(message.length(), 47)),
message.length() > 47 ? "..." : ""));
}
saveMessage(ChatMessageRole.USER, message, prompt, chatPrompt.model());
String assistantResponse = handleResponse(
chatPrompt,
prompt,
chatPrompt.model());
if (isFirstChatMessage)
asyncRunner.run(() -> summarizeChat(preprocessedInput, assistantResponse, chatPrompt), ContextUtils.getContext());
}
}
private String preprocessMessage(String message) {
return preprocessPromptCreator
.map(c -> api.send(c.apply(message)))
.filter(ChatResponse::isSuccess)
.flatMap(ChatResponse::getMessage)
.map(Message::getContent)
.orElse(message);
}
private String handleResponse(ChatPrompt chatPrompt, Prompt prompt, Model model) {
ChatResponse response = api.send(chatPrompt);
String assistantResponse = null;
StringBuilder builder = new StringBuilder();
response.addSubscriber(new ContentSubscriber() {
@Override
public void onContent(String chunk) {
if (!chunk.isEmpty()) {
builder.append(chunk);
chatMessageContainer.addItem("%s...".formatted(builder.toString()), "AI", builder.length() > chunk.length(), false);
chatMessageContainer.scrollToEnd();
getUI().ifPresent(UI::push);
}
}
@Override
public void onError(Throwable throwable) {
LOG.error("Error with response: " + throwable.getMessage());
getUI().ifPresent(ui -> ui.access(() ->
chatMessageContainer.addItem(getTranslation("component.baseAi.error"), "AI", false, false)
));
}
});
if (response.isSuccess()) {
assistantResponse = response.getMessage()
.map(Message::getContent)
.orElse(null);
if (StringUtils.hasText(assistantResponse)) {
String finalAssistantResponse = assistantResponse;
assistantResponse = formatResponseFunction.map(f -> f.apply(finalAssistantResponse)).orElse(assistantResponse);
chatMessageContainer.addItem(assistantResponse, "AI", true, true);
chatMessages.add(Message.assistant(assistantResponse));
chatMessageContainer.scrollToEnd();
getUI().ifPresent(UI::push);
saveMessage(ChatMessageRole.ASSISTANT, assistantResponse, prompt, model);
}
if (functionCallHandler != null) {
List functionCallResults = response.getFunctionCalls()
.stream()
.map(functionCallHandler::apply)
.filter(Optional::isPresent)
.map(Optional::get)
.toList();
if (!functionCallResults.isEmpty()) {
chatMessages.addAll(functionCallResults);
ChatPrompt.Builder afterFunctionsPromptBuilder = allowRecursiveFunctionCalls ?
chatPrompt.builder() :
cloneChatPromtWithoutFunctions(chatPrompt);
ChatPrompt afterFunctionsPrompt = afterFunctionsPromptBuilder
.addAll(functionCallResults)
.build();
return handleResponse(afterFunctionsPrompt, prompt, model);
}
}
}
else {
saveMessage(ChatMessageRole.ERROR, response.getResultState());
}
return assistantResponse;
}
private static ChatPrompt.Builder cloneChatPromtWithoutFunctions(ChatPrompt original) {
ChatPrompt.Builder result = ChatPrompt.builder(original.model());
result.addAll(original.messages());
for (Map.Entry entry : original.parameters().entrySet())
result.setParameter(entry.getKey(), entry.getValue());
result.addStopList(original.stopList());
result.addLogitBias(original.logitBias());
return result;
}
private void summarizeChat(String userMessage, String assistantResponse, ChatPrompt chatPrompt) {
ChatPrompt summaryPrompt = chatSummaryPromptFunction.apply(
chatPrompt.model(),
chatPrompt.getStringProperty("provider"),
userMessage,
assistantResponse);
try {
ChatResponse response = api.send(summaryPrompt);
if (response.isSuccess()) {
response.getMessage()
.map(Message::getContent)
.ifPresent(summary -> {
chatSession.setSummary(summary);
chatRepository.saveAndFlush(chatSession);
});
}
}
catch (Exception e) {
LOG.error("Error creating summary: " + e.getMessage());
}
}
private void displayCommandResponse(String commandResponse) {
showMessage(commandResponse, "AI");
getUI().ifPresent(UI::push);
}
private void showMessage(String messageText, String sender) {
chatMessageContainer.addItem(messageText, sender, false, false);
chatMessageContainer.scrollToEnd();
}
private Chat createChat() {
return new Chat(chatType != null ? chatType : viewClass.getSimpleName());
}
}