All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.ai.vaadin.component.AiImageComponent Maven / Gradle / Ivy

The newest version!
package fi.evolver.ai.vaadin.component;

import java.io.ByteArrayInputStream;
import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;

import org.vaadin.olli.FileDownloadWrapper;

import com.vaadin.flow.component.Component;
import com.vaadin.flow.component.ScrollOptions;
import com.vaadin.flow.component.UI;
import com.vaadin.flow.component.button.Button;
import com.vaadin.flow.component.html.Image;
import com.vaadin.flow.component.messages.MessageInput;
import com.vaadin.flow.component.messages.MessageList;
import com.vaadin.flow.component.messages.MessageListItem;
import com.vaadin.flow.component.notification.Notification;
import com.vaadin.flow.component.orderedlayout.HorizontalLayout;
import com.vaadin.flow.component.orderedlayout.VerticalLayout;
import com.vaadin.flow.component.select.Select;
import com.vaadin.flow.server.StreamResource;
import com.vaadin.flow.theme.lumo.LumoUtility;

import fi.evolver.ai.spring.ApiResponseException;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.chat.prompt.Message;
import fi.evolver.ai.spring.image.ImageApi;
import fi.evolver.ai.spring.image.ImageResponse;
import fi.evolver.ai.spring.image.prompt.ImageGenerationPrompt;
import fi.evolver.ai.spring.provider.openai.OpenAiRequestGenerator;
import fi.evolver.ai.spring.provider.openai.OpenAiRequestParameters;
import fi.evolver.ai.spring.provider.openai.OpenAiService;
import fi.evolver.ai.vaadin.ChatAttachmentRepository;
import fi.evolver.ai.vaadin.ChatRepository;
import fi.evolver.ai.vaadin.PromptRepository;
import fi.evolver.ai.vaadin.component.i18n.MessageInputI18nFactory;
import fi.evolver.ai.vaadin.entity.Chat;
import fi.evolver.ai.vaadin.entity.ChatAttachment;
import fi.evolver.ai.vaadin.entity.ChatMessage;
import fi.evolver.ai.vaadin.entity.ChatMessage.ChatMessageRole;
import fi.evolver.ai.vaadin.entity.Prompt;
import fi.evolver.ai.vaadin.util.AuthUtils;
import fi.evolver.ai.vaadin.util.ChatUtils;
import fi.evolver.ai.vaadin.view.HistoryAwareChat;
import fi.evolver.utils.NullSafetyUtils;

public class AiImageComponent extends BaseChatComponent implements HistoryAwareChat {
	@Serial
	private static final long serialVersionUID = 1L;

	private static final List STYLES = List.of("natural", "vivid");
	private static final List SIZES = List.of("1024x1024", "1024x1792", "1792x1024");
	private static final List QUALITIES = List.of("standard", "hd");
	private static final int MAX_LENGTH = 1000;

	private final MessageInput chatMessageInput = new MessageInput();
	private final MessageList chatMessageList = new MessageList();
	private final List chatMessages = new ArrayList<>();
	private final Button downloadImageButton = new Button();
	private final Button recreateImageButton = new Button();

	private final ImageApi api;
	private final ChatRepository chatRepository;
	private final PromptRepository promptRepository;
	private final ChatAttachmentRepository chatAttachmentRepository;

	private FileDownloadWrapper buttonWrapper;
	private Select qualitySelect;
	private Select sizeSelect;
	private Select styleSelect;
	private String promptText;
	private Image image;

	private Chat chatSession;
	private String chatType;


	public AiImageComponent(
			ImageApi api,
			ChatRepository chatRepository,
			PromptRepository promptRepository,
			ChatAttachmentRepository chatAttachmentRepository,
			Class viewClass) {
		super(viewClass);
		this.api = api;
		this.chatRepository = chatRepository;
		this.promptRepository = promptRepository;
		this.chatAttachmentRepository = chatAttachmentRepository;

		addClassNames(LumoUtility.Width.FULL, LumoUtility.Display.FLEX, LumoUtility.Flex.AUTO);
		setSpacing(false);
		setSizeFull();

		reset();

		setupImageRequestOptions();
		setupChatMessageComponents();
		setupRecreateImageButton();
	}


	@Override
	public void startChatWithHistory(String chatId) {
		createChatSession(chatId);
		actionMenu.setEnabled(true);
		chatMessageInput.scrollIntoView(new ScrollOptions(ScrollOptions.Behavior.SMOOTH));
	}


	@Override
	public void reset() {
		chatSession = createChat();
		actionMenu.setEnabled(false);
		recreateImageButton.setVisible(false);
		chatMessages.clear();
		chatMessageList.setItems();
		chatRating.setValue(null);

		if (image != null && image.isAttached())
			remove(image);

		toggleSelectInputs(true);
		setDefaultSelectValues();
		displayChatMessageInput(true);
		displayDownload(false);
	}

	@Override
	protected void onRatingChange(Integer newValue) {
		chatSession.setChatRating(newValue);
		chatRepository.save(chatSession);
	}


	private void disableImageGenerationInputs() {
		displayChatMessageInput(false);
		toggleSelectInputs(false);
	}


	private void displayChatMessageInput(boolean isVisible) {
		chatMessageInput.setVisible(isVisible);
	}


	private void displayDownload(boolean isVisible) {
		downloadImageButton.setVisible(isVisible);
		if (buttonWrapper != null)
			buttonWrapper.setVisible(isVisible);
	}


	private void saveMessage(ChatMessage.ChatMessageRole role, String message, Prompt prompt, Model model) {
		ChatMessage chatMessage = new ChatMessage(role, message, prompt, model);
		chatMessage.setTokenCount(null);
		chatSession.addChatMessage(chatMessage);
	}


	private void saveMessage(ChatMessage.ChatMessageRole role, String message) {
		saveMessage(role, message, null, null);
	}


	private void createChatSession(String chatId) {
		if (chatId == null)
			return;

		chatSession = chatRepository.findChatByChatIdAndUsername(chatId, AuthUtils.getEmail());
		if (chatSession != null) {
			List existingMessages = new ArrayList<>();
			chatSession.getChatMessages().stream()
				.filter(m -> m.getRole() == ChatMessageRole.ASSISTANT || m.getRole() == ChatMessageRole.USER)
				.forEach(m -> {
					existingMessages.add(new MessageListItem(
							m.getMessage(),
							ChatUtils.convertToInstantFi(m.getSendTime()),
							inferUsername(chatSession, m)));
					chatMessages.add(new Message(Message.Role.valueOf(m.getRole().name()), m.getMessage()));
				});

			chatSession.getChatMessages().stream()
				.filter(m -> m.getRole() == ChatMessageRole.USER)
				.findFirst()
				.ifPresent(cm -> promptText = cm.getMessage());

			chatMessageList.setItems(existingMessages);

			disableImageGenerationInputs();
			actionMenu.setEnabled(false);


			ChatAttachment chatAttachment = chatAttachmentRepository.findChatAttachmentByChat(chatSession);
			if (chatAttachment != null) {
				byte[] imageData = chatAttachment.getData();
				StreamResource resource = new StreamResource("image.png", () -> new ByteArrayInputStream(imageData));
				image = new Image(resource, "image");
				add(image);
				handleDownloadingImage(imageData);
				recreateImageButton.setVisible(true);
			}

			getUI().ifPresent(UI::push);
		}
		else {
			chatSession = createChat();
		}
	}


	private void setupChatMessageComponents() {
		chatMessageInput.setI18n(MessageInputI18nFactory.getI18nForImageComponent(this::getTranslation));
		chatMessageInput.addSubmitListener(this::handleChatMessageInput);

		chatMessageInput.setWidthFull();
		chatMessageList.setSizeFull();

		VerticalLayout chatContainer = new VerticalLayout();
		chatContainer.addClassNames(LumoUtility.Flex.AUTO, LumoUtility.Overflow.HIDDEN);
		chatContainer.add(chatMessageList, chatMessageInput);

		add(chatContainer);
		expand(chatMessageList);
	}


	private void setupImageRequestOptions() {
		sizeSelect = createSelect(getTranslation("component.aiImage.sizeSelect"), SIZES);
		qualitySelect = createSelect(getTranslation("component.aiImage.qualitySelect"), QUALITIES);
		styleSelect = createSelect(getTranslation("component.aiImage.styleSelect"), STYLES);

		setDefaultSelectValues();

		HorizontalLayout selectContainer = new HorizontalLayout();
		selectContainer.setDefaultVerticalComponentAlignment(Alignment.AUTO);
		selectContainer.add(sizeSelect, qualitySelect, styleSelect);
		selectContainer.setSpacing(true);
		add(selectContainer);
	}


	private static Select createSelect(String label, List items) {
		Select select = new Select<>();
		select.setLabel(label);
		select.setItems(items);
		select.setWidthFull();

		return select;
	}


	private void setDefaultSelectValues() {
		if (sizeSelect != null)
			sizeSelect.setValue(SIZES.get(0));
		if (qualitySelect != null)
			qualitySelect.setValue(QUALITIES.get(0));
		if (styleSelect != null)
			styleSelect.setValue(STYLES.get(0));
	}


	private void toggleSelectInputs(boolean isEnabled) {
		if (sizeSelect != null)
			sizeSelect.setEnabled(isEnabled);
		if (qualitySelect != null)
			qualitySelect.setEnabled(isEnabled);
		if (styleSelect != null)
			styleSelect.setEnabled(isEnabled);
	}


	private void handleChatMessageInput(MessageInput.SubmitEvent event) {
		try (var c = AiChatComponent.CHAT_ID.setForScope(chatSession.getChatId())) {
			promptText = event.getValue();
			if (promptText == null || promptText.isEmpty())
				return;

			if (promptText.length() > MAX_LENGTH) {
				Notification.show(
							getTranslation("common.messageTooLong", MAX_LENGTH),
							6000,
							Notification.Position.MIDDLE);
				return;
			}

			if (!actionMenu.isEnabled())
				actionMenu.setEnabled(true);

			disableImageGenerationInputs();

			showMessage(promptText, AuthUtils.getUsername());
			getUI().ifPresent(UI::push);
			handleImageRequest(promptText);
		}
	}


	private void handleImageRequest(String userMessage) {
		chatMessages.add(Message.user(userMessage));
		ImageGenerationPrompt imagePrompt = ImageGenerationPrompt.builder(OpenAiService.DALL_E_3)
				.setPrompt(userMessage)
				.setParameter(OpenAiRequestParameters.QUALITY, qualitySelect.getValue())
				.setParameter(OpenAiRequestParameters.SIZE, sizeSelect.getValue())
				.setParameter(OpenAiRequestParameters.STYLE, styleSelect.getValue())
				.build();

		Prompt prompt = new Prompt(OpenAiRequestGenerator.generate(imagePrompt), null);
		promptRepository.save(prompt);
		saveMessage(ChatMessage.ChatMessageRole.USER, userMessage, prompt, imagePrompt.model());

		try {
			byte[] imageResponseBytes = handleResponse(api.send(imagePrompt), prompt, imagePrompt.model());
			if (imageResponseBytes != null) {
				ChatAttachment chatAttachment = new ChatAttachment("image/png", imageResponseBytes, imagePrompt.provider(), chatSession);
				handleDownloadingImage(imageResponseBytes);
				chatAttachmentRepository.save(chatAttachment);
				recreateImageButton.setVisible(isVisible());
			}
		}
		catch (ApiResponseException e) {
			showMessage(getTranslation("component.aiImage.generationFailed"), "AI");
			saveMessage(ChatMessage.ChatMessageRole.ERROR, "ERROR");
			chatSession.setSummary(getTranslation("component.aiImage.generationFailed"));
			chatRepository.save(chatSession);
		}

		actionMenu.setEnabled(true);
	}


	private byte[] handleResponse(ImageResponse response, Prompt prompt, Model model) {
		if (response.getImage() == null)
			return null;

		chatMessages.add(Message.assistant(getTranslation("component.aiImage.generationSuccess")));
		showMessage(getTranslation("component.aiImage.generationSuccess"), "AI");
		byte[] imageBytes = response.getImage().data();
		StreamResource resource = new StreamResource("image.png", () -> new ByteArrayInputStream(imageBytes));
		image = new Image(resource, "image");
		add(image);
		getUI().ifPresent(UI::push);
		saveMessage(ChatMessage.ChatMessageRole.ASSISTANT, getTranslation("component.aiImage.generationSuccess"), prompt, model);
		chatSession.setSummary(getTranslation("component.aiImage.generationSuccess"));
		chatRepository.save(chatSession);

		return imageBytes;
	}


	private void handleDownloadingImage(byte[] imageBytes) {
		displayDownload(true);
		downloadImageButton.setText(getTranslation("component.aiImage.downloadImage"));
		downloadImageButton.getStyle().setMarginTop("15px");
		if (buttonWrapper == null)
			buttonWrapper = new FileDownloadWrapper(new StreamResource("image.png", () -> new ByteArrayInputStream(imageBytes)));
		else
			buttonWrapper.setResource(new StreamResource("image.png", () -> new ByteArrayInputStream(imageBytes)));

		buttonWrapper.wrapComponent(downloadImageButton);
		buttonWrapper.setVisible(true);
		add(buttonWrapper);
	}


	private void setupRecreateImageButton() {
		recreateImageButton.setText(getTranslation("component.aiImage.recreateImage"));
		recreateImageButton.addClickListener(event -> {
			reset();
			disableImageGenerationInputs();
			showMessage(getTranslation("component.aiImage.generating"), "AI");
			UI.getCurrent().push();
			handleImageRequest(promptText);
		});

		recreateImageButton.getStyle().setMarginBottom("15px");
		recreateImageButton.setVisible(image != null && image.isAttached());
		add(recreateImageButton);
	}


	private void showMessage(String messageText, String sender) {
		MessageListItem message = new MessageListItem(messageText, ChatUtils.currentTimeHelsinki(), sender);
		chatMessageList.setItems(Stream.concat(chatMessageList.getItems().stream(), Stream.of(message)).toList());
	}


	private Chat createChat() {
		return new Chat(chatType != null ? chatType : viewClass.getSimpleName());
	}


	private static String inferUsername(Chat chat, ChatMessage message) {
		return switch (message.getRole()) {
			case ASSISTANT	-> "AI";
			case USER		-> NullSafetyUtils.denull(chat.getDisplayName(), "User");
			default			-> "Unknown";
		};
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy