All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.springframework.kafka.requestreply.ReplyingKafkaTemplate Maven / Gradle / Ivy

/*
 * Copyright 2018-2020 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.kafka.requestreply;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.Collection;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Function;

import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.header.Headers;
import org.apache.kafka.common.header.internals.RecordHeader;

import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.SmartLifecycle;
import org.springframework.core.log.LogAccessor;
import org.springframework.kafka.KafkaException;
import org.springframework.kafka.core.KafkaTemplate;
import org.springframework.kafka.core.ProducerFactory;
import org.springframework.kafka.listener.BatchMessageListener;
import org.springframework.kafka.listener.ContainerProperties;
import org.springframework.kafka.listener.GenericMessageListenerContainer;
import org.springframework.kafka.listener.ListenerUtils;
import org.springframework.kafka.support.KafkaHeaders;
import org.springframework.kafka.support.TopicPartitionOffset;
import org.springframework.kafka.support.serializer.DeserializationException;
import org.springframework.kafka.support.serializer.ErrorHandlingDeserializer;
import org.springframework.lang.Nullable;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.Assert;

/**
 * A KafkaTemplate that implements request/reply semantics.
 *
 * @param  the key type.
 * @param  the outbound data type.
 * @param  the reply data type.
 *
 * @author Gary Russell
 * @author Artem Bilan
 *
 * @since 2.1.3
 *
 */
public class ReplyingKafkaTemplate extends KafkaTemplate implements BatchMessageListener,
		InitializingBean, SmartLifecycle, DisposableBean, ReplyingKafkaOperations {

	private static final String WITH_CORRELATION_ID = " with correlationId: ";

	private static final int FIVE = 5;

	private static final Duration DEFAULT_REPLY_TIMEOUT = Duration.ofSeconds(FIVE);

	private final GenericMessageListenerContainer replyContainer;

	private final ConcurrentMap> futures = new ConcurrentHashMap<>();

	private final byte[] replyTopic;

	private final byte[] replyPartition;

	private TaskScheduler scheduler = new ThreadPoolTaskScheduler();

	private int phase;

	private boolean autoStartup = true;

	private Duration defaultReplyTimeout = DEFAULT_REPLY_TIMEOUT;

	private boolean schedulerSet;

	private boolean sharedReplyTopic;

	private Function, CorrelationKey> correlationStrategy = ReplyingKafkaTemplate::defaultCorrelationIdStrategy;

	private String correlationHeaderName = KafkaHeaders.CORRELATION_ID;

	private String replyTopicHeaderName = KafkaHeaders.REPLY_TOPIC;

	private String replyPartitionHeaderName = KafkaHeaders.REPLY_PARTITION;

	private volatile boolean running;

	private volatile boolean schedulerInitialized;

	public ReplyingKafkaTemplate(ProducerFactory producerFactory,
			GenericMessageListenerContainer replyContainer) {

		this(producerFactory, replyContainer, false);
	}

	public ReplyingKafkaTemplate(ProducerFactory producerFactory,
			GenericMessageListenerContainer replyContainer, boolean autoFlush) {

		super(producerFactory, autoFlush);
		Assert.notNull(replyContainer, "'replyContainer' cannot be null");
		this.replyContainer = replyContainer;
		this.replyContainer.setupMessageListener(this);
		ContainerProperties properties = this.replyContainer.getContainerProperties();
		String tempReplyTopic = null;
		byte[] tempReplyPartition = null;
		TopicPartitionOffset[] topicPartitionsToAssign = properties.getTopicPartitions();
		String[] topics = properties.getTopics();
		if (topics != null && topics.length == 1) {
			tempReplyTopic = topics[0];
		}
		else if (topicPartitionsToAssign != null && topicPartitionsToAssign.length == 1) {
			TopicPartitionOffset topicPartitionOffset = topicPartitionsToAssign[0];
			Assert.notNull(topicPartitionOffset, "'topicPartitionsToAssign' must not be null");
			tempReplyTopic = topicPartitionOffset.getTopic();
			ByteBuffer buffer = ByteBuffer.allocate(4); // NOSONAR magic #
			buffer.putInt(topicPartitionOffset.getPartition());
			tempReplyPartition = buffer.array();
		}
		if (tempReplyTopic == null) {
			this.replyTopic = null;
			this.replyPartition = null;
			this.logger.debug(() -> "Could not determine container's reply topic/partition; senders must populate "
					+ "at least the " + KafkaHeaders.REPLY_TOPIC + " header, and optionally the "
					+ KafkaHeaders.REPLY_PARTITION + " header");
		}
		else {
			this.replyTopic = tempReplyTopic.getBytes(StandardCharsets.UTF_8);
			this.replyPartition = tempReplyPartition;
		}
	}

	public void setTaskScheduler(TaskScheduler scheduler) {
		Assert.notNull(scheduler, "'scheduler' cannot be null");
		this.scheduler = scheduler;
		this.schedulerSet = true;
	}

	/**
	 * Return the reply timeout used if no replyTimeout is provided in the
	 * {@link #sendAndReceive(ProducerRecord, Duration)} call.
	 * @return the timeout.
	 * @since 2.3
	 */
	protected Duration getDefaultReplyTimeout() {
		return this.defaultReplyTimeout;
	}

	/**
	 * Set the reply timeout used if no replyTimeout is provided in the
	 * {@link #sendAndReceive(ProducerRecord, Duration)} call.
	 * @param defaultReplyTimeout the timeout.
	 * @since 2.3
	 */
	public void setDefaultReplyTimeout(Duration defaultReplyTimeout) {
		Assert.notNull(defaultReplyTimeout, "'defaultReplyTimeout' cannot be null");
		Assert.isTrue(defaultReplyTimeout.toMillis() >= 0, "'replyTimeout' must be >= 0");
		this.defaultReplyTimeout = defaultReplyTimeout;
	}

	@Override
	public boolean isRunning() {
		return this.running;
	}

	@Override
	public int getPhase() {
		return this.phase;
	}

	public void setPhase(int phase) {
		this.phase = phase;
	}

	@Override
	public boolean isAutoStartup() {
		return this.autoStartup;
	}

	public void setAutoStartup(boolean autoStartup) {
		this.autoStartup = autoStartup;
	}

	/**
	 * Return the topics/partitions assigned to the replying listener container.
	 * @return the topics/partitions.
	 */
	public Collection getAssignedReplyTopicPartitions() {
		return this.replyContainer.getAssignedPartitions();
	}

	/**
	 * Set to true when multiple templates are using the same topic for replies. This
	 * simply changes logs for unexpected replies to debug instead of error.
	 * @param sharedReplyTopic true if using a shared topic.
	 * @since 2.2
	 */
	public void setSharedReplyTopic(boolean sharedReplyTopic) {
		this.sharedReplyTopic = sharedReplyTopic;
	}

	/**
	 * Set a function to be called to establish a unique correlation key for each request
	 * record.
	 * @param correlationStrategy the function.
	 * @since 2.3
	 */
	public void setCorrelationIdStrategy(Function, CorrelationKey> correlationStrategy) {
		Assert.notNull(correlationStrategy, "'correlationStrategy' cannot be null");
		this.correlationStrategy = correlationStrategy;
	}

	/**
	 * Set a custom header name for the correlation id. Default
	 * {@link KafkaHeaders#CORRELATION_ID}.
	 * @param correlationHeaderName the header name.
	 * @since 2.3
	 */
	public void setCorrelationHeaderName(String correlationHeaderName) {
		Assert.notNull(correlationHeaderName, "'correlationHeaderName' cannot be null");
		this.correlationHeaderName = correlationHeaderName;
	}

	/**
	 * Set a custom header name for the reply topic. Default
	 * {@link KafkaHeaders#REPLY_TOPIC}.
	 * @param replyTopicHeaderName the header name.
	 * @since 2.3
	 */
	public void setReplyTopicHeaderName(String replyTopicHeaderName) {
		Assert.notNull(replyTopicHeaderName, "'replyTopicHeaderName' cannot be null");
		this.replyTopicHeaderName = replyTopicHeaderName;
	}

	/**
	 * Set a custom header name for the reply partition. Default
	 * {@link KafkaHeaders#REPLY_PARTITION}.
	 * @param replyPartitionHeaderName the reply partition header name.
	 * @since 2.3
	 */
	public void setReplyPartitionHeaderName(String replyPartitionHeaderName) {
		Assert.notNull(replyPartitionHeaderName, "'replyPartitionHeaderName' cannot be null");
		this.replyPartitionHeaderName = replyPartitionHeaderName;
	}

	@Override
	public void afterPropertiesSet() {
		if (!this.schedulerSet && !this.schedulerInitialized) {
			((ThreadPoolTaskScheduler) this.scheduler).initialize();
			this.schedulerInitialized = true;
		}
	}

	@Override
	public synchronized void start() {
		if (!this.running) {
			try {
				afterPropertiesSet();
			}
			catch (Exception e) {
				throw new KafkaException("Failed to initialize", e);
			}
			this.replyContainer.start();
			this.running = true;
		}
	}

	@Override
	public synchronized void stop() {
		if (this.running) {
			this.running = false;
			this.replyContainer.stop();
			this.futures.clear();
		}
	}

	@Override
	public void stop(Runnable callback) {
		stop();
		callback.run();
	}

	@Override
	public RequestReplyFuture sendAndReceive(ProducerRecord record) {
		return sendAndReceive(record, null);
	}

	@Override
	public RequestReplyFuture sendAndReceive(ProducerRecord record, @Nullable Duration replyTimeout) {
		Assert.state(this.running, "Template has not been start()ed"); // NOSONAR (sync)
		CorrelationKey correlationId = this.correlationStrategy.apply(record);
		Assert.notNull(correlationId, "the created 'correlationId' cannot be null");
		Headers headers = record.headers();
		boolean hasReplyTopic = headers.lastHeader(KafkaHeaders.REPLY_TOPIC) != null;
		if (!hasReplyTopic && this.replyTopic != null) {
			headers.add(new RecordHeader(this.replyTopicHeaderName, this.replyTopic));
			if (this.replyPartition != null) {
				headers.add(new RecordHeader(this.replyPartitionHeaderName, this.replyPartition));
			}
		}
		headers.add(new RecordHeader(this.correlationHeaderName, correlationId.getCorrelationId()));
		this.logger.debug(() -> "Sending: " + record + WITH_CORRELATION_ID + correlationId);
		RequestReplyFuture future = new RequestReplyFuture<>();
		this.futures.put(correlationId, future);
		try {
			future.setSendFuture(send(record));
		}
		catch (Exception e) {
			this.futures.remove(correlationId);
			throw new KafkaException("Send failed", e);
		}
		scheduleTimeout(record, correlationId, replyTimeout == null ? this.defaultReplyTimeout : replyTimeout);
		return future;
	}

	private void scheduleTimeout(ProducerRecord record, CorrelationKey correlationId, Duration replyTimeout) {
		this.scheduler.schedule(() -> {
			RequestReplyFuture removed = this.futures.remove(correlationId);
			if (removed != null) {
				this.logger.warn(() -> "Reply timed out for: " + record + WITH_CORRELATION_ID + correlationId);
				if (!handleTimeout(correlationId, removed)) {
					removed.setException(new KafkaReplyTimeoutException("Reply timed out"));
				}
			}
		}, Instant.now().plus(replyTimeout));
	}

	/**
	 * Used to inform subclasses that a request has timed out so they can clean up state
	 * and, optionally, complete the future.
	 * @param correlationId the correlation id.
	 * @param future the future.
	 * @return true to indicate the future has been completed.
	 * @since 2.3
	 */
	protected boolean handleTimeout(@SuppressWarnings("unused") CorrelationKey correlationId,
			@SuppressWarnings("unused") RequestReplyFuture future) {

		return false;
	}

	/**
	 * Return true if this correlation id is still active.
	 * @param correlationId the correlation id.
	 * @return true if pending.
	 * @since 2.3
	 */
	protected boolean isPending(CorrelationKey correlationId) {
		return this.futures.containsKey(correlationId);
	}

	@Override
	public void destroy() {
		if (!this.schedulerSet) {
			((ThreadPoolTaskScheduler) this.scheduler).destroy();
		}
	}

	private static  CorrelationKey defaultCorrelationIdStrategy(
			@SuppressWarnings("unused") ProducerRecord record) {

		UUID uuid = UUID.randomUUID();
		byte[] bytes = new byte[16]; // NOSONAR magic #
		ByteBuffer bb = ByteBuffer.wrap(bytes);
		bb.putLong(uuid.getMostSignificantBits());
		bb.putLong(uuid.getLeastSignificantBits());
		return new CorrelationKey(bytes);
	}

	@Override
	public void onMessage(List> data) {
		data.forEach(record -> {
			Header correlationHeader = record.headers().lastHeader(this.correlationHeaderName);
			CorrelationKey correlationId = null;
			if (correlationHeader != null) {
				correlationId = new CorrelationKey(correlationHeader.value());
			}
			if (correlationId == null) {
				this.logger.error(() -> "No correlationId found in reply: " + record
						+ " - to use request/reply semantics, the responding server must return the correlation id "
						+ " in the '" + this.correlationHeaderName + "' header");
			}
			else {
				RequestReplyFuture future = this.futures.remove(correlationId);
				CorrelationKey correlationKey = correlationId;
				if (future == null) {
					logLateArrival(record, correlationId);
				}
				else {
					boolean ok = true;
					if (record.value() == null) {
						DeserializationException de = checkDeserialization(record, this.logger);
						if (de != null) {
							ok = false;
							future.setException(de);
						}
					}
					if (ok) {
						this.logger.debug(() -> "Received: " + record + WITH_CORRELATION_ID + correlationKey);
						future.set(record);
					}
				}
			}
		});
	}

	/**
	 * Return a {@link DeserializationException} if either the key or value failed
	 * deserialization; null otherwise. If you need to determine whether it was the key or
	 * value, call
	 * {@link ListenerUtils#getExceptionFromHeader(ConsumerRecord, String, LogAccessor)}
	 * with {@link ErrorHandlingDeserializer#KEY_DESERIALIZER_EXCEPTION_HEADER} and
	 * {@link ErrorHandlingDeserializer#VALUE_DESERIALIZER_EXCEPTION_HEADER} instead.
	 * @param record the record.
	 * @param logger a {@link LogAccessor}.
	 * @return the {@link DeserializationException} or {@code null}.
	 * @since 2.2.15
	 */
	@Nullable
	public static DeserializationException checkDeserialization(ConsumerRecord record, LogAccessor logger) {
		DeserializationException exception = ListenerUtils.getExceptionFromHeader(record,
				ErrorHandlingDeserializer.VALUE_DESERIALIZER_EXCEPTION_HEADER, logger);
		if (exception != null) {
			logger.error(exception, () -> "Reply value deserialization failed for " + record.topic() + "-"
					+ record.partition() + "@" + record.offset());
			return exception;
		}
		exception = ListenerUtils.getExceptionFromHeader(record,
				ErrorHandlingDeserializer.KEY_DESERIALIZER_EXCEPTION_HEADER, logger);
		if (exception != null) {
			logger.error(exception, () -> "Reply key deserialization failed for " + record.topic() + "-"
					+ record.partition() + "@" + record.offset());
			return exception;
		}
		return null;
	}

	protected void logLateArrival(ConsumerRecord record, CorrelationKey correlationId) {
		if (this.sharedReplyTopic) {
			this.logger.debug(() -> missingCorrelationLogMessage(record, correlationId));
		}
		else {
			this.logger.error(() -> missingCorrelationLogMessage(record, correlationId));
		}
	}

	private String missingCorrelationLogMessage(ConsumerRecord record, CorrelationKey correlationId) {
		return "No pending reply: " + record + WITH_CORRELATION_ID
				+ correlationId + ", perhaps timed out, or using a shared reply topic";
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy