All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.ai.vaadin.component.AiAssistantComponent Maven / Gradle / Ivy

There is a newer version: 1.5.5
Show newest version
package fi.evolver.ai.vaadin.component;

import java.io.IOException;
import java.io.InputStream;
import java.io.Serial;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vaadin.flow.component.Component;
import com.vaadin.flow.component.DetachEvent;
import com.vaadin.flow.component.UI;
import com.vaadin.flow.component.button.Button;
import com.vaadin.flow.component.button.ButtonVariant;
import com.vaadin.flow.component.dialog.Dialog;
import com.vaadin.flow.component.icon.Icon;
import com.vaadin.flow.component.icon.VaadinIcon;
import com.vaadin.flow.component.messages.MessageInput.SubmitEvent;
import com.vaadin.flow.component.notification.Notification;
import com.vaadin.flow.component.notification.NotificationVariant;
import com.vaadin.flow.component.orderedlayout.HorizontalLayout;
import com.vaadin.flow.component.upload.Upload;
import com.vaadin.flow.component.upload.UploadI18N;
import com.vaadin.flow.component.upload.receivers.MultiFileMemoryBuffer;
import com.vaadin.flow.theme.lumo.LumoUtility;

import elemental.json.JsonObject;
import fi.evolver.ai.spring.AsyncRunner;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.assistant.Assistant;
import fi.evolver.ai.spring.assistant.AssistantApi;
import fi.evolver.ai.spring.assistant.AssistantPrompt;
import fi.evolver.ai.spring.assistant.AssistantResponse;
import fi.evolver.ai.spring.chat.ChatApi;
import fi.evolver.ai.spring.chat.ChatResponse;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt;
import fi.evolver.ai.spring.chat.prompt.Message;
import fi.evolver.ai.spring.file.AiFile;
import fi.evolver.ai.spring.provider.openai.OpenAiAssistantResponse.AssistantContent;
import fi.evolver.ai.spring.provider.openai.OpenAiAssistantResponse.Attachment;
import fi.evolver.ai.spring.provider.openai.OpenAiAssistantResponse.AttachmentType;
import fi.evolver.ai.spring.provider.openai.OpenAiRequestParameters;
import fi.evolver.ai.vaadin.ChatAttachmentRepository;
import fi.evolver.ai.vaadin.ChatRepository;
import fi.evolver.ai.vaadin.component.i18n.UploadI18nFactory;
import fi.evolver.ai.vaadin.entity.Chat;
import fi.evolver.ai.vaadin.entity.ChatMessage;
import fi.evolver.ai.vaadin.entity.ChatMessage.ChatMessageRole;
import fi.evolver.ai.vaadin.util.AuthUtils;
import fi.evolver.ai.vaadin.util.ChatUtils;
import fi.evolver.ai.vaadin.view.HistoryAwareChat;
import fi.evolver.basics.spring.log.LogMetadataAttribute;
import fi.evolver.basics.spring.util.MessageChainUtils;
import fi.evolver.basics.spring.util.MessageChainUtils.MessageChain;
import fi.evolver.utils.ContextUtils;

public class AiAssistantComponent extends BaseChatComponent implements HistoryAwareChat {
	@Serial
	private static final long serialVersionUID = 1L;

	private static final Logger LOG = LoggerFactory.getLogger(AiAssistantComponent.class);
	private static final LogMetadataAttribute CHAT_ID = new LogMetadataAttribute("ChatId");

	private final Dialog fileInfoDialog;

	private final ChatApi chatApi;
	private final AssistantApi assistantApi;
	private final AssistantPrompt assistantPrompt;
	private final Model summaryModel;
	private final ChatRepository chatRepository;
	@SuppressWarnings("unused")
	private final ChatAttachmentRepository chatAttachmentRepository;
	private final AsyncRunner asyncRunner;

	private Chat chatSession;
	private int maxLength = Integer.MAX_VALUE;
	private String chatType;
	private ChatMessageContainer chatMessageContainer;
	private Assistant assistant;
	private Map fileIdByName = new HashMap<>();
	private Upload upload;

	public AiAssistantComponent(
			ChatApi chatApi,
			AssistantApi assistantApi,
			AssistantPrompt assistantPrompt,
			Model summaryModel,
			ChatRepository chatRepository,
			ChatAttachmentRepository chatAttachmentRepository,
			Class viewClass,
			AsyncRunner asyncRunner) {
		super(viewClass);
		this.chatApi = chatApi;
		this.assistantApi = assistantApi;
		this.assistantPrompt = assistantPrompt;
		this.summaryModel = summaryModel;
		this.chatRepository = chatRepository;
		this.chatAttachmentRepository = chatAttachmentRepository;
		this.fileInfoDialog = FileInfoDialogUtils.createFileInfoDialog(this::getTranslation);
		this.asyncRunner = asyncRunner;
		this.chatMessageContainer = new FileChatMessageContainer(createFileUploadComponent());

		addClassNames(LumoUtility.Width.FULL, LumoUtility.Display.FLEX, LumoUtility.Flex.AUTO);
		setSpacing(false);
		setSizeFull();

		reset();
		setupChatMessageComponents();
	}


	@Override
	protected void onDetach(DetachEvent detachEvent) {
		assistantCleanUp();
	}


	@Override
	public void startChatWithHistory(String chatId) {
		createChatSession(chatId);
		chatMessageContainer.scrollToEnd();
		actionMenu.setEnabled(true);
		chatMessageContainer.setInputEnabled(false); // Continuing an assistant based chat from history is not yet supported!
	}


	@Override
	public void reset() {
		assistantCleanUp();

		assistant = assistantApi.createAssistant(assistantPrompt);
		chatSession = createChat();
		actionMenu.setEnabled(false);
		chatMessageContainer.reset();
		fileIdByName.clear();
		if (upload != null)
			upload.clearFileList();
		chatMessageContainer.setInputEnabled(false);
		chatRating.setValue(null);
	}

	@Override
	protected void onRatingChange(Integer newValue) {
		chatSession.setChatRating(newValue);
		chatRepository.save(chatSession);
	}


	private void assistantCleanUp() {
		if (assistant != null)
			assistant.close();
	}


	private void saveMessage(ChatMessageRole role, String message, Model model) {
		ChatMessage chatMessage = new ChatMessage(role, message, null, model);
		chatSession.addChatMessage(chatMessage);
		chatSession = chatRepository.save(chatSession);
	}


	private void setupChatMessageComponents() {
		chatMessageContainer.addSubmitListener(this::handleChatMessageInput);
		add(chatMessageContainer);
	}


	private Component createFileUploadComponent() {
		HorizontalLayout uploadLayout = new HorizontalLayout();

		MultiFileMemoryBuffer buffer = new MultiFileMemoryBuffer();
		upload = new Upload(buffer);
		upload.getStyle().set("overflow", "unset");
		upload.setMaxFiles(10);

		int maxFileSizeInBytes = 20 * 1024 * 1024; // 20MB
		upload.setMaxFileSize(maxFileSizeInBytes);
		UploadI18N i18n = UploadI18nFactory.getI18n(this::getTranslation, "20MB");
		upload.setI18n(i18n);

		upload.addFileRejectedListener(event -> {
			Notification notification = Notification.show(event.getErrorMessage(), 5000, Notification.Position.MIDDLE);
			notification.addThemeVariants(NotificationVariant.LUMO_ERROR);
		});

		upload.addSucceededListener(event -> {
			try {
				String filename = event.getFileName();
				InputStream inputStream = buffer.getInputStream(filename);
				AiFile aiFile = new AiFile(inputStream.readAllBytes(), event.getMIMEType(), event.getFileName());
				fileIdByName.put(filename, assistant.addFile(aiFile));
				addChatMessage(getTranslation("component.aiAssistant.fileAdded", filename), "AI");
				chatMessageContainer.setInputEnabled(true);
			}
			catch (IOException e) {
				LOG.error("Failed handling file: {}", event.getFileName(), e);
				addChatMessage(getTranslation("component.aiAssistant.fileUploadFailed"), "AI");
			}
		});

		upload.getElement().addEventListener("file-remove", event -> {
			JsonObject eventData = event.getEventData();
			String filename = eventData.getString("event.detail.file.name");
			String fileId = fileIdByName.remove(filename);
			if (fileId != null)
				assistant.deleteFile(fileId);
			if (fileIdByName.isEmpty())
				chatMessageContainer.setInputEnabled(false);
			chatMessageContainer.addItem(getTranslation("component.aiAssistant.fileRemoved", filename), "AI", false, false);
			chatMessageContainer.scrollToEnd();
		}).addEventData("event.detail.file.name");

		Button infoButton = new Button(new Icon(VaadinIcon.INFO_CIRCLE));
		infoButton.addThemeVariants(ButtonVariant.LUMO_TERTIARY_INLINE, ButtonVariant.LUMO_ICON);
		infoButton.addClickListener(event -> {
			fileInfoDialog.open();
		});

		uploadLayout.add(upload, infoButton);
		return uploadLayout;
	}


	private void handleChatMessageInput(SubmitEvent event) {
		try (var c = CHAT_ID.setForScope(chatSession.getChatId())) {
			String text = event.getValue();

			if (text == null || text.isEmpty())
				return;

			if (text.length() > this.maxLength) {
				Notification.show(
					getTranslation("common.messageTooLong", this.maxLength),
						6000,
						Notification.Position.MIDDLE);
				return;
			}

			if (!actionMenu.isEnabled())
				actionMenu.setEnabled(true);

			addChatMessage(text, AuthUtils.getUsername());
			getUI().ifPresent(UI::push);

			handleChatMessage(text);
		}
	}

	private void createChatSession(String chatId) {
		if (chatId == null)
			return;

		chatSession = chatRepository.findChatByChatIdAndUsername(chatId, AuthUtils.getEmail());
		if (chatSession != null) {
			chatSession.getChatMessages().stream()
				.filter(m -> m.getRole() == ChatMessageRole.ASSISTANT || m.getRole() == ChatMessageRole.USER)
				.forEach(m -> {
					chatMessageContainer.addItem(m.getSendTime(), m.getMessage(), ChatUtils.inferUsername(chatSession, m), false, m.getRole() == ChatMessageRole.ASSISTANT);
				});
			getUI().ifPresent(UI::push);
		}
		else {
			chatSession = createChat();
		}
	}


	private void handleChatMessage(String text) {
		try (MessageChain mc = MessageChainUtils.startMessageChain()) {
			boolean isFirstChatMessage = chatSession.getChatMessages().isEmpty();
			if (isFirstChatMessage)
				chatSession.setSummary(text.substring(0, Math.min(text.length(), 47)));

			saveMessage(ChatMessageRole.USER, text, assistantPrompt.model());

			handleResponse(assistant.ask(createAssistantQuestion(text)), assistantPrompt.model(), text, isFirstChatMessage);
		}
	}


	private String createAssistantQuestion(String text) {
		return "%s\n\nUploaded files: %s".formatted(
				text,
				fileIdByName.entrySet().stream()
					.map(e -> "%s = %s".formatted(e.getKey(), e.getValue()))
					.collect(Collectors.joining(",")));
	}


	private void handleResponse(AssistantResponse response, Model model, String text, boolean isFirstChatMessage) {
		StringBuilder builder = new StringBuilder();
		List contentList = new ArrayList<>();

		response.addSubscriber(r -> {
			contentList.add(r);
			if (r.content() != null) {
				builder.append(r.content());
				addChatMessage("%s...".formatted(builder.toString()), "AI", builder.length() > r.content().length(), false);
				getUI().ifPresent(UI::push);
			}
		});

		if (response.isSuccess() && !builder.isEmpty()) {
			String formattedMessage = handleAttachments(assistant.getProvider(), builder.toString(), contentList.stream().flatMap(c -> c.attachments().stream()).toList());
			addChatMessage(formattedMessage, "AI", true, true);
			getUI().ifPresent(UI::push);

			saveMessage(ChatMessageRole.ASSISTANT, formattedMessage, model);

			if (isFirstChatMessage)
				asyncRunner.run(() -> summarizeChat(text, builder.toString()), ContextUtils.getContext());
		}
		else {
			saveMessage(ChatMessageRole.ERROR, response.getResultState(), null);
		}
	}


	private static String handleAttachments(String provider, String message, List attachments) {
		for (Attachment attachment : attachments) {
			if (AttachmentType.FILE.equals(attachment.type()) && attachment.anchor() != null && attachment.externalReference() != null) {
				String downloadLink = "/api/oai/file/%s?provider=%s&filename=%s".formatted(
						attachment.externalReference(),
						provider,
						parseFilename(attachment.anchor()));
				message = message.replaceAll(
						"\\[(.*?)\\]\\((%s)\\)".formatted(Pattern.quote(attachment.anchor())),
						"[$1](%s)".formatted(downloadLink));
				message = message.replaceAll(Pattern.quote(attachment.anchor()), downloadLink);
			}
		}

		return message;
	}


	private static String parseFilename(String anchor) {
		int startIndex = anchor.lastIndexOf('/');
		if (startIndex == -1)
			return anchor;
		if (startIndex == anchor.length() - 1)
			return "";
		return anchor.substring(startIndex + 1);
	}


	private void summarizeChat(String userMessage, String assistantResponse) {
		ChatPrompt.Builder builder = ChatPrompt.builder(summaryModel)
				.add(Message.system("Summarize the following conversation in less than 15 words. Ensure carefully that you understand the language of the question words the user is writing! This is the main language in which you MUST write the summary. The summary MUST be shorter than 15 words!"))
				.add(Message.user(userMessage))
				.add(Message.assistant(assistantResponse));
		assistantPrompt.getStringProperty(OpenAiRequestParameters.PROVIDER)
			.ifPresent(p -> builder.setParameter(OpenAiRequestParameters.PROVIDER, p));

		ChatResponse response = chatApi.send(builder.build());
		if (response.isSuccess()) {
			String summary = response.getMessage().map(Message::getContent).orElse(null);
			if (summary != null) {
				chatSession.setSummary(summary);
				chatRepository.saveAndFlush(chatSession);
			}
		}
	}


	private void addChatMessage(String messageText, String sender, boolean isReplacement, boolean convertToHtml) {
		chatMessageContainer.addItem(messageText, sender, isReplacement, convertToHtml);
		chatMessageContainer.scrollToEnd();
	}


	private void addChatMessage(String messageText, String sender) {
		addChatMessage(messageText, sender, false, false);
	}


	private Chat createChat() {
		return new Chat(chatType != null ? chatType : viewClass.getSimpleName());
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy