fi.evolver.ai.spring.provider.anthropic.AnthropicStreamingChatResponse Maven / Gradle / Ivy
package fi.evolver.ai.spring.provider.anthropic;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import fi.evolver.ai.spring.ApiResponseException;
import fi.evolver.ai.spring.ContentSubscriber;
import fi.evolver.ai.spring.chat.ChatResponse;
import fi.evolver.ai.spring.chat.FunctionCall;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt;
import fi.evolver.ai.spring.chat.prompt.Message;
import fi.evolver.ai.spring.provider.anthropic.response.*;
import fi.evolver.ai.spring.provider.anthropic.response.AChatStreamingResponse.AContentBlockDelta;
import fi.evolver.ai.spring.provider.anthropic.response.AChatStreamingResponse.AContentBlockStart;
import fi.evolver.ai.spring.provider.anthropic.response.AChatStreamingResponse.AMessageDelta;
import fi.evolver.ai.spring.provider.anthropic.response.AChatStreamingResponse.AMessageStart;
import fi.evolver.ai.spring.provider.anthropic.response.AContent.AInputJson;
import fi.evolver.ai.spring.provider.anthropic.response.AContent.ATextContent;
import fi.evolver.ai.spring.provider.anthropic.response.AContent.AToolUse;
import fi.evolver.utils.NullSafetyUtils;
/**
* A ChatResponse for streaming API calls.
*
* The streaming support works as follows:
* - Subscribers get any already received deltas on subscribe and any further deltas as they are received
* - Asking for the message blocks until the whole response is complete
*/
public class AnthropicStreamingChatResponse extends ChatResponse {
private static final Logger LOG = LoggerFactory.getLogger(AnthropicStreamingChatResponse.class);
private Optional message = Optional.empty();
private Map contents = new TreeMap<>();
private List subscribers = new ArrayList<>();
private final CountDownLatch readyLatch = new CountDownLatch(1);
private final CountDownLatch rateLimitHeadersLatch = new CountDownLatch(1);
private Message resultMessage;
private Optional content = Optional.empty();
private Map functionCallByIndex = new TreeMap<>();
private ARateLimitHeaders rateLimitHeaders;
private volatile Throwable responseException;
public AnthropicStreamingChatResponse(ChatPrompt prompt) {
super(prompt);
}
@Override
public int getResponseTokens() {
if (message.isEmpty())
return 0;
return message.get().usage().outputTokens();
}
@Override
public int getPromptTokens() {
if (message.isEmpty())
return super.getPromptTokens();
return message.get().usage().inputTokens();
}
synchronized void addResult(AChatStreamingResponse response) {
getMessageDelta(response).ifPresent(this::updateMessage);
getContentUpdate(response).ifPresent(this::contentUpdate);
}
private static Optional getMessageDelta(AChatStreamingResponse response) {
if (response instanceof AMessageStart m)
return Optional.of(m.message());
if (response instanceof AMessageDelta m)
return Optional.of(m.delta().withUsage(m.usage()));
return Optional.empty();
}
private void updateMessage(AMessage delta) {
this.message = merge(this.message.orElse(null), delta);
}
private static Optional merge(AMessage original, AMessage delta) {
if (original == null)
return Optional.ofNullable(delta);
return Optional.of(new AMessage(
NullSafetyUtils.denull(delta.id(), original.id()),
NullSafetyUtils.denull(delta.type(), original.type()),
NullSafetyUtils.denull(delta.role(), original.role()),
NullSafetyUtils.denull(delta.content(), original.content()),
NullSafetyUtils.denull(delta.model(), original.model()),
NullSafetyUtils.denull(delta.stopReason(), original.stopReason()),
NullSafetyUtils.denull(delta.stopSequence(), original.stopSequence()),
merge(delta.usage(), original.usage())));
}
private static AUsage merge(AUsage original, AUsage delta) {
if (original == null)
return delta;
if (delta == null)
return original;
return new AUsage(
NullSafetyUtils.denull(delta.inputTokens(), original.inputTokens()),
NullSafetyUtils.denull(delta.outputTokens(), original.outputTokens()));
}
private static Optional getContentUpdate(AChatStreamingResponse response) {
if (response instanceof AContentBlockStart c)
return Optional.of(new ContentUpdate(c.index(), c.contentBlock(), c.index() > 0));
if (response instanceof AContentBlockDelta c)
return Optional.of(new ContentUpdate(c.index(), c.delta(), false));
return Optional.empty();
}
private void contentUpdate(ContentUpdate delta) {
if (delta.content() instanceof ATextContent textContent) {
StringBuilder builder = contents.computeIfAbsent(delta.index(), i -> new StringBuilder());
builder.append(textContent.text());
for (ContentSubscriber subscriber: subscribers) {
try {
if (delta.boundary)
subscriber.onContent("\n\n");
subscriber.onContent(textContent.text());
}
catch (RuntimeException e) {
LOG.error("Subscriber failed handling content update", e);
}
}
}
else if (delta.content() instanceof AToolUse toolUse) {
functionCallByIndex.put(delta.index(), new AnthropicStreamingFunctionCall(toolUse));
}
else if (delta.content() instanceof AInputJson inputJson && functionCallByIndex.containsKey(delta.index())) {
functionCallByIndex.get(delta.index()).addArgumentData(inputJson.partial_json());
}
}
synchronized void handleError(Throwable throwable) {
this.responseException = throwable;
readyLatch.countDown();
rateLimitHeadersLatch.countDown();
for (ContentSubscriber subscriber: subscribers) {
try {
subscriber.onError(throwable);
}
catch (RuntimeException e) {
LOG.error("Subscriber failed handling stream error ({})", throwable.toString(), e);
}
}
}
synchronized void handleStreamEnd() {
String finishReason = getFinishReason().orElse(null);
if (finishReason == null) {
handleError(new IllegalStateException("Stream ended without finish reason"));
return;
}
if (readyLatch.getCount() == 0)
return;
content = Optional.of(getCurrentContent());
resultMessage = createMessage(content, functionCallByIndex.values().stream().toList());
readyLatch.countDown();
for (ContentSubscriber subscriber: subscribers) {
try {
subscriber.onComplete(finishReason);
}
catch (RuntimeException e) {
LOG.error("Subscriber failed handling stream completion", e);
}
}
}
private String getCurrentContent() {
StringBuilder builder = new StringBuilder();
for (StringBuilder content: contents.values()) {
if (builder.isEmpty())
builder.append("\n\n");
builder.append(content);
}
return builder.toString();
}
@Override
public synchronized void addSubscriber(ContentSubscriber subscriber) {
this.subscribers.add(subscriber);
String value = getCurrentContent();
if (!value.isEmpty())
subscriber.onContent(value);
if (responseException != null)
subscriber.onError(responseException);
else
getFinishReason().ifPresent(subscriber::onComplete);
}
public synchronized void addRateLimitHeaders(ARateLimitHeaders rateLimitHeaders) {
this.rateLimitHeaders = rateLimitHeaders;
rateLimitHeadersLatch.countDown();
}
@Override
public String getResultState() {
try {
getResponseMessage();
}
catch (RuntimeException e) {
return "error";
}
return getFinishReason().orElse("error");
}
@Override
public boolean isSuccess() {
return AnthropicService.FINISH_REASONS_OK.contains(getResultState());
}
private Optional getFinishReason() {
return message.map(AMessage::stopReason);
}
@Override
public Message getResponseMessage() {
try {
readyLatch.await();
if (responseException != null)
throw new ApiResponseException(responseException, "Reading message failed unexpectedly");
return resultMessage;
}
catch (InterruptedException e) {
throw new ApiResponseException(e, "Interrupted while waiting for response");
}
}
@Override
public Optional getTextContent() {
getResponseMessage();
return content;
}
@Override
public Optional getFunctionCall() {
return getFunctionCalls().stream().findFirst();
}
@Override
public List getFunctionCalls() {
try {
readyLatch.await();
if (responseException != null)
throw new ApiResponseException(responseException, "Reading message failed unexpectedly");
return functionCallByIndex.values().stream().map(FunctionCall.class::cast).toList();
}
catch (InterruptedException e) {
throw new ApiResponseException(e, "Interrupted while waiting for response");
}
}
@Override
public ARateLimitHeaders getRateLimitHeaders() {
try {
rateLimitHeadersLatch.await();
return rateLimitHeaders;
}
catch (InterruptedException e) {
throw new ApiResponseException(e, "Interrupted while waiting for rate limit headers");
}
}
private record ContentUpdate(int index, AContent content, boolean boundary) {}
private class AnthropicStreamingFunctionCall implements FunctionCall {
private final String functionName;
private final String toolCallId;
private volatile StringBuilder argumentDataBuilder;
public AnthropicStreamingFunctionCall(AToolUse toolUse) {
this.functionName = toolUse.name();
this.toolCallId = toolUse.id();
this.argumentDataBuilder = new StringBuilder();
}
@Override
public String getFunctionName() {
return functionName;
}
private void addArgumentData(String partialJson) {
this.argumentDataBuilder.append(partialJson);
}
@Override
public String getArgumentData() {
try {
readyLatch.await();
if (responseException != null)
throw new ApiResponseException(responseException, "Function call failed unexpectedly");
return argumentDataBuilder.toString();
}
catch (InterruptedException e) {
throw new ApiResponseException(e, "Interrupted while waiting for response");
}
}
@Override
public String getToolCallId() {
return toolCallId;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy