io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceTokenStream Maven / Gradle / Ivy
package io.quarkiverse.langchain4j.runtime.aiservice;
import static dev.langchain4j.internal.Utils.copyIfNotNull;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.Collections.emptyList;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.tool.ToolExecution;
import dev.langchain4j.service.tool.ToolExecutor;
import io.vertx.core.Context;
/**
* An implementation of token stream for Quarkus.
* The only difference with the upstream implementation is the usage of the custom
* {@link QuarkusAiServiceStreamingResponseHandler} instead of the upstream one.
* It allows handling blocking tools execution, when we are invoked on the event loop.
*/
public class QuarkusAiServiceTokenStream implements TokenStream {
private final List messages;
private final List toolSpecifications;
private final Map toolExecutors;
private final List retrievedContents;
private final AiServiceContext context;
private final Object memoryId;
private final Context cxtx;
private final boolean mustSwitchToWorkerThread;
private Consumer tokenHandler;
private Consumer> contentsHandler;
private Consumer errorHandler;
private Consumer> completionHandler;
private Consumer toolExecuteHandler;
private int onNextInvoked;
private int onCompleteInvoked;
private int onRetrievedInvoked;
private int onErrorInvoked;
private int ignoreErrorsInvoked;
private int toolExecuteInvoked;
public QuarkusAiServiceTokenStream(List messages,
List toolSpecifications,
Map toolExecutors,
List retrievedContents,
AiServiceContext context,
Object memoryId, Context ctxt, boolean mustSwitchToWorkerThread) {
this.messages = ensureNotEmpty(messages, "messages");
this.toolSpecifications = copyIfNotNull(toolSpecifications);
this.toolExecutors = copyIfNotNull(toolExecutors);
this.retrievedContents = retrievedContents;
this.context = ensureNotNull(context, "context");
this.memoryId = ensureNotNull(memoryId, "memoryId");
ensureNotNull(context.streamingChatModel, "streamingChatModel");
this.cxtx = ctxt; // If set, it means we need to handle the context propagation.
this.mustSwitchToWorkerThread = mustSwitchToWorkerThread; // If true, we need to switch to a worker thread to execute tools.
}
@Override
public TokenStream onNext(Consumer tokenHandler) {
this.tokenHandler = tokenHandler;
this.onNextInvoked++;
return this;
}
@Override
public TokenStream onRetrieved(Consumer> contentsHandler) {
this.contentsHandler = contentsHandler;
this.onRetrievedInvoked++;
return this;
}
@Override
public TokenStream onToolExecuted(Consumer toolExecuteHandler) {
this.toolExecuteHandler = toolExecuteHandler;
this.toolExecuteInvoked++;
return this;
}
@Override
public TokenStream onComplete(Consumer> completionHandler) {
this.completionHandler = completionHandler;
this.onCompleteInvoked++;
return this;
}
@Override
public TokenStream onError(Consumer errorHandler) {
this.errorHandler = errorHandler;
this.onErrorInvoked++;
return this;
}
@Override
public TokenStream ignoreErrors() {
this.errorHandler = null;
this.ignoreErrorsInvoked++;
return this;
}
@Override
public void start() {
validateConfiguration();
QuarkusAiServiceStreamingResponseHandler handler = new QuarkusAiServiceStreamingResponseHandler(
context,
memoryId,
tokenHandler,
toolExecuteHandler,
completionHandler,
errorHandler,
initTemporaryMemory(context, messages),
new TokenUsage(),
toolSpecifications,
toolExecutors,
mustSwitchToWorkerThread,
cxtx);
if (contentsHandler != null && retrievedContents != null) {
contentsHandler.accept(retrievedContents);
}
if (isNullOrEmpty(toolSpecifications)) {
context.streamingChatModel.generate(messages, handler);
} else {
try {
// Some model do not support function calling with tool specifications
context.streamingChatModel.generate(messages, toolSpecifications, handler);
} catch (Exception e) {
if (errorHandler != null) {
errorHandler.accept(e);
}
}
}
}
private void validateConfiguration() {
if (onNextInvoked != 1) {
throw new IllegalConfigurationException("onNext must be invoked exactly 1 time");
}
if (onCompleteInvoked > 1) {
throw new IllegalConfigurationException("onComplete must be invoked at most 1 time");
}
if (onRetrievedInvoked > 1) {
throw new IllegalConfigurationException("onRetrieved must be invoked at most 1 time");
}
if (toolExecuteInvoked > 1) {
throw new IllegalConfigurationException("onToolExecuted must be invoked at most 1 time");
}
if (onErrorInvoked + ignoreErrorsInvoked != 1) {
throw new IllegalConfigurationException("One of onError or ignoreErrors must be invoked exactly 1 time");
}
}
private List initTemporaryMemory(AiServiceContext context, List messagesToSend) {
if (context.hasChatMemory()) {
return emptyList();
} else {
return new ArrayList<>(messagesToSend);
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy