fi.evolver.ai.spring.provider.openai.OpenAiStreamingChatResponse Maven / Gradle / Ivy
package fi.evolver.ai.spring.provider.openai;
import java.util.*;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.CountDownLatch;
import java.util.stream.Collectors;
import fi.evolver.ai.spring.provider.openai.response.OUsage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import fi.evolver.ai.spring.ApiResponseException;
import fi.evolver.ai.spring.ContentSubscriber;
import fi.evolver.ai.spring.chat.ChatResponse;
import fi.evolver.ai.spring.chat.FunctionCall;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt;
import fi.evolver.ai.spring.chat.prompt.Message;
import fi.evolver.ai.spring.provider.openai.response.ORateLimitHeaders;
import fi.evolver.ai.spring.provider.openai.response.chat.OChatResult;
import fi.evolver.ai.spring.provider.openai.response.chat.OChoice;
import fi.evolver.ai.spring.provider.openai.response.chat.ODelta;
import fi.evolver.ai.spring.provider.openai.response.chat.OFunctionCall;
import fi.evolver.ai.spring.provider.openai.response.chat.OToolCall;
import fi.evolver.ai.spring.util.TokenUtils;
import fi.evolver.utils.collection.CollectionUtils;
import fi.evolver.utils.string.StringUtils;
/**
* A ChatResponse for streaming API calls.
*
* The streaming support works as follows:
* - Subscribers get any already received deltas on subscribe and any further deltas as they are received
* - Asking for the message blocks until the whole response is complete
* - Asking for a single function call will return as soon as we know the name of the first function call
* - Asking for function call arguments blocks until the whole response is complete
* - Asking for all function calls blocks until the whole response is complete (we do not know how many calls to wait for)
*/
public class OpenAiStreamingChatResponse extends ChatResponse {
private static final Logger LOG = LoggerFactory.getLogger(OpenAiStreamingChatResponse.class);
private final Deque results = new ConcurrentLinkedDeque<>();
private List subscribers = new ArrayList<>();
private final CountDownLatch responseCompleteLatch = new CountDownLatch(1);
private final CountDownLatch functionCallLatch = new CountDownLatch(1);
private final CountDownLatch rateLimitHeadersLatch = new CountDownLatch(1);
private Optional content = Optional.empty();
private Optional functionCall = Optional.empty();
private List functionCalls;
private Message responseMessage;
private ORateLimitHeaders rateLimitHeaders;
private volatile Throwable responseException;
private volatile Optional usage;
public OpenAiStreamingChatResponse(ChatPrompt prompt) {
super(prompt);
}
public synchronized void addResult(OChatResult result) {
results.add(result);
if (result.choices().isEmpty()) // Azure returns a prompt_filter_results without any actual content as the first chunk.
return;
if (functionCall.isEmpty()) {
Optional toolCall = getFunctionCall(result);
Optional name = toolCall
.map(OToolCall::functionCall)
.filter(Objects::nonNull)
.map(OFunctionCall::name);
if (toolCall.isPresent() && name.isPresent()) {
functionCall = toolCall.map(OpenAiStreamingFunctionCall::new);
functionCallLatch.countDown();
}
}
Optional contentUpdate = getContent(result);
for (ContentSubscriber subscriber: subscribers) {
try {
contentUpdate.ifPresent(subscriber::onContent);
}
catch (RuntimeException e) {
LOG.error("Subscriber failed handling content update", e);
}
}
}
public synchronized void handleError(Throwable throwable) {
this.responseException = throwable;
this.usage = readUsage();
responseCompleteLatch.countDown();
functionCallLatch.countDown();
rateLimitHeadersLatch.countDown();
for (ContentSubscriber subscriber: subscribers) {
try {
subscriber.onError(throwable);
}
catch (RuntimeException e) {
LOG.error("Subscriber failed handling stream error ({})", throwable.toString(), e);
}
}
}
public synchronized void handleStreamEnd() {
String finishReason = getFinishReason().orElse(null);
if (finishReason == null) {
handleError(new IllegalStateException("Stream ended without finish reason"));
return;
}
this.usage = readUsage();
if (responseCompleteLatch.getCount() == 0)
return;
functionCall.ifPresent(fc -> fc.setArgumentData(results.stream()
.map(OpenAiStreamingChatResponse::getFunctionCall)
.filter(Optional::isPresent)
.map(Optional::get)
.map(OToolCall::functionCall)
.map(OFunctionCall::arguments)
.filter(Objects::nonNull)
.collect(Collectors.joining())));
content = Optional.of(results.stream()
.map(OpenAiStreamingChatResponse::getContent)
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.joining()))
.filter(StringUtils::hasText);
functionCalls = createFunctionCalls();
responseMessage = createMessage(content, functionCalls);
functionCallLatch.countDown();
responseCompleteLatch.countDown();
for (ContentSubscriber subscriber: subscribers) {
try {
subscriber.onComplete(finishReason);
}
catch (RuntimeException e) {
LOG.error("Subscriber failed handling stream completion", e);
}
}
}
private static Optional getFunctionCall(OChatResult result) {
return result.choices().stream()
.map(OChoice::delta)
.map(ODelta::toolCalls)
.filter(Objects::nonNull)
.flatMap(List::stream)
.filter(call -> call.functionCall() != null)
.findFirst();
}
private static Optional getContent(OChatResult result) {
return Optional.of(result.choices().stream()
.map(OChoice::delta)
.map(ODelta::content)
.filter(Objects::nonNull)
.collect(Collectors.joining()))
.filter(s -> !s.isEmpty());
}
private List createFunctionCalls() {
Map functionCallByIndex = new TreeMap<>();
Map argumentsByIndex = new HashMap<>();
for (OChatResult result: results) {
List toolCalls = CollectionUtils.first(result.choices())
.map(OChoice::delta)
.map(ODelta::toolCalls)
.orElseGet(List::of);
for (OToolCall toolCall: toolCalls) {
OFunctionCall update = toolCall.functionCall();
if (update == null)
continue;
if (update.name() != null)
functionCallByIndex.put(toolCall.index(), new OpenAiStreamingFunctionCall(toolCall));
if (update.arguments() != null)
argumentsByIndex.computeIfAbsent(toolCall.index(), k -> new StringBuilder()).append(update.arguments());
}
}
StringBuilder emptyBuilder = new StringBuilder();
functionCallByIndex.forEach((i, f) -> f.setArgumentData(argumentsByIndex.getOrDefault(i, emptyBuilder).toString()));
return new ArrayList<>(functionCallByIndex.values());
}
@Override
public synchronized void addSubscriber(ContentSubscriber subscriber) {
this.subscribers.add(subscriber);
for (OChatResult result: results)
getContent(result).ifPresent(subscriber::onContent);
if (responseException != null)
subscriber.onError(responseException);
else
getFinishReason().ifPresent(subscriber::onComplete);
}
public synchronized void addRateLimitHeaders(ORateLimitHeaders rateLimitHeaders) {
this.rateLimitHeaders = rateLimitHeaders;
rateLimitHeadersLatch.countDown();
}
@Override
public String getResultState() {
try {
getResponseMessage();
getFunctionCall().ifPresent(FunctionCall::getArgumentData);
}
catch (RuntimeException e) {
return "error";
}
return getFinishReason().orElse("error");
}
@Override
public boolean isSuccess() {
return OpenAiService.FINISH_REASONS_OK.contains(getResultState());
}
private Optional readUsage() {
Iterator iterator = results.descendingIterator();
while (iterator.hasNext()) {
OUsage usage = iterator.next().usage();
if (usage != null)
return Optional.of(usage);
}
return Optional.empty();
}
private Optional getFinishReason() {
Iterator iterator = results.descendingIterator();
while (iterator.hasNext()) {
Optional finishReason = getFinishReason(iterator.next());
if (finishReason.isPresent())
return finishReason;
}
return Optional.empty();
}
private static Optional getFinishReason(OChatResult result) {
return result.choices().stream().map(OChoice::finishReason).filter(Objects::nonNull).findFirst();
}
@Override
public Message getResponseMessage() {
try {
responseCompleteLatch.await();
if (responseException != null)
throw new ApiResponseException(responseException, "Reading message failed unexpectedly");
return responseMessage;
}
catch (InterruptedException e) {
throw new ApiResponseException(e, "Interrupted while waiting for response");
}
}
@Override
public Optional getTextContent() {
getResponseMessage();
return content;
}
@Override
public Optional getFunctionCall() {
try {
functionCallLatch.await();
if (responseException != null)
throw new ApiResponseException(responseException, "Reading message failed unexpectedly");
return functionCall.map(OpenAiFunctionCall.class::cast);
}
catch (InterruptedException e) {
throw new ApiResponseException(e, "Interrupted while waiting for response");
}
}
@Override
public List getFunctionCalls() {
getResponseMessage();
return functionCalls;
}
@Override
public ORateLimitHeaders getRateLimitHeaders() {
try {
rateLimitHeadersLatch.await();
return rateLimitHeaders;
}
catch (InterruptedException e) {
throw new ApiResponseException(e, "Interrupted while waiting for rate limit headers");
}
}
@Override
public int getResponseTokens() {
return usage
.map(OUsage::completionTokens)
.orElse(0);
}
@Override
public int getPromptTokens() {
return usage
.map(OUsage::promptTokens)
.orElseGet(super::getPromptTokens);
}
private class OpenAiStreamingFunctionCall implements OpenAiFunctionCall {
private final String functionName;
private final String toolCallId;
private volatile String argumentData;
private volatile int tokenCount;
public OpenAiStreamingFunctionCall(OToolCall toolCall) {
this.functionName = toolCall.functionCall().name();
this.toolCallId = toolCall.id();
}
@Override
public String getFunctionName() {
return functionName;
}
private void setArgumentData(String argumentData) {
if (this.argumentData != null)
throw new IllegalStateException("Do not set argument data twice!");
this.argumentData = argumentData;
tokenCount = TokenUtils.calculateTokens(functionName, getPrompt().model()) +
TokenUtils.calculateTokens(argumentData, getPrompt().model());
}
@Override
public String getArgumentData() {
try {
responseCompleteLatch.await();
if (responseException != null)
throw new ApiResponseException(responseException, "Function call failed unexpectedly");
return argumentData;
}
catch (InterruptedException e) {
throw new ApiResponseException(e, "Interrupted while waiting for response");
}
}
@Override
public String getToolCallId() {
return toolCallId;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy