All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.ai.vaadin.view.ImageHelperBaseView Maven / Gradle / Ivy

The newest version!
package fi.evolver.ai.vaadin.view;

import java.io.Serial;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vaadin.flow.component.Component;
import com.vaadin.flow.component.UI;
import com.vaadin.flow.component.html.Paragraph;
import com.vaadin.flow.router.BeforeEnterEvent;

import fi.evolver.ai.spring.ApiResponseException;
import fi.evolver.ai.spring.AsyncRunner;
import fi.evolver.ai.spring.chat.ChatApi;
import fi.evolver.ai.spring.chat.FunctionCall;
import fi.evolver.ai.spring.chat.function.FunctionSpec;
import fi.evolver.ai.spring.chat.function.annotation.FunctionName;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt;
import fi.evolver.ai.spring.chat.prompt.Message;
import fi.evolver.ai.spring.file.AiFile;
import fi.evolver.ai.spring.image.ImageApi;
import fi.evolver.ai.spring.image.ImageResponse;
import fi.evolver.ai.spring.image.prompt.ImageGenerationPrompt;
import fi.evolver.ai.vaadin.ChatAttachmentRepository;
import fi.evolver.ai.vaadin.ChatRepository;
import fi.evolver.ai.vaadin.PromptRepository;
import fi.evolver.ai.vaadin.component.AiChatComponent;
import fi.evolver.ai.vaadin.component.ChatAvatarImageItem;
import fi.evolver.ai.vaadin.component.ChatAvatarItem;
import fi.evolver.ai.vaadin.component.ChatAvatarLoadingItem;
import fi.evolver.ai.vaadin.component.ChatMessageContainer;
import fi.evolver.ai.vaadin.entity.ChatAttachment;
import fi.evolver.ai.vaadin.util.ChatUtils;
import fi.evolver.ai.vaadin.util.ChatUtils.CommandType;

public abstract class ImageHelperBaseView extends ChatHistoryAwareView {
	@Serial
	private static final long serialVersionUID = 1L;

	protected static final Logger LOG = LoggerFactory.getLogger(ImageHelperBaseView.class);

	protected final AiChatComponent chatComponent;
	protected final ImageApi imageApi;
	protected final ChatAttachmentRepository chatAttachmentRepository;
	protected final Class imageGenerationTool;

	public ImageHelperBaseView(
			ChatApi chatApi,
			ImageApi imageApi,
			ChatRepository chatMessageRepository,
			PromptRepository promptRepository,
			Function> commandCreator,
			ChatAttachmentRepository chatAttachmentRepository,
			Class imageGenerationTool,
			AsyncRunner asyncRunner) {
		preInit();
		this.imageApi = imageApi;
		this.chatComponent = new AiChatComponent(
				chatApi,
				chatMessageRepository,
				promptRepository,
				getClass(),
				null,
				this::createChatPrompt,
				ChatUtils::createSummaryPrompt,
				commandCreator,
				null,
				this::handleFunctionCall,
				true,
				asyncRunner,
				getChatContainer());
		this.chatAttachmentRepository = chatAttachmentRepository;
		this.imageGenerationTool = imageGenerationTool;
	}

	public ImageHelperBaseView(
			ChatApi chatApi,
			ImageApi imageApi,
			ChatRepository chatMessageRepository,
			PromptRepository promptRepository,
			ChatAttachmentRepository chatAttachmentRepository,
			Class imageGenerationTool,
			AsyncRunner asyncRunner) {
		this(
				chatApi,
				imageApi,
				chatMessageRepository,
				promptRepository,
				c -> Optional.empty(),
				chatAttachmentRepository,
				imageGenerationTool,
				asyncRunner);
	}

	protected void preInit() {}

	protected abstract ChatMessageContainer getChatContainer();

	protected abstract ChatPrompt createChatPrompt(List chatMessages);

	protected abstract ImageGenerationPrompt createImagePrompt(String prompt);

	@Override
	public HistoryAwareChat getChatComponent() {
		return chatComponent;
	}

	@Override
	public void beforeEnter(BeforeEnterEvent event) {
		super.startChatIfExists(event);
		if (chatComponent.getChatSession().getId() <= 0)
			return;
		List existingImages = chatAttachmentRepository.findChatAttachmentsByChat(chatComponent.getChatSession());
		if (existingImages.size() > 0) {
			Pattern reg = Pattern.compile("^image\\/(\\w+)$");
			List imageFiles = existingImages.stream().map(i -> {
				String filename = "image";
				Matcher m = reg.matcher(i.getMimeType());
				if (m.matches() && m.groupCount() == 1)
					filename = "image." + m.group(1);
				return new AiFile(i.getData(), i.getMimeType(), filename);
			}).toList();
			getChatContainer().addItem(new ChatAvatarImageItem(
					getTranslation("view.baseImageHelper.imageGenerator"),
					getTranslation("view.baseImageHelper.previousImages"),
					imageFiles),
					false);
		}
	}

	protected Optional handleFunctionCall(FunctionCall call) {
		if (call.getFunctionName().equals(imageGenerationTool.getAnnotation(FunctionName.class).value())) {
			ImageGenerationTool toolCall = call.toResult(FunctionSpec.of(imageGenerationTool));
			generateImage(toolCall.getPrompt());
		}
		return Optional.empty();
	}

	protected void generateImage(String prompt) {
		generateImage(prompt, true);
	}

	protected void generateImage(String prompt, boolean showPrompt) {
		ImageGenerationPrompt imagePrompt = createImagePrompt(prompt);
		if (showPrompt)
			getChatContainer().addItem(prompt, "AI", false, false);
		getChatContainer().addItem(new ChatAvatarLoadingItem(getTranslation("view.baseImageHelper.imageGenerator"), getTranslation("component.aiImage.generating")), false);
		getChatContainer().scrollToEnd();
		getUI().ifPresent(UI::push);
		try {
			ImageResponse res = imageApi.send(imagePrompt);
			getChatContainer().addItem(
					new ChatAvatarImageItem(
							getTranslation("view.baseImageHelper.imageGenerator"),
							res.getImages()),
					true);
			for (AiFile img : res.getImages()) {
				ChatAttachment chatAttachment = new ChatAttachment(img.mimeType(), img.data(), imagePrompt.provider(), chatComponent.getChatSession());
				chatAttachmentRepository.save(chatAttachment);
			}
		}
		catch (Exception e) {
			String errorMsg = getTranslation("component.aiImage.generationFailed");
			if (e instanceof ApiResponseException apiException)
				if (apiException.getCause() instanceof InterruptedException)
					errorMsg = getTranslation("component.aiImage.generationTimeout");
			LOG.error("Error creating image", e);
			getChatContainer().addItem(
					new ChatAvatarItem(
							getTranslation("view.baseImageHelper.imageGenerator"),
							new Paragraph(errorMsg)),
					true);
		}
		getChatContainer().scrollToEnd();
	}

	protected static interface ImageGenerationTool {
		public String getPrompt();
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy