Skip to content

Commit

Permalink
[server] Avoid race condition between pick consumer call and actual c…
Browse files Browse the repository at this point in the history
…onsumer subscribe API (#1301)
  • Loading branch information
sixpluszero authored Nov 13, 2024
1 parent 61cb342 commit 37a6951
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

import com.linkedin.davinci.stats.AggKafkaConsumerServiceStats;
import com.linkedin.venice.meta.ReadOnlyStoreRepository;
import com.linkedin.venice.meta.Version;
import com.linkedin.venice.pubsub.PubSubConsumerAdapterFactory;
import com.linkedin.venice.pubsub.api.PubSubMessageDeserializer;
import com.linkedin.venice.pubsub.api.PubSubTopic;
import com.linkedin.venice.pubsub.api.PubSubTopicPartition;
import com.linkedin.venice.utils.Time;
import com.linkedin.venice.utils.concurrent.VeniceConcurrentHashMap;
import io.tehuti.metrics.MetricsRepository;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Properties;


Expand All @@ -26,6 +28,9 @@
public class StoreAwarePartitionWiseKafkaConsumerService extends PartitionWiseKafkaConsumerService {
// This constant makes sure the store subscription count will always be prioritized over consumer assignment count.
private static final int IMPOSSIBLE_MAX_PARTITION_COUNT_PER_CONSUMER = 10000;
private final Map<SharedKafkaConsumer, Integer> consumerToBaseLoadCount = new VeniceConcurrentHashMap<>();
private final Map<SharedKafkaConsumer, Map<String, Integer>> consumerToStoreLoadCount =
new VeniceConcurrentHashMap<>();

StoreAwarePartitionWiseKafkaConsumerService(
final ConsumerPoolType poolType,
Expand Down Expand Up @@ -73,7 +78,7 @@ protected synchronized SharedKafkaConsumer pickConsumerForPartition(
PubSubTopic versionTopic,
PubSubTopicPartition topicPartition) {
String storeName = versionTopic.getStoreName();
long minLoad = Long.MAX_VALUE;
int minLoad = Integer.MAX_VALUE;
SharedKafkaConsumer minLoadConsumer = null;
for (SharedKafkaConsumer consumer: getConsumerToConsumptionTask().keySet()) {
int index = getConsumerToConsumptionTask().indexOf(consumer);
Expand All @@ -85,13 +90,13 @@ && alreadySubscribedRealtimeTopicPartition(consumer, topicPartition)) {
topicPartition);
continue;
}
long overallLoad = getConsumerStoreLoad(consumer, storeName);
int overallLoad = getConsumerStoreLoad(consumer, storeName);
if (overallLoad < minLoad) {
minLoadConsumer = consumer;
minLoad = overallLoad;
}
}
if (minLoad == Long.MAX_VALUE) {
if (minLoad == Integer.MAX_VALUE) {
throw new IllegalStateException("Unable to find least loaded consumer entry.");
}

Expand All @@ -103,19 +108,58 @@ && alreadySubscribedRealtimeTopicPartition(consumer, topicPartition)) {
getLOGGER().info(
"Picked consumer id: {}, assignment size: {}, computed load: {} for topic partition: {}, version topic: {}",
getConsumerToConsumptionTask().indexOf(minLoadConsumer),
minLoadConsumer.getAssignmentSize(),
getConsumerToBaseLoadCount().getOrDefault(minLoadConsumer, 0),
minLoad,
topicPartition,
versionTopic);
increaseConsumerStoreLoad(minLoadConsumer, storeName);
return minLoadConsumer;
}

long getConsumerStoreLoad(SharedKafkaConsumer consumer, String storeName) {
long baseAssignmentCount = consumer.getAssignmentSize();
long storeSubscriptionCount = consumer.getAssignment()
.stream()
.filter(x -> Version.parseStoreFromKafkaTopicName(x.getTopicName()).equals(storeName))
.count();
@Override
void handleUnsubscription(
SharedKafkaConsumer consumer,
PubSubTopic versionTopic,
PubSubTopicPartition pubSubTopicPartition) {
super.handleUnsubscription(consumer, versionTopic, pubSubTopicPartition);
decreaseConsumerStoreLoad(consumer, versionTopic.getStoreName());
}

int getConsumerStoreLoad(SharedKafkaConsumer consumer, String storeName) {
int baseAssignmentCount = getConsumerToBaseLoadCount().getOrDefault(consumer, 0);
int storeSubscriptionCount =
getConsumerToStoreLoadCount().getOrDefault(consumer, Collections.emptyMap()).getOrDefault(storeName, 0);
return storeSubscriptionCount * IMPOSSIBLE_MAX_PARTITION_COUNT_PER_CONSUMER + baseAssignmentCount;
}

void increaseConsumerStoreLoad(SharedKafkaConsumer consumer, String storeName) {
getConsumerToBaseLoadCount().compute(consumer, (k, v) -> (v == null) ? 1 : v + 1);
getConsumerToStoreLoadCount().computeIfAbsent(consumer, k -> new VeniceConcurrentHashMap<>())
.compute(storeName, (k, v) -> (v == null) ? 1 : v + 1);
}

void decreaseConsumerStoreLoad(SharedKafkaConsumer consumer, String storeName) {
if (!getConsumerToBaseLoadCount().containsKey(consumer)) {
throw new IllegalStateException("Consumer to base load count map does not contain consumer: " + consumer);
}
if (!getConsumerToStoreLoadCount().containsKey(consumer)) {
throw new IllegalStateException("Consumer to store load count map does not contain consumer: " + consumer);
}
if (!getConsumerToStoreLoadCount().get(consumer).containsKey(storeName)) {
throw new IllegalStateException("Consumer to store load count map does not contain store: " + storeName);
}
getConsumerToBaseLoadCount().computeIfPresent(consumer, (k, v) -> (v == 1) ? null : v - 1);
getConsumerToStoreLoadCount().computeIfPresent(consumer, (k, innerMap) -> {
innerMap.computeIfPresent(storeName, (s, c) -> (c == 1) ? null : c - 1);
return innerMap.isEmpty() ? null : innerMap;
});
}

Map<SharedKafkaConsumer, Map<String, Integer>> getConsumerToStoreLoadCount() {
return consumerToStoreLoadCount;
}

Map<SharedKafkaConsumer, Integer> getConsumerToBaseLoadCount() {
return consumerToBaseLoadCount;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import io.tehuti.metrics.MetricsRepository;
import io.tehuti.metrics.Sensor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -436,6 +435,8 @@ public void testStoreAwarePartitionWiseGetConsumer() {
String topicForStoreName3 = Version.composeKafkaTopic(storeName3, 1);
PubSubTopic pubSubTopicForStoreName3 = pubSubTopicRepository.getTopic(topicForStoreName3);

String storeName4 = Utils.getUniqueString("test_consumer_service4");

SharedKafkaConsumer consumer1 = mock(SharedKafkaConsumer.class);
SharedKafkaConsumer consumer2 = mock(SharedKafkaConsumer.class);
ConsumptionTask consumptionTask = mock(ConsumptionTask.class);
Expand All @@ -449,22 +450,30 @@ public void testStoreAwarePartitionWiseGetConsumer() {
consumptionTaskIndexedMap.put(consumer2, consumptionTask);
when(consumerService.getConsumerToConsumptionTask()).thenReturn(consumptionTaskIndexedMap);

Map<SharedKafkaConsumer, Integer> consumerToBasicLoadMap = new VeniceConcurrentHashMap<>();
when(consumerService.getConsumerToBaseLoadCount()).thenReturn(consumerToBasicLoadMap);
Map<SharedKafkaConsumer, Map<String, Integer>> consumerToStoreLoadMap = new VeniceConcurrentHashMap<>();
when(consumerService.getConsumerToStoreLoadCount()).thenReturn(consumerToStoreLoadMap);

Map<PubSubTopicPartition, Set<PubSubConsumerAdapter>> rtTopicPartitionToConsumerMap =
new VeniceConcurrentHashMap<>();
when(consumerService.getRtTopicPartitionToConsumerMap()).thenReturn(rtTopicPartitionToConsumerMap);
when(consumerService.getLOGGER())
.thenReturn(LogManager.getLogger(StoreAwarePartitionWiseKafkaConsumerService.class));
doCallRealMethod().when(consumerService).pickConsumerForPartition(any(), any());
doCallRealMethod().when(consumerService).getConsumerStoreLoad(any(), anyString());
doCallRealMethod().when(consumerService).increaseConsumerStoreLoad(any(), anyString());
doCallRealMethod().when(consumerService).decreaseConsumerStoreLoad(any(), anyString());

consumerToBasicLoadMap.put(consumer1, 1);
Map<String, Integer> innerMap1 = new VeniceConcurrentHashMap<>();
innerMap1.put(storeName1, 1);
consumerToStoreLoadMap.put(consumer1, innerMap1);
consumerToBasicLoadMap.put(consumer2, 2);
Map<String, Integer> innerMap2 = new VeniceConcurrentHashMap<>();
innerMap2.put(storeName2, 2);
consumerToStoreLoadMap.put(consumer2, innerMap2);

when(consumer1.getAssignmentSize()).thenReturn(1);
when(consumer1.getAssignment())
.thenReturn(Collections.singleton(new PubSubTopicPartitionImpl(pubSubTopicForStoreName1, 100)));
Set<PubSubTopicPartition> tpSet = new HashSet<>();
tpSet.add(new PubSubTopicPartitionImpl(pubSubTopicForStoreName2, 100));
tpSet.add(new PubSubTopicPartitionImpl(pubSubTopicForStoreName2, 101));
when(consumer2.getAssignmentSize()).thenReturn(2);
when(consumer2.getAssignment()).thenReturn(tpSet);
Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName1), 10001);
Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName2), 1);
Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName3), 1);
Expand All @@ -478,16 +487,53 @@ public void testStoreAwarePartitionWiseGetConsumer() {
pubSubTopicForStoreName1,
new PubSubTopicPartitionImpl(pubSubTopicForStoreName1, 0)),
consumer2);
Assert.assertEquals(consumerToBasicLoadMap.get(consumer2).intValue(), 3);
Assert.assertEquals(consumerToStoreLoadMap.get(consumer2).get(storeName1).intValue(), 1);
Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer2, storeName1), 10003);
Assert.assertEquals(
consumerService.pickConsumerForPartition(
pubSubTopicForStoreName2,
new PubSubTopicPartitionImpl(pubSubTopicForStoreName2, 0)),
consumer1);
Assert.assertEquals(consumerToBasicLoadMap.get(consumer1).intValue(), 2);
Assert.assertEquals(consumerToStoreLoadMap.get(consumer1).get(storeName2).intValue(), 1);
Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName2), 10002);
Assert.assertEquals(
consumerService.pickConsumerForPartition(
pubSubTopicForStoreName3,
new PubSubTopicPartitionImpl(pubSubTopicForStoreName3, 0)),
consumer1);
Assert.assertEquals(consumerToBasicLoadMap.get(consumer1).intValue(), 3);
Assert.assertEquals(consumerToStoreLoadMap.get(consumer1).get(storeName3).intValue(), 1);
Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName3), 10003);

// Validate decrease consumer entry
Assert.assertThrows(() -> consumerService.decreaseConsumerStoreLoad(consumer1, storeName4));

consumerService.decreaseConsumerStoreLoad(consumer1, storeName1);
Assert.assertEquals(consumerToBasicLoadMap.get(consumer1).intValue(), 2);
Assert.assertNull(consumerToStoreLoadMap.get(consumer1).get(storeName1));
Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName1), 2);
Assert.assertThrows(() -> consumerService.decreaseConsumerStoreLoad(consumer1, storeName1));

consumerService.decreaseConsumerStoreLoad(consumer1, storeName2);
Assert.assertEquals(consumerToBasicLoadMap.get(consumer1).intValue(), 1);
Assert.assertNull(consumerToStoreLoadMap.get(consumer1).get(storeName2));
Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName2), 1);
Assert.assertThrows(() -> consumerService.decreaseConsumerStoreLoad(consumer1, storeName2));

consumerService.decreaseConsumerStoreLoad(consumer1, storeName3);
Assert.assertNull(consumerToBasicLoadMap.get(consumer1));
Assert.assertNull(consumerToStoreLoadMap.get(consumer1));
Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName3), 0);
Assert.assertThrows(() -> consumerService.decreaseConsumerStoreLoad(consumer1, storeName3));

// Validate increase consumer entry
consumerService.increaseConsumerStoreLoad(consumer1, storeName1);
Assert.assertEquals(consumerToBasicLoadMap.get(consumer1).intValue(), 1);
Assert.assertEquals(consumerToStoreLoadMap.get(consumer1).get(storeName1).intValue(), 1);
Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName1), 10001);
Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName2), 1);
}

@Test
Expand Down

0 comments on commit 37a6951

Please sign in to comment.