Skip to content

Commit 8fcb0d2

Browse files
committed
feat: add credentialProvider option when creating clients
In some instances, credentials for the redis client will be short-lived and need to be fetched on-demand when connecting to redis. This is the case when connecting in AWS using IAM authentication or Entra ID in Azure. This feature allows for a credentialProvider to be provided which is a callable function returning a Promise that resolves to a username/password object.
1 parent 5ace34b commit 8fcb0d2

File tree

3 files changed

+47
-10
lines changed

3 files changed

+47
-10
lines changed

docs/client-configuration.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
| scripts | | Script definitions (see [Lua Scripts](../README.md#lua-scripts)) |
2323
| functions | | Function definitions (see [Functions](../README.md#functions)) |
2424
| commandsQueueMaxLength | | Maximum length of the client's internal command queue |
25+
| credentialSupplier | | A callable function that returns a Promise which resolves to an object with username and password properties |
2526
| disableOfflineQueue | `false` | Disables offline queuing, see [FAQ](./FAQ.md#what-happens-when-the-network-goes-down) |
2627
| readonly | `false` | Connect in [`READONLY`](https://redis.io/commands/readonly) mode |
2728
| legacyMode | `false` | Maintain some backwards compatibility (see the [Migration Guide](./v3-to-v4.md)) |

packages/client/lib/client/index.spec.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { strict as assert } from 'node:assert';
2+
import { setTimeout } from 'node:timers/promises';
23
import testUtils, { GLOBAL, waitTillBeenCalled } from '../test-utils';
34
import RedisClient, { RedisClientType } from '.';
45
import { AbortError, ClientClosedError, ClientOfflineError, ConnectionTimeoutError, DisconnectsClientError, ErrorReply, MultiErrorReply, SocketClosedUnexpectedlyError, WatchError } from '../errors';
@@ -102,6 +103,21 @@ describe('Client', () => {
102103
},
103104
minimumDockerVersion: [6, 2]
104105
});
106+
107+
testUtils.testWithClient('should accept a credentialSupplier', async client => {
108+
assert.equal(
109+
await client.ping(),
110+
'PONG'
111+
);
112+
}, {
113+
...GLOBAL.SERVERS.PASSWORD,
114+
clientOptions: {
115+
// simulate a slight pause to fetch the credentials
116+
credentialSupplier: () => setTimeout(50).then(() => Promise.resolve({
117+
...GLOBAL.SERVERS.PASSWORD.clientOptions,
118+
})),
119+
}
120+
});
105121
});
106122

107123
testUtils.testWithClient('should set connection name', async client => {

packages/client/lib/client/index.ts

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@ import { Command, CommandSignature, TypeMapping, CommanderConfig, RedisFunction,
1111
import RedisClientMultiCommand, { RedisClientMultiCommandType } from './multi-command';
1212
import { RedisMultiQueuedCommand } from '../multi-command';
1313
import HELLO, { HelloOptions } from '../commands/HELLO';
14+
import { AuthOptions } from '../commands/AUTH';
1415
import { ScanOptions, ScanCommonOptions } from '../commands/SCAN';
1516
import { RedisLegacyClient, RedisLegacyClientType } from './legacy-mode';
1617
import { RedisPoolOptions, RedisClientPool } from './pool';
1718
import { RedisVariadicArgument, pushVariadicArguments } from '../commands/generic-transformers';
1819

20+
export type RedisCredentialSupplier = () => Promise<AuthOptions>;
21+
1922
export interface RedisClientOptions<
2023
M extends RedisModules = RedisModules,
2124
F extends RedisFunctions = RedisFunctions,
@@ -33,6 +36,10 @@ export interface RedisClientOptions<
3336
* Socket connection properties
3437
*/
3538
socket?: SocketOptions;
39+
/**
40+
* Credential supplier callback function
41+
*/
42+
credentialSupplier?: RedisCredentialSupplier;
3643
/**
3744
* ACL username ([see ACL guide](https://redis.io/topics/acl))
3845
*/
@@ -289,6 +296,7 @@ export default class RedisClient<
289296
readonly #options?: RedisClientOptions<M, F, S, RESP, TYPE_MAPPING>;
290297
readonly #socket: RedisSocket;
291298
readonly #queue: RedisCommandsQueue;
299+
#credentialSupplier?: RedisCredentialSupplier;
292300
#selectedDB = 0;
293301
#monitorCallback?: MonitorCallback<TYPE_MAPPING>;
294302
private _self = this;
@@ -347,6 +355,13 @@ export default class RedisClient<
347355
this._commandOptions = options.commandOptions;
348356
}
349357

358+
// if a credential supplier has been provided, use it, otherwise create a provider from the
359+
// supplier username and password (if provided), otherwise leave it undefined
360+
this.#credentialSupplier = options?.credentialSupplier ?? ((options?.username || options?.password) ? () => Promise.resolve({
361+
username: options?.username,
362+
password: options?.password ?? '',
363+
}) : undefined);
364+
350365
return options;
351366
}
352367

@@ -358,16 +373,16 @@ export default class RedisClient<
358373
);
359374
}
360375

361-
#handshake(selectedDB: number) {
376+
#handshake(selectedDB: number, credential?: AuthOptions) {
362377
const commands = [];
363378

364379
if (this.#options?.RESP) {
365380
const hello: HelloOptions = {};
366381

367-
if (this.#options.password) {
382+
if (credential?.password) {
368383
hello.AUTH = {
369-
username: this.#options.username ?? 'default',
370-
password: this.#options.password
384+
username: credential?.username ?? 'default',
385+
password: credential?.password
371386
};
372387
}
373388

@@ -379,11 +394,11 @@ export default class RedisClient<
379394
HELLO.transformArguments(this.#options.RESP, hello)
380395
);
381396
} else {
382-
if (this.#options?.username || this.#options?.password) {
397+
if (credential) {
383398
commands.push(
384399
COMMANDS.AUTH.transformArguments({
385-
username: this.#options.username,
386-
password: this.#options.password ?? ''
400+
username: credential.username,
401+
password: credential.password ?? ''
387402
})
388403
);
389404
}
@@ -409,7 +424,11 @@ export default class RedisClient<
409424
}
410425

411426
#initiateSocket(): RedisSocket {
412-
const socketInitiator = () => {
427+
const socketInitiator = async () => {
428+
// we have to call the credential fetch before pushing any commands into the queue,
429+
// so fetch the credentials before doing anything else.
430+
const credential: AuthOptions | undefined = await this.#credentialSupplier?.();
431+
413432
const promises = [],
414433
chainId = Symbol('Socket Initiator');
415434

@@ -431,7 +450,7 @@ export default class RedisClient<
431450
);
432451
}
433452

434-
const commands = this.#handshake(this.#selectedDB);
453+
const commands = this.#handshake(this.#selectedDB, credential);
435454
for (let i = commands.length - 1; i >= 0; --i) {
436455
promises.push(
437456
this.#queue.addCommand(commands[i], {
@@ -981,10 +1000,11 @@ export default class RedisClient<
9811000
* Reset the client to its default state (i.e. stop PubSub, stop monitoring, select default DB, etc.)
9821001
*/
9831002
async reset() {
1003+
const credential: AuthOptions | undefined = await this.#credentialSupplier?.();
9841004
const chainId = Symbol('Reset Chain'),
9851005
promises = [this._self.#queue.reset(chainId)],
9861006
selectedDB = this._self.#options?.database ?? 0;
987-
for (const command of this._self.#handshake(selectedDB)) {
1007+
for (const command of this._self.#handshake(selectedDB, credential)) {
9881008
promises.push(
9891009
this._self.#queue.addCommand(command, {
9901010
chainId

0 commit comments

Comments
 (0)