Skip to content

Commit

Permalink
spring-projectsGH-3696: DeserializationEx support for KafkaMS
Browse files Browse the repository at this point in the history
Fixes spring-projects#3696

* Add an internal logic into `KafkaMessageSource` to react properly for the
`ErrorHandlingDeserializer` configuration and re-throw `DeserializationException`
  • Loading branch information
artembilan committed Jul 26, 2023
1 parent bd013e0 commit eb3e8f6
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -46,6 +47,7 @@
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.errors.WakeupException;

import org.springframework.beans.factory.BeanClassLoaderAware;
import org.springframework.core.log.LogAccessor;
import org.springframework.integration.IntegrationMessageHeaderAccessor;
import org.springframework.integration.acks.AcknowledgmentCallback;
Expand All @@ -58,6 +60,7 @@
import org.springframework.kafka.core.DefaultKafkaConsumerFactory;
import org.springframework.kafka.listener.ConsumerAwareRebalanceListener;
import org.springframework.kafka.listener.ConsumerProperties;
import org.springframework.kafka.listener.ListenerUtils;
import org.springframework.kafka.listener.LoggingCommitCallback;
import org.springframework.kafka.support.Acknowledgment;
import org.springframework.kafka.support.DefaultKafkaHeaderMapper;
Expand All @@ -69,9 +72,13 @@
import org.springframework.kafka.support.converter.KafkaMessageHeaders;
import org.springframework.kafka.support.converter.MessagingMessageConverter;
import org.springframework.kafka.support.converter.RecordMessageConverter;
import org.springframework.kafka.support.serializer.DeserializationException;
import org.springframework.kafka.support.serializer.ErrorHandlingDeserializer;
import org.springframework.kafka.support.serializer.SerializationUtils;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;

Expand Down Expand Up @@ -102,7 +109,8 @@
* @since 5.4
*
*/
public class KafkaMessageSource<K, V> extends AbstractMessageSource<Object> implements Pausable {
public class KafkaMessageSource<K, V> extends AbstractMessageSource<Object>
implements Pausable, BeanClassLoaderAware {

private static final long MIN_ASSIGN_TIMEOUT = 2000L;

Expand Down Expand Up @@ -146,6 +154,10 @@ public class KafkaMessageSource<K, V> extends AbstractMessageSource<Object> impl

private Duration closeTimeout = Duration.ofSeconds(DEFAULT_CLOSE_TIMEOUT);

private boolean checkNullKeyForExceptions;

private boolean checkNullValueForExceptions;

private volatile Consumer<K, V> consumer;

private volatile boolean pausing;
Expand All @@ -158,6 +170,8 @@ public class KafkaMessageSource<K, V> extends AbstractMessageSource<Object> impl

public volatile boolean newAssignment; // NOSONAR - direct access from inner

private ClassLoader classLoader;

/**
* Construct an instance with the supplied parameters. Fetching multiple
* records per poll will be disabled.
Expand Down Expand Up @@ -257,11 +271,68 @@ public Collection<TopicPartition> getAssignedPartitions() {
return Collections.unmodifiableCollection(this.assignedPartitions);
}

@Override
public void setBeanClassLoader(ClassLoader classLoader) {
this.classLoader = classLoader;
}

@Override
protected void onInit() {
if (!StringUtils.hasText(this.consumerProperties.getClientId())) {
this.consumerProperties.setClientId(getComponentName());
}

Map<String, Object> props = this.consumerFactory.getConfigurationProperties();
Properties kafkaConsumerProperties = this.consumerProperties.getKafkaConsumerProperties();
this.checkNullKeyForExceptions =
this.consumerProperties.isCheckDeserExWhenKeyNull() ||
checkDeserializer(findDeserializerClass(props, kafkaConsumerProperties, false));
this.checkNullValueForExceptions =
this.consumerProperties.isCheckDeserExWhenValueNull() ||
checkDeserializer(findDeserializerClass(props, kafkaConsumerProperties, true));
}

@Nullable
private Object findDeserializerClass(Map<String, Object> props, Properties consumerOverrides, boolean isValue) {
Object configuredDeserializer =
isValue
? this.consumerFactory.getValueDeserializer()
: this.consumerFactory.getKeyDeserializer();
if (configuredDeserializer == null) {
Object deser = consumerOverrides.get(
isValue
? ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG
: ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG);
if (deser == null) {
deser = props.get(
isValue
? ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG
: ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG);
}
return deser;
}
else {
return configuredDeserializer.getClass();
}
}

private boolean checkDeserializer(@Nullable Object deser) {
Class<?> deserializer = null;
if (deser instanceof Class<?> deserClass) {
deserializer = deserClass;
}
else if (deser instanceof String str) {
try {
deserializer = ClassUtils.forName(str, this.classLoader);
}
catch (ClassNotFoundException | LinkageError e) {
throw new IllegalStateException(e);
}
}
else if (deser != null) {
throw new IllegalStateException("Deserializer must be a class or class name, not a " + deser.getClass());
}
return deserializer != null && ErrorHandlingDeserializer.class.isAssignableFrom(deserializer);
}

/**
Expand Down Expand Up @@ -609,6 +680,13 @@ record = this.recordsIterator.next();
}

private Object recordToMessage(ConsumerRecord<K, V> record) {
if (record.value() == null && this.checkNullValueForExceptions) {
checkDeserializationException(record, SerializationUtils.VALUE_DESERIALIZER_EXCEPTION_HEADER);
}
if (record.key() == null && this.checkNullKeyForExceptions) {
checkDeserializationException(record, SerializationUtils.KEY_DESERIALIZER_EXCEPTION_HEADER);
}

TopicPartition topicPartition = new TopicPartition(record.topic(), record.partition());
KafkaAckInfo<K, V> ackInfo = new KafkaAckInfoImpl(record, topicPartition);
AcknowledgmentCallback ackCallback = this.ackCallbackFactory.createCallback(ackInfo);
Expand Down Expand Up @@ -639,6 +717,13 @@ private Object recordToMessage(ConsumerRecord<K, V> record) {
}
}

private void checkDeserializationException(ConsumerRecord<K, V> cRecord, String headerName) {
DeserializationException exception = ListenerUtils.getExceptionFromHeader(cRecord, headerName, this.logger);
if (exception != null) {
throw exception;
}
}

@Override
public void destroy() {
this.lock.lock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.serialization.Deserializer;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import org.springframework.integration.channel.NullChannel;
Expand All @@ -32,11 +34,17 @@
import org.springframework.kafka.core.DefaultKafkaProducerFactory;
import org.springframework.kafka.core.KafkaTemplate;
import org.springframework.kafka.listener.ConsumerProperties;
import org.springframework.kafka.support.serializer.DeserializationException;
import org.springframework.kafka.support.serializer.ErrorHandlingDeserializer;
import org.springframework.kafka.test.utils.KafkaTestUtils;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.util.ClassUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.awaitility.Awaitility.await;
import static org.mockito.Mockito.mock;

/**
* @author Gary Russell
Expand All @@ -50,9 +58,17 @@ class MessageSourceIntegrationTests {

static final String TOPIC1 = "MessageSourceIntegrationTests1";

static final String TOPIC2 = "MessageSourceIntegrationTests2";

static String brokers;

@BeforeAll
static void setup() {
brokers = System.getProperty("spring.global.embedded.kafka.brokers");
}

@Test
void testSource() throws Exception {
String brokers = System.getProperty("spring.global.embedded.kafka.brokers");
Map<String, Object> consumerProps = KafkaTestUtils.consumerProps(brokers, "testSource", "false");
consumerProps.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, 2);
consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");
Expand Down Expand Up @@ -122,4 +138,66 @@ public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
template.destroy();
}

@Test
void deserializationErrorIsThrownFromSource() throws Exception {
Map<String, Object> consumerProps = KafkaTestUtils.consumerProps(brokers, "testErrorChannelSource", "false");
consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");
consumerProps.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ErrorHandlingDeserializer.class);
consumerProps.put(ErrorHandlingDeserializer.VALUE_DESERIALIZER_CLASS, FailingDeserializer.class);

DefaultKafkaConsumerFactory<Integer, String> consumerFactory = new DefaultKafkaConsumerFactory<>(consumerProps);
ConsumerProperties consumerProperties = new ConsumerProperties(TOPIC2);
CountDownLatch assigned = new CountDownLatch(1);
consumerProperties.setConsumerRebalanceListener(
new ConsumerRebalanceListener() {

@Override
public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
}

@Override
public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
assigned.countDown();
}

});

consumerProperties.setPollTimeout(10);

KafkaMessageSource<Integer, String> source = new KafkaMessageSource<>(consumerFactory, consumerProperties);
source.setBeanClassLoader(ClassUtils.getDefaultClassLoader());
source.setBeanFactory(mock());
source.afterPropertiesSet();
source.start();

Map<String, Object> producerProps = KafkaTestUtils.producerProps(brokers);
DefaultKafkaProducerFactory<Object, Object> producerFactory = new DefaultKafkaProducerFactory<>(producerProps);
KafkaTemplate<Object, Object> template = new KafkaTemplate<>(producerFactory);

String testData = "test data";
template.send(TOPIC2, testData);

source.receive(); // Trigger Kafka Consumer creation and poll()
assertThat(assigned.await(10, TimeUnit.SECONDS)).isTrue();

await().untilAsserted(() ->
assertThatExceptionOfType(DeserializationException.class)
.isThrownBy(source::receive)
.hasFieldOrPropertyWithValue("data", testData.getBytes())
.withMessage("failed to deserialize")
.withStackTraceContaining("failed deserialization"));

source.destroy();
template.destroy();
}

public static class FailingDeserializer implements Deserializer<String> {

@Override
public String deserialize(String topic, byte[] data) {
throw new RuntimeException("failed deserialization");
}

}

}
7 changes: 6 additions & 1 deletion src/reference/asciidoc/kafka.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,11 @@ If you set `allowMultiFetch` to `true` you must process all the retrieved record

Messages emitted by this adapter contain a header `kafka_remainingRecords` with a count of records remaining from the previous poll.

Starting with version `6.2`, the `KafkaMessageSource` supports an `ErrorHandlingDeserializer` provided in the consumer properties.
A `DeserializationException` is extracted from record headers and thrown to the called.
With a `SourcePollingChannelAdapter` this exception is wrapped into an `ErrorMessage` and published to its `errorChannel`.
See https://docs.spring.io/spring-kafka/reference/html/#error-handling-deserializer[`ErrorHandlingDeserializer`] documentation for more information.

[[kafka-outbound-gateway]]
=== Outbound Gateway

Expand All @@ -448,7 +453,7 @@ It is suggested that you add a `ConsumerRebalanceListener` to the template's rep

The `KafkaProducerMessageHandler` `sendTimeoutExpression` default is `delivery.timeout.ms` Kafka producer property `+ 5000` so that the actual Kafka error after a timeout is propagated to the application, instead of a timeout generated by this framework.
This has been changed for consistency because you may get unexpected behavior (Spring may time out the `send()`, while it is actually, eventually, successful).
IMPORTANT: That timeout is 120 seconds by default so you may wish to reduce it to get more timely failures.
IMPORTANT: That timeout is 120 seconds by default, so you may wish to reduce it to get more timely failures.

[[kafka-outbound-gateway-configuration]]
==== Configuration
Expand Down
7 changes: 7 additions & 0 deletions src/reference/asciidoc/whats-new.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,10 @@ See, for example, `transformWith()`, `splitWith()` in <<./dsl.adoc#java-dsl, Jav
- For the server and client WebSocket containers, the send buffer overflow strategy is now configurable in `IntegrationWebSocketContainer` and in XML via `send-buffer-overflow-strategy`.
This strategy determines the behavior when a session's outbound message buffer has reached the configured limit.
See <<./web-sockets.adoc#websocket-client-container-attributes, WebSockets Support>> for more information.


[[x6.2-kafka]]
=== Apache Kafka Support Changes

The `KafkaMessageSource` now extracts an `ErrorHandlingDeserializer` configuration from the consumer properties and re-throws `DeserializationException` extracted from failed record headers.
See <<./kafka.adoc#kafka-inbound-pollable, Kafka Inbound Channel Adapter>> for more information.

0 comments on commit eb3e8f6

Please sign in to comment.