All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.ai.vaadin.component.AiChatComponent Maven / Gradle / Ivy

The newest version!
package fi.evolver.ai.vaadin.component;

import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vaadin.flow.component.AttachEvent;
import com.vaadin.flow.component.Component;
import com.vaadin.flow.component.UI;
import com.vaadin.flow.component.messages.MessageInput.SubmitEvent;
import com.vaadin.flow.component.notification.Notification;
import com.vaadin.flow.theme.lumo.LumoUtility;

import fi.evolver.ai.spring.AsyncRunner;
import fi.evolver.ai.spring.ContentSubscriber;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.chat.ChatApi;
import fi.evolver.ai.spring.chat.ChatResponse;
import fi.evolver.ai.spring.chat.FunctionCall;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt;
import fi.evolver.ai.spring.chat.prompt.Message;
import fi.evolver.ai.spring.chat.prompt.Message.Role;
import fi.evolver.ai.spring.provider.openai.OpenAiRequestGenerator;
import fi.evolver.ai.vaadin.ChatRepository;
import fi.evolver.ai.vaadin.PromptRepository;
import fi.evolver.ai.vaadin.entity.Chat;
import fi.evolver.ai.vaadin.entity.ChatMessage;
import fi.evolver.ai.vaadin.entity.ChatMessage.ChatMessageRole;
import fi.evolver.ai.vaadin.entity.Prompt;
import fi.evolver.ai.vaadin.util.AuthUtils;
import fi.evolver.ai.vaadin.util.ChatUtils;
import fi.evolver.ai.vaadin.util.ChatUtils.CommandType;
import fi.evolver.ai.vaadin.view.HistoryAwareChat;
import fi.evolver.basics.spring.log.LogMetadataAttribute;
import fi.evolver.basics.spring.util.MessageChainUtils;
import fi.evolver.basics.spring.util.MessageChainUtils.MessageChain;
import fi.evolver.utils.ContextUtils;
import fi.evolver.utils.string.StringUtils;


public class AiChatComponent extends BaseChatComponent implements HistoryAwareChat {
	@Serial
	private static final long serialVersionUID = 1L;
	private static final Logger LOG = LoggerFactory.getLogger(AiChatComponent.class);

	public static final LogMetadataAttribute CHAT_ID = new LogMetadataAttribute("ChatId");

	private final List chatMessages = new ArrayList<>();

	private final ChatApi api;
	private final ChatRepository chatRepository;
	private final PromptRepository promptRepository;
	private final Optional> preprocessPromptCreator;
	private final Function, ChatPrompt> promptCreator;
	private final ChatSummaryPromptFunction chatSummaryPromptFunction;
	private final Function> commandCreator;
	private final Optional> formatResponseFunction;
	private final Function> functionCallHandler;
	private final boolean allowRecursiveFunctionCalls;
	private final AsyncRunner asyncRunner;

	private Chat chatSession;
	private int maxLength = Integer.MAX_VALUE;
	private String chatType;
	private ChatMessageContainer chatMessageContainer;

	public AiChatComponent(
			ChatApi api,
			ChatRepository chatRepository,
			PromptRepository promptRepository,
			Class viewClass,
			Function preprocessPromptCreator,
			Function, ChatPrompt> promptCreator,
			ChatSummaryPromptFunction chatSummaryFunction,
			Function> commandCreator,
			Function formatResponseFunction,
			Function> functionCallHandler,
			boolean allowRecursiveFunctionCalls,
			AsyncRunner asyncRunner,
			ChatMessageContainer chatMessageContainer) {

		super(viewClass);

		this.api = api;
		this.chatRepository = chatRepository;
		this.promptRepository = promptRepository;
		this.preprocessPromptCreator = Optional.ofNullable(preprocessPromptCreator);
		this.promptCreator = promptCreator;
		this.chatSummaryPromptFunction = chatSummaryFunction;
		this.commandCreator = commandCreator;
		this.formatResponseFunction = Optional.ofNullable(formatResponseFunction);
		this.functionCallHandler = functionCallHandler;
		this.allowRecursiveFunctionCalls = allowRecursiveFunctionCalls;
		this.asyncRunner = asyncRunner;
		this.chatMessageContainer = chatMessageContainer;

		addClassNames(LumoUtility.Width.FULL, LumoUtility.Display.FLEX, LumoUtility.Flex.AUTO, LumoUtility.Overflow.HIDDEN);
		setSpacing(false);
		setSizeFull();

		reset();
		setupChatMessageComponents();
	}

	public AiChatComponent(
			ChatApi api,
			ChatRepository chatRepository,
			PromptRepository promptRepository,
			Class viewClass,
			Function preprocessPromptCreator,
			Function, ChatPrompt> promptCreator,
			ChatSummaryPromptFunction chatSummaryFunction,
			Function> commandCreator,
			Function formatResponseFunction,
			AsyncRunner asyncRunner,
			ChatMessageContainer chatMessageContainer) {

		this(
				api,
				chatRepository,
				promptRepository,
				viewClass,
				preprocessPromptCreator,
				promptCreator,
				chatSummaryFunction,
				commandCreator,
				formatResponseFunction,
				null,
				false,
				asyncRunner,
				chatMessageContainer);
	}

	public AiChatComponent(
			ChatApi api,
			ChatRepository chatRepository,
			PromptRepository promptRepository,
			Class viewClass,
			Function preprocessPromptCreator,
			Function, ChatPrompt> promptCreator,
			ChatSummaryPromptFunction chatSummaryFunction,
			Function> commandCreator,
			Function formatResponseFunction,
			AsyncRunner asyncRunner) {

		this(
				api,
				chatRepository,
				promptRepository,
				viewClass,
				preprocessPromptCreator,
				promptCreator,
				chatSummaryFunction,
				commandCreator,
				formatResponseFunction,
				asyncRunner,
				new ChatMessageContainer());
	}

	public AiChatComponent(
			ChatApi api,
			ChatRepository chatRepository,
			PromptRepository promptRepository,
			Class viewClass,
			Function, ChatPrompt> promptCreator,
			ChatSummaryPromptFunction chatSummaryFunction,
			AsyncRunner asyncRunner) {

		this(
				api,
				chatRepository,
				promptRepository,
				viewClass,
				null,
				promptCreator,
				chatSummaryFunction,
				c -> Optional.empty(),
				null,
				asyncRunner);
	}

	public AiChatComponent(
			ChatApi api,
			ChatRepository chatRepository,
			PromptRepository promptRepository,
			Class viewClass,
			Function preprocessPromptCreator,
			Function, ChatPrompt> promptCreator,
			AsyncRunner asyncRunner) {

		this(
				api,
				chatRepository,
				promptRepository,
				viewClass,
				preprocessPromptCreator,
				promptCreator,
				ChatUtils::createSummaryPrompt,
				c -> Optional.empty(),
				null,
				asyncRunner);
	}

	public AiChatComponent(
			ChatApi api,
			ChatRepository chatRepository,
			PromptRepository promptRepository,
			Class viewClass,
			Function, ChatPrompt> promptCreator,
			AsyncRunner asyncRunner) {

		this(
				api,
				chatRepository,
				promptRepository,
				viewClass,
				null,
				promptCreator,
				ChatUtils::createSummaryPrompt,
				c -> Optional.empty(),
				null,
				asyncRunner);
	}


	@Override
	protected void onAttach(AttachEvent attachEvent) {
		super.onAttach(attachEvent);
		chatMessageContainer.scrollToEnd();
	}


	@Override
	public void startChatWithHistory(String chatId) {
		createChatSession(chatId);
		actionMenu.setEnabled(true);
	}


	@Override
	public void reset() {
		chatSession = createChat();
		actionMenu.setEnabled(false);
		chatMessages.clear();
		chatMessageContainer.reset();
		chatRating.setValue(null);
	}

	@Override
	protected void onRatingChange(Integer newValue) {
		chatSession.setChatRating(newValue);
		chatRepository.save(chatSession);
	}


	public void setInputEnabled(boolean enabled) {
		chatMessageContainer.setInputEnabled(enabled);
	}


	public void setMaxLength(int maxLength) {
		this.maxLength = maxLength;
	}


	public void setChatType(String chatType) {
		this.chatType = chatType;
		if (chatSession != null)
			chatSession.setChatType(chatType);
	}


	public Chat getChatSession() {
		return chatSession;
	}


	public void saveMessage(ChatMessageRole role, String message, Prompt prompt, Model model) {
		ChatMessage chatMessage = new ChatMessage(role, message, prompt, model);
		chatSession.addChatMessage(chatMessage);
		chatSession = chatRepository.save(chatSession);
	}


	public void saveMessage(ChatMessageRole role, String message) {
		saveMessage(role, message, null, null);
	}


	private void setupChatMessageComponents() {
		chatMessageContainer.addSubmitListener(this::handleChatMessageInput);
		add(chatMessageContainer);
	}


	private void handleChatMessageInput(SubmitEvent event) {
		try (var c = CHAT_ID.setForScope(chatSession.getChatId())) {
			String text = event.getValue();

			if (text == null || text.isEmpty())
				return;

			if (text.length() > this.maxLength) {
				Notification.show(
						getTranslation("common.messageTooLong", this.maxLength),
						6000,
						Notification.Position.MIDDLE);
				return;
			}

			if (!actionMenu.isEnabled())
				actionMenu.setEnabled(true);

			showMessage(text, AuthUtils.getUsername());
			getUI().ifPresent(UI::push);

			Optional command = ChatUtils.parseCommand(text);
			command.flatMap(commandCreator).ifPresentOrElse(
					this::displayCommandResponse,
					() -> handleChatMessage(text));
		}
	}

	private void createChatSession(String chatId) {
		if (chatId == null)
			return;

		chatSession = chatRepository.findChatByChatIdAndUsername(chatId, AuthUtils.getEmail());
		if (chatSession != null) {
			chatSession.getChatMessages().stream()
				.filter(m -> m.getRole() == ChatMessageRole.ASSISTANT || m.getRole() == ChatMessageRole.USER)
				.forEach(m -> {
					chatMessageContainer.addItem(m.getSendTime(), m.getMessage(), ChatUtils.inferUsername(chatSession, m), false, m.getRole() == ChatMessageRole.ASSISTANT);
					chatMessages.add(new Message(Role.valueOf(m.getRole().name()), m.getMessage()));
			});
			chatRating.setValue(chatSession.getChatRating());
			getUI().ifPresent(UI::push);
		}
		else {
			chatSession = createChat();
		}
	}


	private void handleChatMessage(String message) {
		try (MessageChain mc = MessageChainUtils.startMessageChain()) {
			boolean isFirstChatMessage = !chatSession.hasMessages();

			String preprocessedInput = preprocessMessage(message);
			chatMessages.add(Message.user(preprocessedInput));
			ChatPrompt chatPrompt = promptCreator.apply(chatMessages);

			Prompt prompt = new Prompt(OpenAiRequestGenerator.generate(chatPrompt), chatPrompt.tokenCount());
			promptRepository.save(prompt);

			if (isFirstChatMessage) {
				chatPrompt.messages().stream()
						.filter(m -> Role.SYSTEM == m.getRole())
						.findFirst()
						.ifPresent(s -> saveMessage(ChatMessageRole.SYSTEM, s.getContent()));

				chatSession.setSummary("%s%s".formatted(
						message.substring(0, Math.min(message.length(), 47)),
						message.length() > 47 ? "..." : ""));
			}

			saveMessage(ChatMessageRole.USER, message, prompt, chatPrompt.model());
			String assistantResponse = handleResponse(
					chatPrompt,
					prompt,
					chatPrompt.model());
			if (isFirstChatMessage)
				asyncRunner.run(() -> summarizeChat(preprocessedInput, assistantResponse, chatPrompt), ContextUtils.getContext());
		}
	}


	private String preprocessMessage(String message) {
		return preprocessPromptCreator
			.map(c -> api.send(c.apply(message)))
			.filter(ChatResponse::isSuccess)
			.flatMap(ChatResponse::getMessage)
			.map(Message::getContent)
			.orElse(message);
	}


	private String handleResponse(ChatPrompt chatPrompt, Prompt prompt, Model model) {
		ChatResponse response = api.send(chatPrompt);
		String assistantResponse = null;
		StringBuilder builder = new StringBuilder();

		response.addSubscriber(new ContentSubscriber() {
			@Override
			public void onContent(String chunk) {
				if (!chunk.isEmpty()) {
					builder.append(chunk);
					chatMessageContainer.addItem("%s...".formatted(builder.toString()), "AI", builder.length() > chunk.length(), false);
					chatMessageContainer.scrollToEnd();
					getUI().ifPresent(UI::push);
				}
			}

			@Override
			public void onError(Throwable throwable) {
				LOG.error("Error with response: " + throwable.getMessage());
				getUI().ifPresent(ui -> ui.access(() ->
					chatMessageContainer.addItem(getTranslation("component.baseAi.error"), "AI", false, false)
				));
			}
		});

		if (response.isSuccess()) {
			assistantResponse = response.getMessage()
					.map(Message::getContent)
					.orElse(null);

			if (StringUtils.hasText(assistantResponse)) {
				String finalAssistantResponse = assistantResponse;
				assistantResponse = formatResponseFunction.map(f -> f.apply(finalAssistantResponse)).orElse(assistantResponse);
				chatMessageContainer.addItem(assistantResponse, "AI", true, true);
				chatMessages.add(Message.assistant(assistantResponse));
				chatMessageContainer.scrollToEnd();
				getUI().ifPresent(UI::push);

				saveMessage(ChatMessageRole.ASSISTANT, assistantResponse, prompt, model);
			}

			if (functionCallHandler != null) {
				List functionCallResults = response.getFunctionCalls()
					.stream()
					.map(functionCallHandler::apply)
					.filter(Optional::isPresent)
					.map(Optional::get)
					.toList();

				if (!functionCallResults.isEmpty()) {
					chatMessages.addAll(functionCallResults);
					ChatPrompt.Builder afterFunctionsPromptBuilder = allowRecursiveFunctionCalls ?
							chatPrompt.builder() :
							cloneChatPromtWithoutFunctions(chatPrompt);

					ChatPrompt afterFunctionsPrompt = afterFunctionsPromptBuilder
							.addAll(functionCallResults)
							.build();

					return handleResponse(afterFunctionsPrompt, prompt, model);
				}
			}
		}
		else {
			saveMessage(ChatMessageRole.ERROR, response.getResultState());
		}

		return assistantResponse;
	}

	private static ChatPrompt.Builder cloneChatPromtWithoutFunctions(ChatPrompt original) {
		ChatPrompt.Builder result = ChatPrompt.builder(original.model());

		result.addAll(original.messages());
		for (Map.Entry entry : original.parameters().entrySet())
			result.setParameter(entry.getKey(), entry.getValue());
		result.addStopList(original.stopList());
		result.addLogitBias(original.logitBias());

		return result;
	}


	private void summarizeChat(String userMessage, String assistantResponse, ChatPrompt chatPrompt) {
		ChatPrompt summaryPrompt = chatSummaryPromptFunction.apply(
				chatPrompt.model(),
				chatPrompt.getStringProperty("provider"),
				userMessage,
				assistantResponse);
		try {
			ChatResponse response = api.send(summaryPrompt);
			if (response.isSuccess()) {
				response.getMessage()
						.map(Message::getContent)
						.ifPresent(summary -> {
							chatSession.setSummary(summary);
							chatRepository.saveAndFlush(chatSession);
						});
			}
		}
		catch (Exception e) {
			LOG.error("Error creating summary: " + e.getMessage());
		}
	}


	private void displayCommandResponse(String commandResponse) {
		showMessage(commandResponse, "AI");
		getUI().ifPresent(UI::push);
	}


	private void showMessage(String messageText, String sender) {
		chatMessageContainer.addItem(messageText, sender, false, false);
		chatMessageContainer.scrollToEnd();
	}


	private Chat createChat() {
		return new Chat(chatType != null ? chatType : viewClass.getSimpleName());
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy