All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.springframework.kafka.requestreply.ReplyingKafkaTemplate Maven / Gradle / Ivy

/*
 * Copyright 2018-2023 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.kafka.requestreply;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.Collection;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.header.Headers;
import org.apache.kafka.common.header.internals.RecordHeader;

import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.SmartLifecycle;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.log.LogAccessor;
import org.springframework.kafka.KafkaException;
import org.springframework.kafka.core.KafkaTemplate;
import org.springframework.kafka.core.ProducerFactory;
import org.springframework.kafka.listener.BatchMessageListener;
import org.springframework.kafka.listener.ConsumerSeekAware;
import org.springframework.kafka.listener.ContainerProperties;
import org.springframework.kafka.listener.GenericMessageListenerContainer;
import org.springframework.kafka.support.KafkaHeaders;
import org.springframework.kafka.support.KafkaUtils;
import org.springframework.kafka.support.TopicPartitionOffset;
import org.springframework.kafka.support.serializer.DeserializationException;
import org.springframework.kafka.support.serializer.SerializationUtils;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.Assert;

/**
 * A KafkaTemplate that implements request/reply semantics.
 *
 * @param  the key type.
 * @param  the outbound data type.
 * @param  the reply data type.
 *
 * @author Gary Russell
 * @author Artem Bilan
 *
 * @since 2.1.3
 *
 */
public class ReplyingKafkaTemplate extends KafkaTemplate implements BatchMessageListener,
		InitializingBean, SmartLifecycle, DisposableBean, ReplyingKafkaOperations, ConsumerSeekAware {

	private static final String WITH_CORRELATION_ID = " with correlationId: ";

	private static final int FIVE = 5;

	private static final Duration DEFAULT_REPLY_TIMEOUT = Duration.ofSeconds(FIVE);

	private final GenericMessageListenerContainer replyContainer;

	private final ConcurrentMap> futures = new ConcurrentHashMap<>();

	private final byte[] replyTopic;

	private final byte[] replyPartition;

	private TaskScheduler scheduler = new ThreadPoolTaskScheduler();

	private int phase;

	private boolean autoStartup = true;

	private Duration defaultReplyTimeout = DEFAULT_REPLY_TIMEOUT;

	private boolean schedulerSet;

	private boolean sharedReplyTopic;

	private Function, CorrelationKey> correlationStrategy =
			ReplyingKafkaTemplate::defaultCorrelationIdStrategy;

	private boolean binaryCorrelation = true;

	private String correlationHeaderName = KafkaHeaders.CORRELATION_ID;

	private String replyTopicHeaderName = KafkaHeaders.REPLY_TOPIC;

	private String replyPartitionHeaderName = KafkaHeaders.REPLY_PARTITION;

	private Function, Exception> replyErrorChecker = rec -> null;

	private CountDownLatch assignLatch = new CountDownLatch(1);

	private volatile boolean running;

	private volatile boolean schedulerInitialized;

	public ReplyingKafkaTemplate(ProducerFactory producerFactory,
			GenericMessageListenerContainer replyContainer) {

		this(producerFactory, replyContainer, false);
	}

	public ReplyingKafkaTemplate(ProducerFactory producerFactory,
			GenericMessageListenerContainer replyContainer, boolean autoFlush) {

		super(producerFactory, autoFlush);
		Assert.notNull(replyContainer, "'replyContainer' cannot be null");
		this.replyContainer = replyContainer;
		this.replyContainer.setupMessageListener(this);
		ContainerProperties properties = this.replyContainer.getContainerProperties();
		String tempReplyTopic = null;
		byte[] tempReplyPartition = null;
		TopicPartitionOffset[] topicPartitionsToAssign = properties.getTopicPartitions();
		String[] topics = properties.getTopics();
		if (topics != null && topics.length == 1) {
			tempReplyTopic = topics[0];
		}
		else if (topicPartitionsToAssign != null && topicPartitionsToAssign.length == 1) {
			TopicPartitionOffset topicPartitionOffset = topicPartitionsToAssign[0];
			Assert.notNull(topicPartitionOffset, "'topicPartitionsToAssign' must not be null");
			tempReplyTopic = topicPartitionOffset.getTopic();
			ByteBuffer buffer = ByteBuffer.allocate(4); // NOSONAR magic #
			buffer.putInt(topicPartitionOffset.getPartition());
			tempReplyPartition = buffer.array();
		}
		if (tempReplyTopic == null) {
			this.replyTopic = null;
			this.replyPartition = null;
			this.logger.debug(() -> "Could not determine container's reply topic/partition; senders must populate "
					+ "at least the " + KafkaHeaders.REPLY_TOPIC + " header, and optionally the "
					+ KafkaHeaders.REPLY_PARTITION + " header");
		}
		else {
			this.replyTopic = tempReplyTopic.getBytes(StandardCharsets.UTF_8);
			this.replyPartition = tempReplyPartition;
		}
	}

	public void setTaskScheduler(TaskScheduler scheduler) {
		Assert.notNull(scheduler, "'scheduler' cannot be null");
		this.scheduler = scheduler;
		this.schedulerSet = true;
	}

	/**
	 * Return the reply timeout used if no replyTimeout is provided in the
	 * {@link #sendAndReceive(ProducerRecord, Duration)} call.
	 * @return the timeout.
	 * @since 2.3
	 */
	protected Duration getDefaultReplyTimeout() {
		return this.defaultReplyTimeout;
	}

	/**
	 * Set the reply timeout used if no replyTimeout is provided in the
	 * {@link #sendAndReceive(ProducerRecord, Duration)} call.
	 * @param defaultReplyTimeout the timeout.
	 * @since 2.3
	 */
	public void setDefaultReplyTimeout(Duration defaultReplyTimeout) {
		Assert.notNull(defaultReplyTimeout, "'defaultReplyTimeout' cannot be null");
		Assert.isTrue(defaultReplyTimeout.toMillis() >= 0, "'replyTimeout' must be >= 0");
		this.defaultReplyTimeout = defaultReplyTimeout;
	}

	@Override
	public boolean isRunning() {
		return this.running;
	}

	@Override
	public int getPhase() {
		return this.phase;
	}

	public void setPhase(int phase) {
		this.phase = phase;
	}

	@Override
	public boolean isAutoStartup() {
		return this.autoStartup;
	}

	public void setAutoStartup(boolean autoStartup) {
		this.autoStartup = autoStartup;
	}

	/**
	 * Return the topics/partitions assigned to the replying listener container.
	 * @return the topics/partitions.
	 */
	public Collection getAssignedReplyTopicPartitions() {
		return this.replyContainer.getAssignedPartitions();
	}

	/**
	 * Set to true when multiple templates are using the same topic for replies. This
	 * simply changes logs for unexpected replies to debug instead of error.
	 * @param sharedReplyTopic true if using a shared topic.
	 * @since 2.2
	 */
	public void setSharedReplyTopic(boolean sharedReplyTopic) {
		this.sharedReplyTopic = sharedReplyTopic;
	}

	/**
	 * Set a function to be called to establish a unique correlation key for each request
	 * record.
	 * @param correlationStrategy the function.
	 * @since 2.3
	 */
	public void setCorrelationIdStrategy(Function, CorrelationKey> correlationStrategy) {
		Assert.notNull(correlationStrategy, "'correlationStrategy' cannot be null");
		this.correlationStrategy = correlationStrategy;
	}

	/**
	 * Set a custom header name for the correlation id. Default
	 * {@link KafkaHeaders#CORRELATION_ID}.
	 * @param correlationHeaderName the header name.
	 * @since 2.3
	 */
	public void setCorrelationHeaderName(String correlationHeaderName) {
		Assert.notNull(correlationHeaderName, "'correlationHeaderName' cannot be null");
		this.correlationHeaderName = correlationHeaderName;
	}

	/**
	 * Return the correlation header name.
	 * @return the header name.
	 * @since 2.8.8
	 */
	protected String getCorrelationHeaderName() {
		return this.correlationHeaderName;
	}

	/**
	 * Set a custom header name for the reply topic. Default
	 * {@link KafkaHeaders#REPLY_TOPIC}.
	 * @param replyTopicHeaderName the header name.
	 * @since 2.3
	 */
	public void setReplyTopicHeaderName(String replyTopicHeaderName) {
		Assert.notNull(replyTopicHeaderName, "'replyTopicHeaderName' cannot be null");
		this.replyTopicHeaderName = replyTopicHeaderName;
	}

	/**
	 * Set a custom header name for the reply partition. Default
	 * {@link KafkaHeaders#REPLY_PARTITION}.
	 * @param replyPartitionHeaderName the reply partition header name.
	 * @since 2.3
	 */
	public void setReplyPartitionHeaderName(String replyPartitionHeaderName) {
		Assert.notNull(replyPartitionHeaderName, "'replyPartitionHeaderName' cannot be null");
		this.replyPartitionHeaderName = replyPartitionHeaderName;
	}

	/**
	 * Set a function to examine replies for an error returned by the server.
	 * @param replyErrorChecker the error checker function.
	 * @since 2.6.7
	 */
	public void setReplyErrorChecker(Function, Exception> replyErrorChecker) {
		Assert.notNull(replyErrorChecker, "'replyErrorChecker' cannot be null");
		this.replyErrorChecker = replyErrorChecker;
	}

	/**
	 * Set to false to use the String representation of the correlation as the
	 * correlationId rather than the binary representation. Default true.
	 * @param binaryCorrelation false for String.
	 * @since 3.0
	 */
	public void setBinaryCorrelation(boolean binaryCorrelation) {
		this.binaryCorrelation = binaryCorrelation;
	}

	protected boolean isBinaryCorrelation() {
		return this.binaryCorrelation;
	}

	@Override
	public void afterPropertiesSet() {
		if (!this.schedulerSet && !this.schedulerInitialized) {
			((ThreadPoolTaskScheduler) this.scheduler).initialize();
			this.schedulerInitialized = true;
		}
	}

	@Override
	public synchronized void start() {
		if (!this.running) {
			try {
				afterPropertiesSet();
			}
			catch (Exception e) {
				throw new KafkaException("Failed to initialize", e);
			}
			this.assignLatch = new CountDownLatch(1);
			this.replyContainer.start();
			this.running = true;
		}
	}

	@Override
	public synchronized void stop() {
		if (this.running) {
			this.running = false;
			this.replyContainer.stop();
			this.futures.clear();
		}
	}

	@Override
	public void stop(Runnable callback) {
		stop();
		callback.run();
	}

	@Override
	public void onFirstPoll() {
		this.assignLatch.countDown();
	}

	@Override
	public boolean waitForAssignment(Duration duration) throws InterruptedException {
		return this.assignLatch.await(duration.toMillis(), TimeUnit.MILLISECONDS);
	}

	@Override
	public RequestReplyMessageFuture sendAndReceive(Message message) {
		return sendAndReceive(message, this.defaultReplyTimeout, null);
	}

	@Override
	public RequestReplyMessageFuture sendAndReceive(Message message, Duration replyTimeout) {
		return sendAndReceive(message, replyTimeout, null);
	}

	@Override
	public 

RequestReplyTypedMessageFuture sendAndReceive(Message message, @Nullable ParameterizedTypeReference

returnType) { return sendAndReceive(message, this.defaultReplyTimeout, returnType); } @SuppressWarnings("unchecked") @Override public

RequestReplyTypedMessageFuture sendAndReceive(Message message, @Nullable Duration replyTimeout, @Nullable ParameterizedTypeReference

returnType) { RequestReplyFuture future = sendAndReceive((ProducerRecord) getMessageConverter() .fromMessage(message, getDefaultTopic()), replyTimeout); RequestReplyTypedMessageFuture replyFuture = new RequestReplyTypedMessageFuture<>(future.getSendFuture()); future.whenComplete((result, ex) -> { if (ex == null) { try { replyFuture.complete(getMessageConverter() .toMessage(result, null, null, returnType == null ? null : returnType.getType())); } catch (Exception ex2) { // NOSONAR replyFuture.completeExceptionally(ex2); } } else { replyFuture.completeExceptionally(ex); } }); return replyFuture; } @Override public RequestReplyFuture sendAndReceive(ProducerRecord record) { return sendAndReceive(record, this.defaultReplyTimeout); } @Override public RequestReplyFuture sendAndReceive(ProducerRecord record, @Nullable Duration replyTimeout) { Assert.state(this.running, "Template has not been start()ed"); // NOSONAR (sync) Duration timeout = replyTimeout; if (timeout == null) { timeout = this.defaultReplyTimeout; } CorrelationKey correlationId = this.correlationStrategy.apply(record); Assert.notNull(correlationId, "the created 'correlationId' cannot be null"); Headers headers = record.headers(); boolean hasReplyTopic = headers.lastHeader(KafkaHeaders.REPLY_TOPIC) != null; if (!hasReplyTopic && this.replyTopic != null) { headers.add(new RecordHeader(this.replyTopicHeaderName, this.replyTopic)); if (this.replyPartition != null) { headers.add(new RecordHeader(this.replyPartitionHeaderName, this.replyPartition)); } } Object correlation = this.binaryCorrelation ? correlationId : correlationId.toString(); byte[] correlationValue = this.binaryCorrelation ? correlationId.getCorrelationId() : ((String) correlation).getBytes(StandardCharsets.UTF_8); headers.add(new RecordHeader(this.correlationHeaderName, correlationValue)); this.logger.debug(() -> "Sending: " + KafkaUtils.format(record) + WITH_CORRELATION_ID + correlationId); RequestReplyFuture future = new RequestReplyFuture<>(); this.futures.put(correlation, future); try { future.setSendFuture(send(record)); } catch (Exception e) { this.futures.remove(correlation); throw new KafkaException("Send failed", e); } scheduleTimeout(record, correlation, timeout); return future; } private void scheduleTimeout(ProducerRecord record, Object correlationId, Duration replyTimeout) { this.scheduler.schedule(() -> { RequestReplyFuture removed = this.futures.remove(correlationId); if (removed != null) { this.logger.warn(() -> "Reply timed out for: " + KafkaUtils.format(record) + WITH_CORRELATION_ID + correlationId); if (!handleTimeout(correlationId, removed)) { removed.completeExceptionally(new KafkaReplyTimeoutException("Reply timed out")); } } }, Instant.now().plus(replyTimeout)); } /** * Used to inform subclasses that a request has timed out so they can clean up state * and, optionally, complete the future. * @param correlationId the correlation id. * @param future the future. * @return true to indicate the future has been completed. * @since 2.3 */ protected boolean handleTimeout(@SuppressWarnings("unused") Object correlationId, @SuppressWarnings("unused") RequestReplyFuture future) { return false; } /** * Return true if this correlation id is still active. * @param correlationId the correlation id. * @return true if pending. * @since 2.3 */ protected boolean isPending(Object correlationId) { return this.futures.containsKey(correlationId); } @Override public void destroy() { if (!this.schedulerSet) { ((ThreadPoolTaskScheduler) this.scheduler).destroy(); } } private static CorrelationKey defaultCorrelationIdStrategy( @SuppressWarnings("unused") ProducerRecord record) { UUID uuid = UUID.randomUUID(); byte[] bytes = new byte[16]; // NOSONAR magic # ByteBuffer bb = ByteBuffer.wrap(bytes); bb.putLong(uuid.getMostSignificantBits()); bb.putLong(uuid.getLeastSignificantBits()); return new CorrelationKey(bytes); } @Override public void onMessage(List> data) { data.forEach(record -> { Header correlationHeader = record.headers().lastHeader(this.correlationHeaderName); Object correlationId = null; if (correlationHeader != null) { correlationId = this.binaryCorrelation ? new CorrelationKey(correlationHeader.value()) : new String(correlationHeader.value(), StandardCharsets.UTF_8); } if (correlationId == null) { this.logger.error(() -> "No correlationId found in reply: " + KafkaUtils.format(record) + " - to use request/reply semantics, the responding server must return the correlation id " + " in the '" + this.correlationHeaderName + "' header"); } else { RequestReplyFuture future = this.futures.remove(correlationId); Object correlationKey = correlationId; if (future == null) { logLateArrival(record, correlationId); } else { boolean ok = true; Exception exception = checkForErrors(record); if (exception != null) { ok = false; future.completeExceptionally(exception); } if (ok) { this.logger.debug(() -> "Received: " + KafkaUtils.format(record) + WITH_CORRELATION_ID + correlationKey); future.complete(record); } } } }); } /** * Check for errors in a reply. The default implementation checks for {@link DeserializationException}s * and invokes the {@link #setReplyErrorChecker(Function) replyErrorChecker} function. * @param record the record. * @return the exception, or null if none. * @since 2.6.7 */ @Nullable protected Exception checkForErrors(ConsumerRecord record) { if (record.value() == null || record.key() == null) { DeserializationException de = checkDeserialization(record, this.logger); if (de != null) { return de; } } return this.replyErrorChecker.apply(record); } /** * Return a {@link DeserializationException} if either the key or value failed * deserialization; null otherwise. If you need to determine whether it was the key or * value, call * {@link SerializationUtils#getExceptionFromHeader(ConsumerRecord, String, LogAccessor)} * with {@link SerializationUtils#KEY_DESERIALIZER_EXCEPTION_HEADER} and * {@link SerializationUtils#VALUE_DESERIALIZER_EXCEPTION_HEADER} instead. * @param record the record. * @param logger a {@link LogAccessor}. * @return the {@link DeserializationException} or {@code null}. * @since 2.2.15 */ @Nullable public static DeserializationException checkDeserialization(ConsumerRecord record, LogAccessor logger) { DeserializationException exception = SerializationUtils.getExceptionFromHeader(record, SerializationUtils.VALUE_DESERIALIZER_EXCEPTION_HEADER, logger); if (exception != null) { logger.error(exception, () -> "Reply value deserialization failed for " + record.topic() + "-" + record.partition() + "@" + record.offset()); return exception; } exception = SerializationUtils.getExceptionFromHeader(record, SerializationUtils.KEY_DESERIALIZER_EXCEPTION_HEADER, logger); if (exception != null) { logger.error(exception, () -> "Reply key deserialization failed for " + record.topic() + "-" + record.partition() + "@" + record.offset()); return exception; } return null; } protected void logLateArrival(ConsumerRecord record, Object correlationId) { if (this.sharedReplyTopic) { this.logger.debug(() -> missingCorrelationLogMessage(record, correlationId)); } else { this.logger.error(() -> missingCorrelationLogMessage(record, correlationId)); } } private String missingCorrelationLogMessage(ConsumerRecord record, Object correlationId) { return "No pending reply: " + KafkaUtils.format(record) + WITH_CORRELATION_ID + correlationId + ", perhaps timed out, or using a shared reply topic"; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy