Skip to content

Commit 009d8b9

Browse files
committed
1 parent c35a2bb commit 009d8b9

File tree

2 files changed

+128
-3
lines changed

2 files changed

+128
-3
lines changed

spring-kafka/src/main/java/org/springframework/kafka/listener/AbstractConsumerSeekAware.java

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2019-2023 the original author or authors.
2+
* Copyright 2019-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,15 +16,14 @@
1616

1717
package org.springframework.kafka.listener;
1818

19+
import java.util.ArrayList;
1920
import java.util.Collection;
2021
import java.util.Collections;
2122
import java.util.LinkedList;
2223
import java.util.List;
2324
import java.util.Map;
2425
import java.util.concurrent.ConcurrentHashMap;
25-
2626
import org.apache.kafka.common.TopicPartition;
27-
2827
import org.springframework.lang.Nullable;
2928

3029
/**
@@ -41,6 +40,8 @@ public abstract class AbstractConsumerSeekAware implements ConsumerSeekAware {
4140
private final Map<Thread, ConsumerSeekCallback> callbackForThread = new ConcurrentHashMap<>();
4241

4342
private final Map<TopicPartition, ConsumerSeekCallback> callbacks = new ConcurrentHashMap<>();
43+
// [Suggestion]
44+
private final Map<TopicPartition, List<ConsumerSeekCallback>> callbacksV2 = new ConcurrentHashMap<>();
4445

4546
private final Map<ConsumerSeekCallback, List<TopicPartition>> callbacksToTopic = new ConcurrentHashMap<>();
4647

@@ -60,6 +61,17 @@ public void onPartitionsAssigned(Map<TopicPartition, Long> assignments, Consumer
6061
}
6162
}
6263

64+
// [Suggestion]
65+
public void onPartitionsAssignedV2(Map<TopicPartition, Long> assignments, ConsumerSeekCallback callback) {
66+
ConsumerSeekCallback threadCallback = this.callbackForThread.get(Thread.currentThread());
67+
if (threadCallback != null) {
68+
assignments.keySet().forEach(tp -> {
69+
this.callbacksV2.computeIfAbsent(tp, key -> new ArrayList<>()).add(threadCallback);
70+
this.callbacksToTopic.computeIfAbsent(threadCallback, key -> new LinkedList<>()).add(tp);
71+
});
72+
}
73+
}
74+
6375
@Override
6476
public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
6577
partitions.forEach(tp -> {
@@ -76,6 +88,24 @@ public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
7688
});
7789
}
7890

91+
// [Suggestion]
92+
public void onPartitionsRevokedV2(Collection<TopicPartition> partitions) {
93+
partitions.forEach(tp -> {
94+
List<ConsumerSeekCallback> removed = this.callbacksV2.remove(tp);
95+
if (removed != null && !removed.isEmpty()) {
96+
removed.forEach(cb -> {
97+
List<TopicPartition> topics = this.callbacksToTopic.get(cb);
98+
if (topics != null) {
99+
topics.remove(tp);
100+
if (topics.isEmpty()) {
101+
this.callbacksToTopic.remove(cb);
102+
}
103+
}
104+
});
105+
}
106+
});
107+
}
108+
79109
@Override
80110
public void unregisterSeekCallback() {
81111
this.callbackForThread.remove(Thread.currentThread());
@@ -91,6 +121,11 @@ protected ConsumerSeekCallback getSeekCallbackFor(TopicPartition topicPartition)
91121
return this.callbacks.get(topicPartition);
92122
}
93123

124+
// [Suggestion]
125+
protected List<ConsumerSeekCallback> getSeekCallbackForV2(TopicPartition topicPartition) {
126+
return this.callbacksV2.get(topicPartition);
127+
}
128+
94129
/**
95130
* The map of callbacks for all currently assigned partitions.
96131
* @return the map.
@@ -99,6 +134,11 @@ protected Map<TopicPartition, ConsumerSeekCallback> getSeekCallbacks() {
99134
return Collections.unmodifiableMap(this.callbacks);
100135
}
101136

137+
// [Suggestion]
138+
protected Map<TopicPartition, List<ConsumerSeekCallback>> getSeekCallbacksV2() {
139+
return Collections.unmodifiableMap(this.callbacksV2);
140+
}
141+
102142
/**
103143
* Return the currently registered callbacks and their associated {@link TopicPartition}(s).
104144
* @return the map of callbacks and partitions.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package org.springframework.kafka.listener;
2+
3+
import java.util.Collection;
4+
import java.util.List;
5+
import java.util.Map;
6+
import java.util.Set;
7+
import org.apache.kafka.common.TopicPartition;
8+
import org.junit.jupiter.api.Test;
9+
import org.springframework.beans.factory.annotation.Autowired;
10+
import org.springframework.context.annotation.Bean;
11+
import org.springframework.context.annotation.Configuration;
12+
import org.springframework.kafka.annotation.EnableKafka;
13+
import org.springframework.kafka.annotation.KafkaListener;
14+
import org.springframework.kafka.config.ConcurrentKafkaListenerContainerFactory;
15+
import org.springframework.kafka.core.ConsumerFactory;
16+
import org.springframework.kafka.core.DefaultKafkaConsumerFactory;
17+
import org.springframework.kafka.listener.ConsumerSeekAware.ConsumerSeekCallback;
18+
import org.springframework.kafka.test.EmbeddedKafkaBroker;
19+
import org.springframework.kafka.test.context.EmbeddedKafka;
20+
import org.springframework.kafka.test.utils.KafkaTestUtils;
21+
import org.springframework.stereotype.Component;
22+
import org.springframework.test.annotation.DirtiesContext;
23+
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
24+
25+
import static org.assertj.core.api.Assertions.assertThat;
26+
27+
@DirtiesContext
28+
@SpringJUnitConfig
29+
@EmbeddedKafka(topics = {AbstractConsumerSeekAwareTests.TOPIC}, partitions = 1)
30+
public class AbstractConsumerSeekAwareTests {
31+
static final String TOPIC = "Seek";
32+
33+
@Autowired
34+
Config config;
35+
36+
@Autowired
37+
Config.MultiGroupListener multiGroupListener;
38+
39+
@Test
40+
public void sizeOfCallbacksIsNotSame() {
41+
// Check the size of registered callbacks
42+
Map<ConsumerSeekCallback, List<TopicPartition>> callbacksAndTopics = multiGroupListener.getCallbacksAndTopics();
43+
Set<ConsumerSeekCallback> registeredCallbacks = callbacksAndTopics.keySet();
44+
assertThat(registeredCallbacks).hasSize(2);
45+
46+
// Get the size of all seek callbacks
47+
Map<TopicPartition, ConsumerSeekCallback> topicsToCallback = multiGroupListener.getSeekCallbacks();
48+
Collection<ConsumerSeekCallback> callbacks = topicsToCallback.values();
49+
assertThat(callbacks).hasSize(1); // <- I think the result should be two because two callbacks are registered.
50+
}
51+
52+
@EnableKafka
53+
@Configuration
54+
static class Config {
55+
56+
@Autowired
57+
EmbeddedKafkaBroker broker;
58+
59+
@Bean
60+
public ConcurrentKafkaListenerContainerFactory<String, String> kafkaListenerContainerFactory(
61+
ConsumerFactory<String, String> consumerFactory) {
62+
ConcurrentKafkaListenerContainerFactory<String, String> factory = new ConcurrentKafkaListenerContainerFactory<>();
63+
factory.setConsumerFactory(consumerFactory);
64+
return factory;
65+
}
66+
67+
@Bean
68+
ConsumerFactory<String, String> consumerFactory() {
69+
return new DefaultKafkaConsumerFactory<>(KafkaTestUtils.consumerProps("test-group", "false", this.broker));
70+
}
71+
72+
@Component
73+
static class MultiGroupListener extends AbstractConsumerSeekAware {
74+
75+
@KafkaListener(groupId = "group1", topics = TOPIC)
76+
void listenForGroup1(String in) {
77+
}
78+
79+
@KafkaListener(groupId = "group2", topics = TOPIC)
80+
void listenForGroup2(String in) {
81+
}
82+
}
83+
}
84+
85+
}

0 commit comments

Comments
 (0)