Skip to content

Enhance thread-safety in ClientSideCaching key retrieval #3268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 73 additions & 29 deletions src/main/java/io/lettuce/core/support/caching/ClientSideCaching.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package io.lettuce.core.support.caching;

import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;

import io.lettuce.core.StatefulRedisConnectionImpl;
import io.lettuce.core.TrackingArgs;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.codec.RedisCodec;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.Supplier;

/**
* Utility to provide server-side assistance for client-side caches. This is a {@link CacheFrontend} that represents a two-level
Expand All @@ -31,6 +33,7 @@
* @param <K> Key type.
* @param <V> Value type.
* @author Mark Paluch
* @author Yoobin Yoon
* @since 6.0
*/
public class ClientSideCaching<K, V> implements CacheFrontend<K, V> {
Expand All @@ -41,6 +44,8 @@ public class ClientSideCaching<K, V> implements CacheFrontend<K, V> {

private final List<Consumer<K>> invalidationListeners = new CopyOnWriteArrayList<>();

private final ConcurrentHashMap<K, ReentrantLock> keyLocks = new ConcurrentHashMap<>();

private ClientSideCaching(CacheAccessor<K, V> cacheAccessor, RedisCache<K, V> redisCache) {
this.cacheAccessor = cacheAccessor;
this.redisCache = redisCache;
Expand Down Expand Up @@ -103,6 +108,7 @@ private static <K, V> CacheFrontend<K, V> create(CacheAccessor<K, V> cacheAccess
}

private void notifyInvalidate(K key) {
keyLocks.remove(key);

for (java.util.function.Consumer<K> invalidationListener : invalidationListeners) {
invalidationListener.accept(key);
Expand All @@ -111,24 +117,52 @@ private void notifyInvalidate(K key) {

@Override
public void close() {
keyLocks.clear();
redisCache.close();
}

public void addInvalidationListener(java.util.function.Consumer<K> invalidationListener) {
invalidationListeners.add(invalidationListener);
}

/**
* Execute the supplied function while holding the lock for the given key.
*
* @param key the key to lock
* @param supplier the function to execute under the lock
* @return the result of the supplied function
*/
private <T> T withKeyLock(K key, Supplier<T> supplier) {
ReentrantLock keyLock = keyLocks.computeIfAbsent(key, k -> new ReentrantLock());
keyLock.lock();
try {
return supplier.get();
} finally {
keyLock.unlock();
}
}

@Override
public V get(K key) {

V value = cacheAccessor.get(key);

if (value == null) {
value = redisCache.get(key);
value = withKeyLock(key, () -> {
V cachedValue = cacheAccessor.get(key);

if (cachedValue == null) {
V redisValue = redisCache.get(key);

if (redisValue != null) {
cacheAccessor.put(key, redisValue);
}

return redisValue;
}

if (value != null) {
cacheAccessor.put(key, value);
}
return cachedValue;
});
}

return value;
Expand All @@ -140,28 +174,38 @@ public V get(K key, Callable<V> valueLoader) {
V value = cacheAccessor.get(key);

if (value == null) {
value = redisCache.get(key);

if (value == null) {

try {
value = valueLoader.call();
} catch (Exception e) {
throw new ValueRetrievalException(
String.format("Value loader %s failed with an exception for key %s", valueLoader, key), e);
value = withKeyLock(key, () -> {
V cachedValue = cacheAccessor.get(key);

if (cachedValue == null) {
V redisValue = redisCache.get(key);

if (redisValue == null) {
try {
V loadedValue = valueLoader.call();

if (loadedValue == null) {
throw new ValueRetrievalException(
String.format("Value loader %s returned a null value for key %s", valueLoader, key));
}

redisCache.put(key, loadedValue);
redisCache.get(key);
cacheAccessor.put(key, loadedValue);

return loadedValue;
} catch (Exception e) {
throw new ValueRetrievalException(
String.format("Value loader %s failed with an exception for key %s", valueLoader, key), e);
}
} else {
cacheAccessor.put(key, redisValue);
return redisValue;
}
}

if (value == null) {
throw new ValueRetrievalException(
String.format("Value loader %s returned a null value for key %s", valueLoader, key));
}
redisCache.put(key, value);

// register interest in key
redisCache.get(key);
}

cacheAccessor.put(key, value);
return cachedValue;
});
}

return value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,6 @@
import static io.lettuce.TestTags.INTEGRATION_TEST;
import static org.assertj.core.api.Assertions.assertThat;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;

import javax.inject.Inject;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

import io.lettuce.core.ClientOptions;
import io.lettuce.core.RedisClient;
import io.lettuce.core.TestSupport;
Expand All @@ -29,11 +16,29 @@
import io.lettuce.test.LettuceExtension;
import io.lettuce.test.Wait;
import io.lettuce.test.condition.EnabledOnCommand;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import javax.inject.Inject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

/**
* Integration tests for server-side assisted cache invalidation.
*
* @author Mark Paluch
* @author Yoobin Yoon
*/
@Tag(INTEGRATION_TEST)
@ExtendWith(LettuceExtension.class)
Expand Down Expand Up @@ -227,4 +232,127 @@ void serverAssistedCachingShouldUseValueLoader() throws InterruptedException {
frontend.close();
}

@Test
void valueLoaderShouldBeInvokedOnceForConcurrentRequests() throws Exception {

Map<String, String> clientCache = new ConcurrentHashMap<>();

StatefulRedisConnection<String, String> connection = redisClient.connect();

final String testKey = "concurrent-loader-key";
connection.sync().del(testKey);

AtomicInteger loaderCallCount = new AtomicInteger(0);

CacheFrontend<String, String> frontend = ClientSideCaching.enable(CacheAccessor.forMap(clientCache), connection,
TrackingArgs.Builder.enabled());

try {
int threadCount = 10;
CountDownLatch startLatch = new CountDownLatch(1);
CountDownLatch finishLatch = new CountDownLatch(threadCount);
List<String> results = new CopyOnWriteArrayList<>();

ExecutorService executor = Executors.newFixedThreadPool(threadCount);
for (int i = 0; i < threadCount; i++) {
executor.submit(() -> {
try {
startLatch.await();

String result = frontend.get(testKey, () -> {
loaderCallCount.incrementAndGet();

try {
Thread.sleep(100);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}

return "loaded-value";
});

results.add(result);
} catch (Exception e) {
e.printStackTrace();
} finally {
finishLatch.countDown();
}
});
}

startLatch.countDown();

finishLatch.await(5, TimeUnit.SECONDS);
executor.shutdown();

assertThat(loaderCallCount.get()).isEqualTo(1);

assertThat(results).hasSize(threadCount);
assertThat(results).containsOnly("loaded-value");

assertThat(connection.sync().get(testKey)).isEqualTo("loaded-value");

assertThat(clientCache).containsEntry(testKey, "loaded-value");
} finally {
frontend.close();
connection.close();
}
}

@Test
void locksShouldBeProperlyCleanedUp() throws Exception {

Map<String, String> clientCache = new ConcurrentHashMap<>();

StatefulRedisConnection<String, String> connection = redisClient.connect();
StatefulRedisConnection<String, String> otherClient = redisClient.connect();

final String testKey1 = "lock-test-key1";
final String testKey2 = "lock-test-key2";
final String initialValue = "initial-value";
final String updatedValue = "updated-value";

connection.sync().del(testKey1, testKey2);
connection.sync().set(testKey1, initialValue);
connection.sync().set(testKey2, initialValue);

ClientSideCaching<String, String> frontend = (ClientSideCaching<String, String>) ClientSideCaching
.enable(CacheAccessor.forMap(clientCache), connection, TrackingArgs.Builder.enabled());

Field keyLocksField = ClientSideCaching.class.getDeclaredField("keyLocks");
keyLocksField.setAccessible(true);
ConcurrentHashMap<String, ReentrantLock> keyLocks = (ConcurrentHashMap<String, ReentrantLock>) keyLocksField
.get(frontend);

try {
frontend.get(testKey1);
frontend.get(testKey2);

assertThat(keyLocks).containsKey(testKey1);
assertThat(keyLocks).containsKey(testKey2);
assertThat(keyLocks).hasSize(2);

otherClient.sync().set(testKey1, updatedValue);

Thread.sleep(200);

assertThat(keyLocks).doesNotContainKey(testKey1);
assertThat(keyLocks).containsKey(testKey2);
assertThat(keyLocks).hasSize(1);

frontend.get(testKey1);

assertThat(keyLocks).containsKey(testKey1);
assertThat(keyLocks).hasSize(2);

frontend.close();

assertThat(keyLocks).isEmpty();

} finally {
connection.close();
otherClient.close();
}
}

}