Skip to content

Introduce credentials provider (#3224) #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/main/java/redis/clients/jedis/CommandArguments.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ public ProtocolCommand getCommand() {
}

public CommandArguments add(Object arg) {
if (arg instanceof Rawable) {
if (arg == null) {
throw new IllegalArgumentException("null is not a valid argument.");
} else if (arg instanceof Rawable) {
args.add((Rawable) arg);
} else if (arg instanceof byte[]) {
args.add(RawableFactory.from((byte[]) arg));
Expand All @@ -37,9 +39,6 @@ public CommandArguments add(Object arg) {
} else if (arg instanceof Boolean) {
args.add(RawableFactory.from(Integer.toString((Boolean) arg ? 1 : 0)));
} else {
if (arg == null) {
throw new IllegalArgumentException("null is not a valid argument.");
}
args.add(RawableFactory.from(String.valueOf(arg)));
}
return this;
Expand Down
51 changes: 35 additions & 16 deletions src/main/java/redis/clients/jedis/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import java.io.IOException;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;

import redis.clients.jedis.args.Rawable;
import redis.clients.jedis.commands.ProtocolCommand;
Expand Down Expand Up @@ -336,15 +340,16 @@ public List<Object> getMany(final int count) {
private void initializeFromClientConfig(JedisClientConfig config) {
try {
connect();
String password = config.getPassword();
if (password != null) {
String user = config.getUser();
if (user != null) {
auth(user, password);
} else {
auth(password);
}

Supplier<RedisCredentials> credentialsProvider = config.getCredentialsProvider();
if (credentialsProvider instanceof RedisCredentialsProvider) {
((RedisCredentialsProvider) credentialsProvider).prepare();
auth(credentialsProvider);
((RedisCredentialsProvider) credentialsProvider).cleanUp();
} else {
auth(credentialsProvider);
}

int dbIndex = config.getDatabase();
if (dbIndex > 0) {
select(dbIndex);
Expand All @@ -354,27 +359,41 @@ private void initializeFromClientConfig(JedisClientConfig config) {
// TODO: need to figure out something without encoding
clientSetname(clientName);
}

} catch (JedisException je) {
try {
if (isConnected()) {
quit();
}
disconnect();
} catch (Exception e) {
//
// the first exception 'je' will be thrown
}
throw je;
}
}

private String auth(final String password) {
sendCommand(Protocol.Command.AUTH, password);
return getStatusCodeReply();
}
private void auth(final Supplier<RedisCredentials> credentialsProvider) {
RedisCredentials credentials = credentialsProvider.get();
if (credentials == null || credentials.getPassword() == null) return;

private String auth(final String user, final String password) {
sendCommand(Protocol.Command.AUTH, user, password);
return getStatusCodeReply();
// Source: https://stackoverflow.com/a/9670279/4021802
ByteBuffer passBuf = Protocol.CHARSET.encode(CharBuffer.wrap(credentials.getPassword()));
byte[] rawPass = Arrays.copyOfRange(passBuf.array(), passBuf.position(), passBuf.limit());
Arrays.fill(passBuf.array(), (byte) 0); // clear sensitive data

if (credentials.getUser() != null) {
sendCommand(Protocol.Command.AUTH, SafeEncoder.encode(credentials.getUser()), rawPass);
} else {
sendCommand(Protocol.Command.AUTH, rawPass);
}

Arrays.fill(rawPass, (byte) 0); // clear sensitive data

// clearing 'char[] credentials.getPassword()' should be
// handled in RedisCredentialsProvider.cleanUp()

getStatusCodeReply(); // OK
}

public String select(final int index) {
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/redis/clients/jedis/ConnectionFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ public ConnectionFactory(final JedisSocketFactory jedisSocketFactory, final Jedi
this.jedisSocketFactory = jedisSocketFactory;
}

/**
* @deprecated Use {@link RedisCredentialsProvider} through
* {@link JedisClientConfig#getCredentialsProvider()}.
*/
@Deprecated
public void setPassword(final String password) {
this.clientConfig.updatePassword(password);
}
Expand Down
64 changes: 43 additions & 21 deletions src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package redis.clients.jedis;

import java.util.Objects;
import java.util.function.Supplier;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocketFactory;
Expand All @@ -11,8 +11,7 @@ public final class DefaultJedisClientConfig implements JedisClientConfig {
private final int socketTimeoutMillis;
private final int blockingSocketTimeoutMillis;

private final String user;
private volatile String password;
private volatile Supplier<RedisCredentials> credentialsProvider;
private final int database;
private final String clientName;

Expand All @@ -24,14 +23,13 @@ public final class DefaultJedisClientConfig implements JedisClientConfig {
private final HostAndPortMapper hostAndPortMapper;

private DefaultJedisClientConfig(int connectionTimeoutMillis, int soTimeoutMillis,
int blockingSocketTimeoutMillis, String user, String password, int database, String clientName,
boolean ssl, SSLSocketFactory sslSocketFactory, SSLParameters sslParameters,
int blockingSocketTimeoutMillis, Supplier<RedisCredentials> credentialsProvider, int database,
String clientName, boolean ssl, SSLSocketFactory sslSocketFactory, SSLParameters sslParameters,
HostnameVerifier hostnameVerifier, HostAndPortMapper hostAndPortMapper) {
this.connectionTimeoutMillis = connectionTimeoutMillis;
this.socketTimeoutMillis = soTimeoutMillis;
this.blockingSocketTimeoutMillis = blockingSocketTimeoutMillis;
this.user = user;
this.password = password;
this.credentialsProvider = credentialsProvider;
this.database = database;
this.clientName = clientName;
this.ssl = ssl;
Expand All @@ -58,19 +56,25 @@ public int getBlockingSocketTimeoutMillis() {

@Override
public String getUser() {
return user;
return credentialsProvider.get().getUser();
}

@Override
public String getPassword() {
return password;
char[] password = credentialsProvider.get().getPassword();
return password == null ? null : new String(password);
}

@Override
public Supplier<RedisCredentials> getCredentialsProvider() {
return credentialsProvider;
}

@Override
@Deprecated
public synchronized void updatePassword(String password) {
if (!Objects.equals(this.password, password)) {
this.password = password;
}
((DefaultRedisCredentialsProvider) this.credentialsProvider)
.setCredentials(new DefaultRedisCredentials(getUser(), password));
}

@Override
Expand Down Expand Up @@ -120,6 +124,7 @@ public static class Builder {

private String user = null;
private String password = null;
private Supplier<RedisCredentials> credentialsProvider;
private int database = Protocol.DEFAULT_DATABASE;
private String clientName = null;

Expand All @@ -134,9 +139,14 @@ private Builder() {
}

public DefaultJedisClientConfig build() {
if (credentialsProvider == null) {
credentialsProvider = new DefaultRedisCredentialsProvider(
new DefaultRedisCredentials(user, password));
}

return new DefaultJedisClientConfig(connectionTimeoutMillis, socketTimeoutMillis,
blockingSocketTimeoutMillis, user, password, database, clientName, ssl, sslSocketFactory,
sslParameters, hostnameVerifier, hostAndPortMapper);
blockingSocketTimeoutMillis, credentialsProvider, database, clientName, ssl,
sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper);
}

public Builder timeoutMillis(int timeoutMillis) {
Expand Down Expand Up @@ -170,6 +180,16 @@ public Builder password(String password) {
return this;
}

public Builder credentials(RedisCredentials credentials) {
this.credentialsProvider = new DefaultRedisCredentialsProvider(credentials);
return this;
}

public Builder credentialsProvider(Supplier<RedisCredentials> credentials) {
this.credentialsProvider = credentials;
return this;
}

public Builder database(int database) {
this.database = database;
return this;
Expand Down Expand Up @@ -210,16 +230,18 @@ public static DefaultJedisClientConfig create(int connectionTimeoutMillis, int s
int blockingSocketTimeoutMillis, String user, String password, int database, String clientName,
boolean ssl, SSLSocketFactory sslSocketFactory, SSLParameters sslParameters,
HostnameVerifier hostnameVerifier, HostAndPortMapper hostAndPortMapper) {
return new DefaultJedisClientConfig(connectionTimeoutMillis, soTimeoutMillis,
blockingSocketTimeoutMillis, user, password, database, clientName, ssl,
sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper);
return new DefaultJedisClientConfig(
connectionTimeoutMillis, soTimeoutMillis, blockingSocketTimeoutMillis,
new DefaultRedisCredentialsProvider(new DefaultRedisCredentials(user, password)),
database, clientName, ssl, sslSocketFactory, sslParameters,
hostnameVerifier, hostAndPortMapper);
}

public static DefaultJedisClientConfig copyConfig(JedisClientConfig copy) {
return new DefaultJedisClientConfig(copy.getConnectionTimeoutMillis(),
copy.getSocketTimeoutMillis(), copy.getBlockingSocketTimeoutMillis(), copy.getUser(),
copy.getPassword(), copy.getDatabase(), copy.getClientName(), copy.isSsl(),
copy.getSslSocketFactory(), copy.getSslParameters(), copy.getHostnameVerifier(),
copy.getHostAndPortMapper());
copy.getSocketTimeoutMillis(), copy.getBlockingSocketTimeoutMillis(),
copy.getCredentialsProvider(), copy.getDatabase(), copy.getClientName(),
copy.isSsl(), copy.getSslSocketFactory(), copy.getSslParameters(),
copy.getHostnameVerifier(), copy.getHostAndPortMapper());
}
}
38 changes: 38 additions & 0 deletions src/main/java/redis/clients/jedis/DefaultRedisCredentials.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package redis.clients.jedis;

public final class DefaultRedisCredentials implements RedisCredentials {

private final String user;
private final char[] password;

public DefaultRedisCredentials(String user, char[] password) {
this.user = user;
this.password = password;
}

public DefaultRedisCredentials(String user, CharSequence password) {
this.user = user;
this.password = password == null ? null
: password instanceof String ? ((String) password).toCharArray()
: toCharArray(password);
}

@Override
public String getUser() {
return user;
}

@Override
public char[] getPassword() {
return password;
}

private static char[] toCharArray(CharSequence seq) {
final int len = seq.length();
char[] arr = new char[len];
for (int i = 0; i < len; i++) {
arr[i] = seq.charAt(i);
}
return arr;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package redis.clients.jedis;

public final class DefaultRedisCredentialsProvider implements RedisCredentialsProvider {

private volatile RedisCredentials credentials;

public DefaultRedisCredentialsProvider(RedisCredentials credentials) {
this.credentials = credentials;
}

public void setCredentials(RedisCredentials credentials) {
this.credentials = credentials;
}

@Override
public RedisCredentials get() {
return this.credentials;
}
}
7 changes: 7 additions & 0 deletions src/main/java/redis/clients/jedis/JedisClientConfig.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package redis.clients.jedis;

import java.util.function.Supplier;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocketFactory;
Expand Down Expand Up @@ -39,9 +40,15 @@ default String getPassword() {
return null;
}

@Deprecated
default void updatePassword(String password) {
}

default Supplier<RedisCredentials> getCredentialsProvider() {
return new DefaultRedisCredentialsProvider(
new DefaultRedisCredentials(getUser(), getPassword()));
}

default int getDatabase() {
return Protocol.DEFAULT_DATABASE;
}
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/redis/clients/jedis/JedisFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ void setHostAndPort(final HostAndPort hostAndPort) {
((DefaultJedisSocketFactory) jedisSocketFactory).updateHostAndPort(hostAndPort);
}

/**
* @deprecated Use {@link RedisCredentialsProvider} through
* {@link JedisClientConfig#getCredentialsProvider()}.
*/
@Deprecated
public void setPassword(final String password) {
this.clientConfig.updatePassword(password);
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/redis/clients/jedis/Protocol.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public final class Protocol {
private static final String NOPERM_PREFIX = "NOPERM";

private Protocol() {
// this prevent the class from instantiation
throw new InstantiationError("Must not instantiate this class");
}

public static void sendCommand(final RedisOutputStream os, CommandArguments args) {
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/redis/clients/jedis/RedisCredentials.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package redis.clients.jedis;

public interface RedisCredentials {

/**
* @return Redis ACL user
*/
default String getUser() {
return null;
}

default char[] getPassword() {
return null;
}
}
Loading