All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.ai.spring.provider.anthropic.AnthropicStreamingChatResponse Maven / Gradle / Ivy

package fi.evolver.ai.spring.provider.anthropic;

import java.util.*;
import java.util.concurrent.CountDownLatch;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import fi.evolver.ai.spring.ApiResponseException;
import fi.evolver.ai.spring.ContentSubscriber;
import fi.evolver.ai.spring.chat.ChatResponse;
import fi.evolver.ai.spring.chat.FunctionCall;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt;
import fi.evolver.ai.spring.chat.prompt.Message;
import fi.evolver.ai.spring.provider.anthropic.response.*;
import fi.evolver.ai.spring.provider.anthropic.response.AChatStreamingResponse.AContentBlockDelta;
import fi.evolver.ai.spring.provider.anthropic.response.AChatStreamingResponse.AContentBlockStart;
import fi.evolver.ai.spring.provider.anthropic.response.AChatStreamingResponse.AMessageDelta;
import fi.evolver.ai.spring.provider.anthropic.response.AChatStreamingResponse.AMessageStart;
import fi.evolver.ai.spring.provider.anthropic.response.AContent.AInputJson;
import fi.evolver.ai.spring.provider.anthropic.response.AContent.ATextContent;
import fi.evolver.ai.spring.provider.anthropic.response.AContent.AToolUse;
import fi.evolver.utils.NullSafetyUtils;


/**
 * A ChatResponse for streaming API calls.
 *
 * The streaming support works as follows:
 *  - Subscribers get any already received deltas on subscribe and any further deltas as they are received
 *  - Asking for the message blocks until the whole response is complete
 */
public class AnthropicStreamingChatResponse extends ChatResponse {
	private static final Logger LOG = LoggerFactory.getLogger(AnthropicStreamingChatResponse.class);

	private Optional message = Optional.empty();
	private Map contents = new TreeMap<>();

	private List subscribers = new ArrayList<>();

	private final CountDownLatch readyLatch = new CountDownLatch(1);
	private final CountDownLatch rateLimitHeadersLatch = new CountDownLatch(1);

	private Message resultMessage;
	private Optional content = Optional.empty();
	private Map functionCallByIndex = new TreeMap<>();

	private ARateLimitHeaders rateLimitHeaders;

	private volatile Throwable responseException;


	public AnthropicStreamingChatResponse(ChatPrompt prompt) {
		super(prompt);
	}

	@Override
	public int getResponseTokens() {
		if (message.isEmpty())
			return 0;

		return message.get().usage().outputTokens();
	}

	@Override
	public int getPromptTokens() {
		if (message.isEmpty())
			return super.getPromptTokens();

		return message.get().usage().inputTokens();
	}

	synchronized void addResult(AChatStreamingResponse response) {
		getMessageDelta(response).ifPresent(this::updateMessage);
		getContentUpdate(response).ifPresent(this::contentUpdate);
	}


	private static Optional getMessageDelta(AChatStreamingResponse response) {
		if (response instanceof AMessageStart m)
			return Optional.of(m.message());
		if (response instanceof AMessageDelta m)
			return Optional.of(m.delta().withUsage(m.usage()));
		return Optional.empty();
	}

	private void updateMessage(AMessage delta) {
		this.message = merge(this.message.orElse(null), delta);
	}

	private static Optional merge(AMessage original, AMessage delta) {
		if (original == null)
			return Optional.ofNullable(delta);

		return Optional.of(new AMessage(
				NullSafetyUtils.denull(delta.id(), original.id()),
				NullSafetyUtils.denull(delta.type(), original.type()),
				NullSafetyUtils.denull(delta.role(), original.role()),
				NullSafetyUtils.denull(delta.content(), original.content()),
				NullSafetyUtils.denull(delta.model(), original.model()),
				NullSafetyUtils.denull(delta.stopReason(), original.stopReason()),
				NullSafetyUtils.denull(delta.stopSequence(), original.stopSequence()),
				merge(delta.usage(), original.usage())));
	}


	private static AUsage merge(AUsage original, AUsage delta) {
		if (original == null)
			return delta;
		if (delta == null)
			return original;

		return new AUsage(
				NullSafetyUtils.denull(delta.inputTokens(), original.inputTokens()),
				NullSafetyUtils.denull(delta.outputTokens(), original.outputTokens()));
	}


	private static Optional getContentUpdate(AChatStreamingResponse response) {
		if (response instanceof AContentBlockStart c)
			return Optional.of(new ContentUpdate(c.index(), c.contentBlock(), c.index() > 0));
		if (response instanceof AContentBlockDelta c)
			return Optional.of(new ContentUpdate(c.index(), c.delta(), false));
		return Optional.empty();
	}

	private void contentUpdate(ContentUpdate delta) {
		if (delta.content() instanceof ATextContent textContent) {
			StringBuilder builder = contents.computeIfAbsent(delta.index(), i -> new StringBuilder());
			builder.append(textContent.text());

			for (ContentSubscriber subscriber: subscribers) {
				try {
					if (delta.boundary)
						subscriber.onContent("\n\n");
					subscriber.onContent(textContent.text());
				}
				catch (RuntimeException e) {
					LOG.error("Subscriber failed handling content update", e);
				}
			}
		}
		else if (delta.content() instanceof AToolUse toolUse) {
			functionCallByIndex.put(delta.index(), new AnthropicStreamingFunctionCall(toolUse));
		}
		else if (delta.content() instanceof AInputJson inputJson && functionCallByIndex.containsKey(delta.index())) {
			functionCallByIndex.get(delta.index()).addArgumentData(inputJson.partial_json());
		}
		
	}


	synchronized void handleError(Throwable throwable) {
		this.responseException = throwable;
		readyLatch.countDown();
		rateLimitHeadersLatch.countDown();

		for (ContentSubscriber subscriber: subscribers) {
			try {
				subscriber.onError(throwable);
			}
			catch (RuntimeException e) {
				LOG.error("Subscriber failed handling stream error ({})", throwable.toString(), e);
			}
		}
	}


	synchronized void handleStreamEnd() {
		String finishReason = getFinishReason().orElse(null);
		if (finishReason == null) {
			handleError(new IllegalStateException("Stream ended without finish reason"));
			return;
		}

		if (readyLatch.getCount() == 0)
			return;

		content = Optional.of(getCurrentContent());
		resultMessage = createMessage(content, functionCallByIndex.values().stream().toList());
		readyLatch.countDown();

		for (ContentSubscriber subscriber: subscribers) {
			try {
				subscriber.onComplete(finishReason);
			}
			catch (RuntimeException e) {
				LOG.error("Subscriber failed handling stream completion", e);
			}
		}
	}

	private String getCurrentContent() {
		StringBuilder builder = new StringBuilder();
		for (StringBuilder content: contents.values()) {
			if (builder.isEmpty())
				builder.append("\n\n");
			builder.append(content);
		}
		return builder.toString();
	}


	@Override
	public synchronized void addSubscriber(ContentSubscriber subscriber) {
		this.subscribers.add(subscriber);
		String value = getCurrentContent();
		if (!value.isEmpty())
			subscriber.onContent(value);

		if (responseException != null)
			subscriber.onError(responseException);
		else
			getFinishReason().ifPresent(subscriber::onComplete);
	}

	public synchronized void addRateLimitHeaders(ARateLimitHeaders rateLimitHeaders) {
		this.rateLimitHeaders = rateLimitHeaders;
		rateLimitHeadersLatch.countDown();
	}


	@Override
	public String getResultState() {
		try {
			getResponseMessage();
		}
		catch (RuntimeException e) {
			return "error";
		}
		return getFinishReason().orElse("error");
	}


	@Override
	public boolean isSuccess() {
		return AnthropicService.FINISH_REASONS_OK.contains(getResultState());
	}


	private Optional getFinishReason() {
		return message.map(AMessage::stopReason);
	}


	@Override
	public Message getResponseMessage() {
		try {
			readyLatch.await();
			if (responseException != null)
				throw new ApiResponseException(responseException, "Reading message failed unexpectedly");
			return resultMessage;
		}
		catch (InterruptedException e) {
			throw new ApiResponseException(e, "Interrupted while waiting for response");
		}
	}


	@Override
	public Optional getTextContent() {
		getResponseMessage();
		return content;
	}


	@Override
	public Optional getFunctionCall() {
		return getFunctionCalls().stream().findFirst();
	}


	@Override
	public List getFunctionCalls() {
		try {
			readyLatch.await();
			if (responseException != null)
				throw new ApiResponseException(responseException, "Reading message failed unexpectedly");
			return functionCallByIndex.values().stream().map(FunctionCall.class::cast).toList();
		}
		catch (InterruptedException e) {
			throw new ApiResponseException(e, "Interrupted while waiting for response");
		}
	}

	@Override
	public ARateLimitHeaders getRateLimitHeaders() {
		try {
			rateLimitHeadersLatch.await();
			return rateLimitHeaders;
		}
		catch (InterruptedException e) {
			throw new ApiResponseException(e, "Interrupted while waiting for rate limit headers");
		}
	}


	private record ContentUpdate(int index, AContent content, boolean boundary) {}

	private class AnthropicStreamingFunctionCall implements FunctionCall {
		private final String functionName;
		private final String toolCallId;
		private volatile StringBuilder argumentDataBuilder;

		public AnthropicStreamingFunctionCall(AToolUse toolUse) {
			this.functionName = toolUse.name();
			this.toolCallId = toolUse.id();
			this.argumentDataBuilder = new StringBuilder();
		}

		@Override
		public String getFunctionName() {
			return functionName;
		}


		private void addArgumentData(String partialJson) {
			this.argumentDataBuilder.append(partialJson);
		}


		@Override
		public String getArgumentData() {
			try {
				readyLatch.await();
				if (responseException != null)
					throw new ApiResponseException(responseException, "Function call failed unexpectedly");
				return argumentDataBuilder.toString();
			}
			catch (InterruptedException e) {
				throw new ApiResponseException(e, "Interrupted while waiting for response");
			}
		}

		@Override
		public String getToolCallId() {
			return toolCallId;
		}

	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy