All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.quarkiverse.langchain4j.runtime.devui.ChatJsonRPCService Maven / Gradle / Ivy

There is a newer version: 0.21.0
Show newest version
package io.quarkiverse.langchain4j.runtime.devui;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import jakarta.enterprise.context.control.ActivateRequestContext;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.tool.ToolExecutor;
import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.service.tool.ToolProviderRequest;
import dev.langchain4j.service.tool.ToolProviderResult;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.devui.json.ChatMessagePojo;
import io.quarkiverse.langchain4j.runtime.devui.json.ChatResultPojo;
import io.quarkiverse.langchain4j.runtime.devui.json.ToolExecutionRequestPojo;
import io.quarkiverse.langchain4j.runtime.devui.json.ToolExecutionResultPojo;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutorFactory;
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
import io.quarkus.arc.All;
import io.quarkus.arc.Arc;
import io.quarkus.logging.Log;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import io.smallrye.mutiny.subscription.MultiEmitter;
import io.vertx.core.json.JsonObject;

@ActivateRequestContext
public class ChatJsonRPCService {

    public static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 20;
    private final ChatLanguageModel model;
    private final Optional streamingModel;

    private final ChatMemoryProvider memoryProvider;

    // The augmentor to use, if any is found in the application. Only augmentors that are CDI beans
    // can be found. If more than one is found, for now we choose the first one offered by the CDI container.
    // FIXME: perhaps the UI could offer choosing between available augmentors when there are more
    private RetrievalAugmentor retrievalAugmentor;

    private final List toolSpecifications;
    private final Map toolExecutors;
    private final ToolProvider toolProvider;

    public ChatJsonRPCService(@All List models, // don't use ChatLanguageModel model because it results in the default model not being configured
            @All List streamingModels,
            @All List> retrievalAugmentorSuppliers,
            @All List retrievalAugmentors,
            ChatMemoryProvider memoryProvider,
            QuarkusToolExecutorFactory toolExecutorFactory,
            @All List> toolProviders) {
        this.model = models.get(0);
        this.toolProvider = getToolProvider(toolProviders);
        this.streamingModel = streamingModels.isEmpty() ? Optional.empty() : Optional.of(streamingModels.get(0));
        this.retrievalAugmentor = null;
        for (Supplier supplier : retrievalAugmentorSuppliers) {
            this.retrievalAugmentor = supplier.get();
            if (this.retrievalAugmentor != null) {
                break;
            }
        }
        if (this.retrievalAugmentor == null) {
            for (RetrievalAugmentor augmentorFromCdi : retrievalAugmentors) {
                this.retrievalAugmentor = augmentorFromCdi;
                if (this.retrievalAugmentor != null) {
                    break;
                }
            }
        }
        this.memoryProvider = memoryProvider;
        // retrieve available tools
        Map> toolsMetadata = ToolsRecorder.getMetadata();
        if (toolsMetadata != null && this.toolProvider == null) {
            toolExecutors = new HashMap<>();
            toolSpecifications = new ArrayList<>();
            for (Map.Entry> entry : toolsMetadata.entrySet()) {
                for (ToolMethodCreateInfo methodCreateInfo : entry.getValue()) {
                    Object objectWithTool = null;
                    try {
                        objectWithTool = Arc.container().select(
                                Thread.currentThread().getContextClassLoader().loadClass(entry.getKey())).get();
                    } catch (ClassNotFoundException e) {
                        throw new RuntimeException(e);
                    }
                    QuarkusToolExecutor.Context executorContext = new QuarkusToolExecutor.Context(objectWithTool,
                            methodCreateInfo.invokerClassName(), methodCreateInfo.methodName(),
                            methodCreateInfo.argumentMapperClassName(), methodCreateInfo.executionModel());
                    toolExecutors.put(methodCreateInfo.toolSpecification().name(),
                            toolExecutorFactory.create(executorContext));
                    toolSpecifications.add(methodCreateInfo.toolSpecification());
                }
            }

        } else if (this.toolProvider != null) {
            // mutable list / map
            toolExecutors = new HashMap<>();
            toolSpecifications = new ArrayList<>();
        } else {
            toolSpecifications = List.of();
            toolExecutors = Map.of();
        }
    }

    private final AtomicReference currentMemory = new AtomicReference<>();
    private final AtomicLong currentMemoryId = new AtomicLong();

    public String reset(String systemMessage) {
        if (currentMemory.get() != null) {
            currentMemory.get().clear();
        }
        long memoryId = ThreadLocalRandom.current().nextLong();
        currentMemoryId.set(memoryId);
        ChatMemory memory = memoryProvider.get(memoryId);
        currentMemory.set(memory);
        if (systemMessage != null && !systemMessage.isEmpty()) {
            memory.add(new SystemMessage(systemMessage));
        }
        return "OK";
    }

    public boolean isStreamingChatSupported() {
        return streamingModel.isPresent();
    }

    public Multi streamingChat(String message, boolean ragEnabled) {
        ChatMemory m = currentMemory.get();
        if (m == null) {
            reset("");
            m = currentMemory.get();
        }
        final ChatMemory memory = m;

        // create a backup of the chat memory, because we are now going to
        // add a new message to it, and might have to remove it if the chat
        // request fails - unfortunately the ChatMemory API doesn't allow
        // removing single messages
        List chatMemoryBackup = memory.messages();

        Multi stream = Multi.createFrom().emitter(em -> {
            try {
                // invoke RAG is applicable
                UserMessage userMessage = UserMessage.from(message);
                if (retrievalAugmentor != null && ragEnabled) {
                    Metadata metadata = Metadata.from(userMessage, currentMemoryId.get(), memory.messages());
                    AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
                    ChatMessage augmentedMessage = retrievalAugmentor.augment(augmentationRequest).chatMessage();
                    memory.add(augmentedMessage);
                    em.emit(new JsonObject().put("augmentedMessage", augmentedMessage.text()));
                } else {
                    memory.add(new UserMessage(message));
                }

                StreamingChatLanguageModel streamingModel = this.streamingModel.orElseThrow(IllegalStateException::new);
                boolean hasToolProvider = setToolsViaProviderIfAvailable(memory, userMessage);

                // invoke tools if applicable
                Response modelResponse;
                if (toolSpecifications.isEmpty()) {
                    streamingModel.generate(memory.messages(), new StreamingResponseHandler() {
                        @Override
                        public void onComplete(Response response) {
                            memory.add(response.content());
                            String message = response.content().text();
                            em.emit(new JsonObject().put("message", message));
                            em.complete();
                        }

                        @Override
                        public void onNext(String token) {
                            em.emit(new JsonObject().put("token", token));
                        }

                        @Override
                        public void onError(Throwable error) {
                            em.fail(error);
                        }
                    });
                } else {
                    executeWithToolsAndStreaming(memory, em, MAX_SEQUENTIAL_TOOL_EXECUTIONS);
                }
            } catch (Throwable t) {
                // restore the memory from the backup
                memory.clear();
                chatMemoryBackup.forEach(memory::add);
                Log.warn(t);
                em.fail(t);
            }
        });
        // run on a worker thread because the retrieval augmentor might be blocking
        return stream.runSubscriptionOn(Infrastructure.getDefaultWorkerPool());
    }

    public ChatResultPojo chat(String message, boolean ragEnabled) {
        ChatMemory memory = currentMemory.get();
        if (memory == null) {
            reset("");
            memory = currentMemory.get();
        }
        // create a backup of the chat memory, because we are now going to
        // add a new message to it, and might have to remove it if the chat
        // request fails - unfortunately the ChatMemory API doesn't allow
        // removing single messages
        List chatMemoryBackup = memory.messages();
        try {
            UserMessage userMessage = UserMessage.from(message);
            if (retrievalAugmentor != null && ragEnabled) {
                Metadata metadata = Metadata.from(userMessage, currentMemoryId.get(), memory.messages());
                AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
                ChatMessage augmentedMessage = retrievalAugmentor.augment(augmentationRequest).chatMessage();
                memory.add(augmentedMessage);
            } else {
                memory.add(new UserMessage(message));
            }

            boolean hasToolProvider = setToolsViaProviderIfAvailable(memory, userMessage);

            Response modelResponse;
            if (toolSpecifications.isEmpty()) {
                modelResponse = model.generate(memory.messages());
                memory.add(modelResponse.content());
            } else {
                executeWithTools(memory);
            }
            // Remove toolProviderSupplier tools again
            if (hasToolProvider) {
                toolSpecifications.clear();
                toolExecutors.clear();
            }
            List response = ChatMessagePojo.listFromMemory(memory);
            return new ChatResultPojo(response, null);
        } catch (Throwable t) {
            // restore the memory from the backup
            memory.clear();
            chatMemoryBackup.forEach(memory::add);
            Log.warn(t);
            return new ChatResultPojo(null, t.getMessage());
        }
    }

    // FIXME: this was basically copied from `dev.langchain4j.service.DefaultAiServices`,
    // maybe it could be extracted into a reusable piece of code
    private Response executeWithTools(ChatMemory memory) {
        Response response = model.generate(memory.messages(), toolSpecifications);
        int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 20;
        int executionsLeft = MAX_SEQUENTIAL_TOOL_EXECUTIONS;
        while (true) {
            if (executionsLeft-- == 0) {
                throw new RuntimeException(
                        "Something is wrong, exceeded " + MAX_SEQUENTIAL_TOOL_EXECUTIONS + " sequential tool executions");
            }
            AiMessage aiMessage = response.content();
            memory.add(aiMessage);
            if (!aiMessage.hasToolExecutionRequests()) {
                break;
            }
            for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                ToolExecutor toolExecutor = toolExecutors.get(toolExecutionRequest.name());
                String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, currentMemoryId.get());
                ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(
                        toolExecutionRequest,
                        toolExecutionResult);
                memory.add(toolExecutionResultMessage);
            }
            response = model.generate(memory.messages(), toolSpecifications);
        }
        return Response.from(response.content(), new TokenUsage(), response.finishReason());
    }

    private void executeWithToolsAndStreaming(ChatMemory memory,
            MultiEmitter em,
            int toolExecutionsLeft) {
        toolExecutionsLeft--;
        if (toolExecutionsLeft == 0) {
            throw new RuntimeException(
                    "Something is wrong, exceeded " + MAX_SEQUENTIAL_TOOL_EXECUTIONS + " sequential tool executions");
        }
        int finalToolExecutionsLeft = toolExecutionsLeft;
        streamingModel.get().generate(memory.messages(), toolSpecifications, new StreamingResponseHandler() {
            @Override
            public void onComplete(Response response) {
                // run on a worker thread because the tool might be blocking
                Infrastructure.getDefaultExecutor().execute(() -> {
                    AiMessage aiMessage = response.content();
                    memory.add(aiMessage);
                    if (!aiMessage.hasToolExecutionRequests()) {
                        em.emit(new JsonObject().put("message", aiMessage.text()));
                        em.complete();
                    } else {
                        for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                            ToolExecutor toolExecutor = toolExecutors.get(toolExecutionRequest.name());
                            ToolExecutionRequestPojo toolExecutionRequestPojo = new ToolExecutionRequestPojo(
                                    toolExecutionRequest.id(), toolExecutionRequest.name(), toolExecutionRequest.arguments());
                            em.emit(new JsonObject().put("toolExecutionRequest", toolExecutionRequestPojo));
                            String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, currentMemoryId.get());
                            ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage
                                    .from(toolExecutionRequest, toolExecutionResult);
                            memory.add(toolExecutionResultMessage);
                            ToolExecutionResultPojo toolExecutionResultPojo = new ToolExecutionResultPojo(
                                    toolExecutionResultMessage.id(), toolExecutionResultMessage.toolName(),
                                    toolExecutionResultMessage.text());
                            em.emit(new JsonObject().put("toolExecutionResult", toolExecutionResultPojo));
                        }
                        executeWithToolsAndStreaming(memory, em, finalToolExecutionsLeft);
                    }
                    // Remove toolProviderSupplier tools again
                    if (toolProvider != null) {
                        toolSpecifications.clear();
                        toolExecutors.clear();
                    }
                });
            }

            @Override
            public void onNext(String token) {
                em.emit(new JsonObject().put("token", token));
            }

            @Override
            public void onError(Throwable error) {
                throw new RuntimeException(error);
            }
        });
    }

    private ToolProvider getToolProvider(List> toolProviders) {
        for (Supplier provider : toolProviders) {
            if (provider.get() != null) {
                return provider.get();
            }
        }
        return null;
    }

    private boolean setToolsViaProviderIfAvailable(ChatMemory memory, UserMessage userMessage) {
        boolean hasToolProvider = toolProvider != null;
        if (hasToolProvider) {
            ToolProviderRequest toolRequest = new ToolProviderRequest(memory, userMessage);
            ToolProviderResult toolsResult = toolProvider.provideTools(toolRequest);
            for (ToolSpecification specification : toolsResult.tools().keySet()) {
                toolSpecifications.add(specification);
                toolExecutors.put(specification.name(), toolsResult.tools().get(specification));
            }
        }
        return hasToolProvider;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy