All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceTokenStream Maven / Gradle / Ivy

There is a newer version: 0.21.0
Show newest version
package io.quarkiverse.langchain4j.runtime.aiservice;

import static dev.langchain4j.internal.Utils.copyIfNotNull;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.Collections.emptyList;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.tool.ToolExecution;
import dev.langchain4j.service.tool.ToolExecutor;
import io.vertx.core.Context;

/**
 * An implementation of token stream for Quarkus.
 * The only difference with the upstream implementation is the usage of the custom
 * {@link QuarkusAiServiceStreamingResponseHandler} instead of the upstream one.
 * It allows handling blocking tools execution, when we are invoked on the event loop.
 */
public class QuarkusAiServiceTokenStream implements TokenStream {

    private final List messages;
    private final List toolSpecifications;
    private final Map toolExecutors;
    private final List retrievedContents;
    private final AiServiceContext context;
    private final Object memoryId;
    private final Context cxtx;
    private final boolean mustSwitchToWorkerThread;

    private Consumer tokenHandler;
    private Consumer> contentsHandler;
    private Consumer errorHandler;
    private Consumer> completionHandler;
    private Consumer toolExecuteHandler;

    private int onNextInvoked;
    private int onCompleteInvoked;
    private int onRetrievedInvoked;
    private int onErrorInvoked;
    private int ignoreErrorsInvoked;
    private int toolExecuteInvoked;

    public QuarkusAiServiceTokenStream(List messages,
            List toolSpecifications,
            Map toolExecutors,
            List retrievedContents,
            AiServiceContext context,
            Object memoryId, Context ctxt, boolean mustSwitchToWorkerThread) {
        this.messages = ensureNotEmpty(messages, "messages");
        this.toolSpecifications = copyIfNotNull(toolSpecifications);
        this.toolExecutors = copyIfNotNull(toolExecutors);
        this.retrievedContents = retrievedContents;
        this.context = ensureNotNull(context, "context");
        this.memoryId = ensureNotNull(memoryId, "memoryId");
        ensureNotNull(context.streamingChatModel, "streamingChatModel");
        this.cxtx = ctxt; // If set, it means we need to handle the context propagation.
        this.mustSwitchToWorkerThread = mustSwitchToWorkerThread; // If true, we need to switch to a worker thread to execute tools.
    }

    @Override
    public TokenStream onNext(Consumer tokenHandler) {
        this.tokenHandler = tokenHandler;
        this.onNextInvoked++;
        return this;
    }

    @Override
    public TokenStream onRetrieved(Consumer> contentsHandler) {
        this.contentsHandler = contentsHandler;
        this.onRetrievedInvoked++;
        return this;
    }

    @Override
    public TokenStream onToolExecuted(Consumer toolExecuteHandler) {
        this.toolExecuteHandler = toolExecuteHandler;
        this.toolExecuteInvoked++;
        return this;
    }

    @Override
    public TokenStream onComplete(Consumer> completionHandler) {
        this.completionHandler = completionHandler;
        this.onCompleteInvoked++;
        return this;
    }

    @Override
    public TokenStream onError(Consumer errorHandler) {
        this.errorHandler = errorHandler;
        this.onErrorInvoked++;
        return this;
    }

    @Override
    public TokenStream ignoreErrors() {
        this.errorHandler = null;
        this.ignoreErrorsInvoked++;
        return this;
    }

    @Override
    public void start() {
        validateConfiguration();
        QuarkusAiServiceStreamingResponseHandler handler = new QuarkusAiServiceStreamingResponseHandler(
                context,
                memoryId,
                tokenHandler,
                toolExecuteHandler,
                completionHandler,
                errorHandler,
                initTemporaryMemory(context, messages),
                new TokenUsage(),
                toolSpecifications,
                toolExecutors,
                mustSwitchToWorkerThread,
                cxtx);

        if (contentsHandler != null && retrievedContents != null) {
            contentsHandler.accept(retrievedContents);
        }

        if (isNullOrEmpty(toolSpecifications)) {
            context.streamingChatModel.generate(messages, handler);
        } else {
            try {
                // Some model do not support function calling with tool specifications
                context.streamingChatModel.generate(messages, toolSpecifications, handler);
            } catch (Exception e) {
                if (errorHandler != null) {
                    errorHandler.accept(e);
                }
            }
        }
    }

    private void validateConfiguration() {
        if (onNextInvoked != 1) {
            throw new IllegalConfigurationException("onNext must be invoked exactly 1 time");
        }

        if (onCompleteInvoked > 1) {
            throw new IllegalConfigurationException("onComplete must be invoked at most 1 time");
        }

        if (onRetrievedInvoked > 1) {
            throw new IllegalConfigurationException("onRetrieved must be invoked at most 1 time");
        }

        if (toolExecuteInvoked > 1) {
            throw new IllegalConfigurationException("onToolExecuted must be invoked at most 1 time");
        }

        if (onErrorInvoked + ignoreErrorsInvoked != 1) {
            throw new IllegalConfigurationException("One of onError or ignoreErrors must be invoked exactly 1 time");
        }
    }

    private List initTemporaryMemory(AiServiceContext context, List messagesToSend) {
        if (context.hasChatMemory()) {
            return emptyList();
        } else {
            return new ArrayList<>(messagesToSend);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy