Skip to content

Commit cc036c2

Browse files
committed
adding reauth support for both pubsub and shardedpubsub
1 parent 9717c9a commit cc036c2

File tree

6 files changed

+347
-75
lines changed

6 files changed

+347
-75
lines changed

src/main/java/redis/clients/jedis/Connection.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,8 @@ public class Connection implements Closeable {
4646
private String strVal;
4747
protected String server;
4848
protected String version;
49-
protected AtomicReference<RedisCredentials> currentCredentials = new AtomicReference<RedisCredentials>(
50-
null);
51-
private boolean isTokenBasedAuthenticationEnabled = false;
49+
private AtomicReference<RedisCredentials> currentCredentials = new AtomicReference<>(null);
50+
private AuthXManager authXManager;
5251

5352
public Connection() {
5453
this(Protocol.DEFAULT_HOST, Protocol.DEFAULT_PORT);
@@ -68,6 +67,7 @@ public Connection(final HostAndPort hostAndPort, final JedisClientConfig clientC
6867

6968
public Connection(final JedisSocketFactory socketFactory) {
7069
this.socketFactory = socketFactory;
70+
this.authXManager = null;
7171
}
7272

7373
public Connection(final JedisSocketFactory socketFactory, JedisClientConfig clientConfig) {
@@ -458,9 +458,8 @@ protected void initializeFromClientConfig(final JedisClientConfig config) {
458458

459459
Supplier<RedisCredentials> credentialsProvider = config.getCredentialsProvider();
460460

461-
AuthXManager authXManager = config.getAuthXManager();
461+
authXManager = config.getAuthXManager();
462462
if (authXManager != null) {
463-
isTokenBasedAuthenticationEnabled = true;
464463
credentialsProvider = authXManager;
465464
}
466465

@@ -608,7 +607,11 @@ public boolean ping() {
608607
return true;
609608
}
610609

611-
public boolean isTokenBasedAuthenticationEnabled() {
612-
return isTokenBasedAuthenticationEnabled;
610+
protected boolean isTokenBasedAuthenticationEnabled() {
611+
return authXManager != null;
612+
}
613+
614+
protected AuthXManager getAuthXManager() {
615+
return authXManager;
613616
}
614617
}

src/main/java/redis/clients/jedis/JedisPubSubBase.java

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import java.util.Arrays;
66
import java.util.List;
7+
import java.util.function.Consumer;
78

89
import redis.clients.jedis.Protocol.Command;
910
import redis.clients.jedis.exceptions.JedisException;
@@ -12,7 +13,8 @@
1213
public abstract class JedisPubSubBase<T> {
1314

1415
private int subscribedChannels = 0;
15-
private volatile Connection client;
16+
private final JedisSafeAuthenticator authenticator = new JedisSafeAuthenticator();
17+
private final Consumer<Object> pingResultHandler = this::processPingReply;
1618

1719
public void onMessage(T channel, T message) {
1820
}
@@ -36,12 +38,7 @@ public void onPong(T pattern) {
3638
}
3739

3840
private void sendAndFlushCommand(Command command, T... args) {
39-
if (client == null) {
40-
throw new JedisException(getClass() + " is not connected to a Connection.");
41-
}
42-
CommandArguments cargs = new CommandArguments(command).addObjects(args);
43-
client.sendCommand(cargs);
44-
client.flush();
41+
authenticator.sendAndFlushCommand(command, args);
4542
}
4643

4744
public final void unsubscribe() {
@@ -63,7 +60,8 @@ public final void psubscribe(T... patterns) {
6360
}
6461

6562
private void checkConnectionSuitableForPubSub() {
66-
if (client.protocol == RedisProtocol.RESP2 && client.isTokenBasedAuthenticationEnabled()) {
63+
if (authenticator.client.protocol != RedisProtocol.RESP3
64+
&& authenticator.client.isTokenBasedAuthenticationEnabled()) {
6765
throw new JedisException(
6866
"Blocking pub/sub operations are not supported on token-based authentication enabled connections with RESP2 protocol!");
6967
}
@@ -78,7 +76,13 @@ public final void punsubscribe(T... patterns) {
7876
}
7977

8078
public final void ping() {
81-
sendAndFlushCommand(Command.PING);
79+
authenticator.commandSync.lock();
80+
try {
81+
sendAndFlushCommand(Command.PING);
82+
authenticator.resultHandler.add(pingResultHandler);
83+
} finally {
84+
authenticator.commandSync.unlock();
85+
}
8286
}
8387

8488
public final void ping(T argument) {
@@ -94,24 +98,24 @@ public final int getSubscribedChannels() {
9498
}
9599

96100
public final void proceed(Connection client, T... channels) {
97-
this.client = client;
98-
this.client.setTimeoutInfinite();
101+
authenticator.registerForAuthentication(client);
102+
authenticator.client.setTimeoutInfinite();
99103
try {
100104
subscribe(channels);
101105
process();
102106
} finally {
103-
this.client.rollbackTimeout();
107+
authenticator.client.rollbackTimeout();
104108
}
105109
}
106110

107111
public final void proceedWithPatterns(Connection client, T... patterns) {
108-
this.client = client;
109-
this.client.setTimeoutInfinite();
112+
authenticator.registerForAuthentication(client);
113+
authenticator.client.setTimeoutInfinite();
110114
try {
111115
psubscribe(patterns);
112116
process();
113117
} finally {
114-
this.client.rollbackTimeout();
118+
authenticator.client.rollbackTimeout();
115119
}
116120
}
117121

@@ -121,7 +125,7 @@ public final void proceedWithPatterns(Connection client, T... patterns) {
121125
private void process() {
122126

123127
do {
124-
Object reply = client.getUnflushedObject();
128+
Object reply = authenticator.client.getUnflushedObject();
125129

126130
if (reply instanceof List) {
127131
List<Object> listReply = (List<Object>) reply;
@@ -175,12 +179,8 @@ private void process() {
175179
throw new JedisException("Unknown message type: " + firstObj);
176180
}
177181
} else if (reply instanceof byte[]) {
178-
byte[] resp = (byte[]) reply;
179-
if ("PONG".equals(SafeEncoder.encode(resp))) {
180-
onPong(null);
181-
} else {
182-
onPong(encode(resp));
183-
}
182+
Consumer<Object> resultHandler = authenticator.resultHandler.remove();
183+
resultHandler.accept(reply);
184184
} else {
185185
throw new JedisException("Unknown message type: " + reply);
186186
}
@@ -189,4 +189,13 @@ private void process() {
189189
// /* Invalidate instance since this thread is no longer listening */
190190
// this.client = null;
191191
}
192+
193+
private void processPingReply(Object reply) {
194+
byte[] resp = (byte[]) reply;
195+
if ("PONG".equals(SafeEncoder.encode(resp))) {
196+
onPong(null);
197+
} else {
198+
onPong(encode(resp));
199+
}
200+
}
192201
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package redis.clients.jedis;
2+
3+
import java.util.Queue;
4+
import java.util.concurrent.ConcurrentLinkedQueue;
5+
import java.util.concurrent.atomic.AtomicReference;
6+
import java.util.concurrent.locks.ReentrantLock;
7+
import java.util.function.Consumer;
8+
9+
import org.slf4j.Logger;
10+
import org.slf4j.LoggerFactory;
11+
12+
import redis.clients.authentication.core.SimpleToken;
13+
import redis.clients.authentication.core.Token;
14+
import redis.clients.jedis.Protocol.Command;
15+
import redis.clients.jedis.authentication.JedisAuthenticationException;
16+
import redis.clients.jedis.exceptions.JedisException;
17+
import redis.clients.jedis.util.SafeEncoder;
18+
19+
public class JedisSafeAuthenticator {
20+
21+
private static final Token PLACEHOLDER_TOKEN = new SimpleToken(null, null, 0, 0, null);
22+
private static final Logger logger = LoggerFactory.getLogger(JedisSafeAuthenticator.class);
23+
24+
protected volatile Connection client;
25+
protected final Consumer<Object> authResultHandler = this::processAuthReply;
26+
protected final Consumer<Token> authenticationHandler = this::safeReAuthenticate;
27+
28+
protected final AtomicReference<Token> pendingTokenRef = new AtomicReference<Token>(null);
29+
protected final ReentrantLock commandSync = new ReentrantLock();
30+
protected final Queue<Consumer<Object>> resultHandler = new ConcurrentLinkedQueue<Consumer<Object>>();
31+
32+
protected void sendAndFlushCommand(Command command, Object... args) {
33+
if (client == null) {
34+
throw new JedisException(getClass() + " is not connected to a Connection.");
35+
}
36+
CommandArguments cargs = new CommandArguments(command).addObjects(args);
37+
38+
Token newToken = pendingTokenRef.getAndSet(PLACEHOLDER_TOKEN);
39+
40+
// lets send the command without locking !!IF!! we know that pendingTokenRef is null replaced with PLACEHOLDER_TOKEN and no re-auth will go into action
41+
// !!ELSE!! we are locking since we already know a re-auth is still in progress in another thread and we need to wait for it to complete, we do nothing but wait on it!
42+
if (newToken != null) {
43+
commandSync.lock();
44+
}
45+
try {
46+
System.out.println("Sending command: " + command.toString());
47+
client.sendCommand(cargs);
48+
client.flush();
49+
} finally {
50+
Token newerToken = pendingTokenRef.getAndSet(null);
51+
// lets check if a newer token received since the beginning of this sendAndFlushCommand call
52+
if (newerToken != null && newerToken != PLACEHOLDER_TOKEN) {
53+
safeReAuthenticate(newerToken);
54+
}
55+
if (newToken != null) {
56+
commandSync.unlock();
57+
}
58+
}
59+
}
60+
61+
protected void registerForAuthentication(Connection newClient) {
62+
Connection oldClient = this.client;
63+
if (oldClient == newClient) return;
64+
if (oldClient != null && oldClient.getAuthXManager() != null) {
65+
oldClient.getAuthXManager().removePostAuthenticationHook(authenticationHandler);
66+
}
67+
if (newClient != null && newClient.getAuthXManager() != null) {
68+
newClient.getAuthXManager().addPostAuthenticationHook(authenticationHandler);
69+
}
70+
this.client = newClient;
71+
}
72+
73+
private void safeReAuthenticate(Token token) {
74+
try {
75+
byte[] rawPass = client.encodeToBytes(token.getValue().toCharArray());
76+
byte[] rawUser = client.encodeToBytes(token.getUser().toCharArray());
77+
78+
Token newToken = pendingTokenRef.getAndSet(token);
79+
if (newToken == null) {
80+
commandSync.lock();
81+
try {
82+
sendAndFlushCommand(Command.AUTH, rawUser, rawPass);
83+
resultHandler.add(this.authResultHandler);
84+
} finally {
85+
pendingTokenRef.set(null);
86+
commandSync.unlock();
87+
}
88+
}
89+
} catch (Exception e) {
90+
logger.error("Error while re-authenticating connection", e);
91+
client.getAuthXManager().getListener().onConnectionAuthenticationError(e);
92+
}
93+
}
94+
95+
protected void processAuthReply(Object reply) {
96+
byte[] resp = (byte[]) reply;
97+
String response = SafeEncoder.encode(resp);
98+
if (!"OK".equals(response)) {
99+
String msg = "Re-authentication failed with server response: " + response;
100+
Exception failedAuth = new JedisAuthenticationException(msg);
101+
logger.error(failedAuth.getMessage(), failedAuth);
102+
client.getAuthXManager().getListener().onConnectionAuthenticationError(failedAuth);
103+
}
104+
}
105+
}

src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
import java.util.Arrays;
66
import java.util.List;
7+
import java.util.function.Consumer;
78

89
import redis.clients.jedis.Protocol.Command;
910
import redis.clients.jedis.exceptions.JedisException;
1011

1112
public abstract class JedisShardedPubSubBase<T> {
1213

1314
private int subscribedChannels = 0;
14-
private volatile Connection client;
15+
private final JedisSafeAuthenticator authenticator = new JedisSafeAuthenticator();
1516

1617
public void onSMessage(T channel, T message) {
1718
}
@@ -23,12 +24,7 @@ public void onSUnsubscribe(T channel, int subscribedChannels) {
2324
}
2425

2526
private void sendAndFlushCommand(Command command, T... args) {
26-
if (client == null) {
27-
throw new JedisException(getClass() + " is not connected to a Connection.");
28-
}
29-
CommandArguments cargs = new CommandArguments(command).addObjects(args);
30-
client.sendCommand(cargs);
31-
client.flush();
27+
authenticator.sendAndFlushCommand(command, args);
3228
}
3329

3430
public final void sunsubscribe() {
@@ -40,9 +36,18 @@ public final void sunsubscribe(T... channels) {
4036
}
4137

4238
public final void ssubscribe(T... channels) {
39+
checkConnectionSuitableForPubSub();
4340
sendAndFlushCommand(Command.SSUBSCRIBE, channels);
4441
}
4542

43+
private void checkConnectionSuitableForPubSub() {
44+
if (authenticator.client.protocol != RedisProtocol.RESP3
45+
&& authenticator.client.isTokenBasedAuthenticationEnabled()) {
46+
throw new JedisException(
47+
"Blocking pub/sub operations are not supported on token-based authentication enabled connections with RESP2 protocol!");
48+
}
49+
}
50+
4651
public final boolean isSubscribed() {
4752
return subscribedChannels > 0;
4853
}
@@ -52,23 +57,22 @@ public final int getSubscribedChannels() {
5257
}
5358

5459
public final void proceed(Connection client, T... channels) {
55-
this.client = client;
56-
this.client.setTimeoutInfinite();
60+
authenticator.registerForAuthentication(client);
61+
authenticator.client.setTimeoutInfinite();
5762
try {
5863
ssubscribe(channels);
5964
process();
6065
} finally {
61-
this.client.rollbackTimeout();
66+
authenticator.client.rollbackTimeout();
6267
}
6368
}
6469

6570
protected abstract T encode(byte[] raw);
6671

67-
// private void process(Client client) {
6872
private void process() {
6973

7074
do {
71-
Object reply = client.getUnflushedObject();
75+
Object reply = authenticator.client.getUnflushedObject();
7276

7377
if (reply instanceof List) {
7478
List<Object> listReply = (List<Object>) reply;
@@ -96,6 +100,9 @@ private void process() {
96100
} else {
97101
throw new JedisException("Unknown message type: " + firstObj);
98102
}
103+
} else if (reply instanceof byte[]) {
104+
Consumer<Object> resultHandler = authenticator.resultHandler.remove();
105+
resultHandler.accept(reply);
99106
} else {
100107
throw new JedisException("Unknown message type: " + reply);
101108
}

src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,6 @@ public void allConnectionsReauthTest() throws InterruptedException, ExecutionExc
270270
}
271271
}
272272

273-
// T.3.2
274-
// Test system behavior when some connections fail to re-authenticate during bulk authentication. e.g when a network partition occurs for 1 or more of them
275-
@Test
276-
public void partialReauthFailureTest() {
277-
278-
}
279-
280273
// T.3.3
281274
// Verify behavior when attempting to authenticate a single connection with an expired token.
282275
@Test

0 commit comments

Comments
 (0)