All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.ai.spring.provider.openai.OpenAiStreamingChatResponse Maven / Gradle / Ivy

package fi.evolver.ai.spring.provider.openai;

import java.util.*;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.CountDownLatch;
import java.util.stream.Collectors;

import fi.evolver.ai.spring.provider.openai.response.OUsage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import fi.evolver.ai.spring.ApiResponseException;
import fi.evolver.ai.spring.ContentSubscriber;
import fi.evolver.ai.spring.chat.ChatResponse;
import fi.evolver.ai.spring.chat.FunctionCall;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt;
import fi.evolver.ai.spring.chat.prompt.Message;
import fi.evolver.ai.spring.provider.openai.response.ORateLimitHeaders;
import fi.evolver.ai.spring.provider.openai.response.chat.OChatResult;
import fi.evolver.ai.spring.provider.openai.response.chat.OChoice;
import fi.evolver.ai.spring.provider.openai.response.chat.ODelta;
import fi.evolver.ai.spring.provider.openai.response.chat.OFunctionCall;
import fi.evolver.ai.spring.provider.openai.response.chat.OToolCall;
import fi.evolver.ai.spring.util.TokenUtils;
import fi.evolver.utils.collection.CollectionUtils;
import fi.evolver.utils.string.StringUtils;


/**
 * A ChatResponse for streaming API calls.
 *
 * The streaming support works as follows:
 *  - Subscribers get any already received deltas on subscribe and any further deltas as they are received
 *  - Asking for the message blocks until the whole response is complete
 *  - Asking for a single function call will return as soon as we know the name of the first function call
 *  - Asking for function call arguments blocks until the whole response is complete
 *  - Asking for all function calls blocks until the whole response is complete (we do not know how many calls to wait for)
 */
public class OpenAiStreamingChatResponse extends ChatResponse {
	private static final Logger LOG = LoggerFactory.getLogger(OpenAiStreamingChatResponse.class);

	private final Deque results = new ConcurrentLinkedDeque<>();
	private List subscribers = new ArrayList<>();

	private final CountDownLatch responseCompleteLatch = new CountDownLatch(1);
	private final CountDownLatch functionCallLatch = new CountDownLatch(1);
	private final CountDownLatch rateLimitHeadersLatch = new CountDownLatch(1);

	private Optional content = Optional.empty();
	private Optional functionCall = Optional.empty();
	private List functionCalls;
	private Message responseMessage;

	private ORateLimitHeaders rateLimitHeaders;
	private volatile Throwable responseException;
	private volatile Optional usage;

	public OpenAiStreamingChatResponse(ChatPrompt prompt) {
		super(prompt);
	}


	public synchronized void addResult(OChatResult result) {
		results.add(result);
		if (result.choices().isEmpty())		// Azure returns a prompt_filter_results without any actual content as the first chunk.
			return;

		if (functionCall.isEmpty()) {
			Optional toolCall = getFunctionCall(result);
			Optional name = toolCall
					.map(OToolCall::functionCall)
					.filter(Objects::nonNull)
					.map(OFunctionCall::name);
			if (toolCall.isPresent() && name.isPresent()) {
				functionCall = toolCall.map(OpenAiStreamingFunctionCall::new);
				functionCallLatch.countDown();
			}
		}

		Optional contentUpdate = getContent(result);
		for (ContentSubscriber subscriber: subscribers) {
			try {
				contentUpdate.ifPresent(subscriber::onContent);
			}
			catch (RuntimeException e) {
				LOG.error("Subscriber failed handling content update", e);
			}
		}
	}


	public synchronized void handleError(Throwable throwable) {
		this.responseException = throwable;
		this.usage = readUsage();

		responseCompleteLatch.countDown();
		functionCallLatch.countDown();
		rateLimitHeadersLatch.countDown();

		for (ContentSubscriber subscriber: subscribers) {
			try {
				subscriber.onError(throwable);
			}
			catch (RuntimeException e) {
				LOG.error("Subscriber failed handling stream error ({})", throwable.toString(), e);
			}
		}
	}


	public synchronized void handleStreamEnd() {
		String finishReason = getFinishReason().orElse(null);
		if (finishReason == null) {
			handleError(new IllegalStateException("Stream ended without finish reason"));
			return;
		}

		this.usage = readUsage();

		if (responseCompleteLatch.getCount() == 0)
			return;

		functionCall.ifPresent(fc -> fc.setArgumentData(results.stream()
				.map(OpenAiStreamingChatResponse::getFunctionCall)
				.filter(Optional::isPresent)
				.map(Optional::get)
				.map(OToolCall::functionCall)
				.map(OFunctionCall::arguments)
				.filter(Objects::nonNull)
				.collect(Collectors.joining())));

		content = Optional.of(results.stream()
				.map(OpenAiStreamingChatResponse::getContent)
				.filter(Optional::isPresent)
				.map(Optional::get)
				.collect(Collectors.joining()))
				.filter(StringUtils::hasText);
		functionCalls = createFunctionCalls();
		responseMessage = createMessage(content, functionCalls);

		functionCallLatch.countDown();
		responseCompleteLatch.countDown();

		for (ContentSubscriber subscriber: subscribers) {
			try {
				subscriber.onComplete(finishReason);
			}
			catch (RuntimeException e) {
				LOG.error("Subscriber failed handling stream completion", e);
			}
		}
	}


	private static Optional getFunctionCall(OChatResult result) {
		return result.choices().stream()
				.map(OChoice::delta)
				.map(ODelta::toolCalls)
				.filter(Objects::nonNull)
				.flatMap(List::stream)
				.filter(call -> call.functionCall() != null)
				.findFirst();
	}


	private static Optional getContent(OChatResult result) {
		return Optional.of(result.choices().stream()
				.map(OChoice::delta)
				.map(ODelta::content)
				.filter(Objects::nonNull)
				.collect(Collectors.joining()))
				.filter(s -> !s.isEmpty());
	}


	private List createFunctionCalls() {
		Map functionCallByIndex = new TreeMap<>();
		Map argumentsByIndex = new HashMap<>();

		for (OChatResult result: results) {
			List toolCalls = CollectionUtils.first(result.choices())
					.map(OChoice::delta)
					.map(ODelta::toolCalls)
					.orElseGet(List::of);

			for (OToolCall toolCall: toolCalls) {
				OFunctionCall update = toolCall.functionCall();
				if (update == null)
					continue;
				if (update.name() != null)
					functionCallByIndex.put(toolCall.index(), new OpenAiStreamingFunctionCall(toolCall));
				if (update.arguments() != null)
					argumentsByIndex.computeIfAbsent(toolCall.index(), k -> new StringBuilder()).append(update.arguments());
			}
		}

		StringBuilder emptyBuilder = new StringBuilder();
		functionCallByIndex.forEach((i, f) -> f.setArgumentData(argumentsByIndex.getOrDefault(i, emptyBuilder).toString()));
		return new ArrayList<>(functionCallByIndex.values());
	}


	@Override
	public synchronized void addSubscriber(ContentSubscriber subscriber) {
		this.subscribers.add(subscriber);
		for (OChatResult result: results)
			getContent(result).ifPresent(subscriber::onContent);

		if (responseException != null)
			subscriber.onError(responseException);
		else
			getFinishReason().ifPresent(subscriber::onComplete);
	}

	public synchronized void addRateLimitHeaders(ORateLimitHeaders rateLimitHeaders) {
		this.rateLimitHeaders = rateLimitHeaders;
		rateLimitHeadersLatch.countDown();
	}


	@Override
	public String getResultState() {
		try {
			getResponseMessage();
			getFunctionCall().ifPresent(FunctionCall::getArgumentData);
		}
		catch (RuntimeException e) {
			return "error";
		}
		return getFinishReason().orElse("error");
	}


	@Override
	public boolean isSuccess() {
		return OpenAiService.FINISH_REASONS_OK.contains(getResultState());
	}


	private Optional readUsage() {
		Iterator iterator = results.descendingIterator();
		while (iterator.hasNext()) {
			OUsage usage = iterator.next().usage();
			if (usage != null)
				return Optional.of(usage);
		}

		return Optional.empty();
	}

	private Optional getFinishReason() {
		Iterator iterator = results.descendingIterator();
		while (iterator.hasNext()) {
			Optional finishReason = getFinishReason(iterator.next());
			if (finishReason.isPresent())
				return finishReason;
		}

		return Optional.empty();
	}


	private static Optional getFinishReason(OChatResult result) {
		return result.choices().stream().map(OChoice::finishReason).filter(Objects::nonNull).findFirst();
	}


	@Override
	public Message getResponseMessage() {
		try {
			responseCompleteLatch.await();
			if (responseException != null)
				throw new ApiResponseException(responseException, "Reading message failed unexpectedly");
			return responseMessage;
		}
		catch (InterruptedException e) {
			throw new ApiResponseException(e, "Interrupted while waiting for response");
		}
	}


	@Override
	public Optional getTextContent() {
		getResponseMessage();
		return content;
	}


	@Override
	public Optional getFunctionCall() {
		try {
			functionCallLatch.await();
			if (responseException != null)
				throw new ApiResponseException(responseException, "Reading message failed unexpectedly");
			return functionCall.map(OpenAiFunctionCall.class::cast);
		}
		catch (InterruptedException e) {
			throw new ApiResponseException(e, "Interrupted while waiting for response");
		}
	}


	@Override
	public List getFunctionCalls() {
		getResponseMessage();
		return functionCalls;
	}

	@Override
	public ORateLimitHeaders getRateLimitHeaders() {
		try {
			rateLimitHeadersLatch.await();
			return rateLimitHeaders;
		}
		catch (InterruptedException e) {
			throw new ApiResponseException(e, "Interrupted while waiting for rate limit headers");
		}
	}

	@Override
	public int getResponseTokens() {
		return usage
				.map(OUsage::completionTokens)
				.orElse(0);
	}

	@Override
	public int getPromptTokens() {
		return usage
				.map(OUsage::promptTokens)
				.orElseGet(super::getPromptTokens);
	}

	private class OpenAiStreamingFunctionCall implements OpenAiFunctionCall {
		private final String functionName;
		private final String toolCallId;
		private volatile String argumentData;

		private volatile int tokenCount;

		public OpenAiStreamingFunctionCall(OToolCall toolCall) {
			this.functionName = toolCall.functionCall().name();
			this.toolCallId = toolCall.id();
		}

		@Override
		public String getFunctionName() {
			return functionName;
		}


		private void setArgumentData(String argumentData) {
			if (this.argumentData != null)
				throw new IllegalStateException("Do not set argument data twice!");
			this.argumentData = argumentData;
			tokenCount = TokenUtils.calculateTokens(functionName, getPrompt().model()) +
					TokenUtils.calculateTokens(argumentData, getPrompt().model());
		}


		@Override
		public String getArgumentData() {
			try {
				responseCompleteLatch.await();
				if (responseException != null)
					throw new ApiResponseException(responseException, "Function call failed unexpectedly");
				return argumentData;
			}
			catch (InterruptedException e) {
				throw new ApiResponseException(e, "Interrupted while waiting for response");
			}
		}

		@Override
		public String getToolCallId() {
			return toolCallId;
		}

	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy