Skip to content

Commit 1b0ac32

Browse files
committed
feat: add credentialProvider option when creating clients
In some instances, credentials for the redis client will be short-lived and need to be fetched on-demand when connecting to redis. This is the case when connecting in AWS using IAM authentication or Entra ID in Azure. This feature allows for a credentialProvider to be provided which is a callable function returning a Promise that resolves to a username/password object.
1 parent ffa7d25 commit 1b0ac32

File tree

3 files changed

+51
-10
lines changed

3 files changed

+51
-10
lines changed

docs/client-configuration.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
| scripts | | Script definitions (see [Lua Scripts](../README.md#lua-scripts)) |
2323
| functions | | Function definitions (see [Functions](../README.md#functions)) |
2424
| commandsQueueMaxLength | | Maximum length of the client's internal command queue |
25+
| credentialSupplier | | A callable function that returns a Promise which resolves to an object with username and password properties |
2526
| disableOfflineQueue | `false` | Disables offline queuing, see [FAQ](./FAQ.md#what-happens-when-the-network-goes-down) |
2627
| readonly | `false` | Connect in [`READONLY`](https://redis.io/commands/readonly) mode |
2728
| legacyMode | `false` | Maintain some backwards compatibility (see the [Migration Guide](./v3-to-v4.md)) |

packages/client/lib/client/index.spec.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { strict as assert } from 'node:assert';
2+
import { setTimeout } from 'node:timers/promises';
23
import testUtils, { GLOBAL, waitTillBeenCalled } from '../test-utils';
34
import RedisClient, { RedisClientType } from '.';
45
import { AbortError, ClientClosedError, ClientOfflineError, ConnectionTimeoutError, DisconnectsClientError, ErrorReply, MultiErrorReply, SocketClosedUnexpectedlyError, WatchError } from '../errors';
@@ -103,6 +104,21 @@ describe('Client', () => {
103104
},
104105
minimumDockerVersion: [6, 2]
105106
});
107+
108+
testUtils.testWithClient('should accept a credentialSupplier', async client => {
109+
assert.equal(
110+
await client.ping(),
111+
'PONG'
112+
);
113+
}, {
114+
...GLOBAL.SERVERS.PASSWORD,
115+
clientOptions: {
116+
// simulate a slight pause to fetch the credentials
117+
credentialSupplier: () => setTimeout(50).then(() => Promise.resolve({
118+
...GLOBAL.SERVERS.PASSWORD.clientOptions,
119+
})),
120+
}
121+
});
106122
});
107123

108124
testUtils.testWithClient('should set connection name', async client => {

packages/client/lib/client/index.ts

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@ import { Command, CommandSignature, TypeMapping, CommanderConfig, RedisFunction,
1111
import RedisClientMultiCommand, { RedisClientMultiCommandType } from './multi-command';
1212
import { RedisMultiQueuedCommand } from '../multi-command';
1313
import HELLO, { HelloOptions } from '../commands/HELLO';
14+
import { AuthOptions } from '../commands/AUTH';
1415
import { ScanOptions, ScanCommonOptions } from '../commands/SCAN';
1516
import { RedisLegacyClient, RedisLegacyClientType } from './legacy-mode';
1617
import { RedisPoolOptions, RedisClientPool } from './pool';
1718
import { RedisVariadicArgument, parseArgs, pushVariadicArguments } from '../commands/generic-transformers';
1819
import { BasicCommandParser, CommandParser } from './parser';
1920

21+
export type RedisCredentialSupplier = () => Promise<AuthOptions | undefined>;
22+
2023
export interface RedisClientOptions<
2124
M extends RedisModules = RedisModules,
2225
F extends RedisFunctions = RedisFunctions,
@@ -34,6 +37,10 @@ export interface RedisClientOptions<
3437
* Socket connection properties
3538
*/
3639
socket?: SocketOptions;
40+
/**
41+
* Credential supplier callback function
42+
*/
43+
credentialSupplier?: RedisCredentialSupplier;
3744
/**
3845
* ACL username ([see ACL guide](https://redis.io/topics/acl))
3946
*/
@@ -276,6 +283,7 @@ export default class RedisClient<
276283
readonly #options?: RedisClientOptions<M, F, S, RESP, TYPE_MAPPING>;
277284
readonly #socket: RedisSocket;
278285
readonly #queue: RedisCommandsQueue;
286+
#credentialSupplier: RedisCredentialSupplier;
279287
#selectedDB = 0;
280288
#monitorCallback?: MonitorCallback<TYPE_MAPPING>;
281289
private _self = this;
@@ -313,6 +321,8 @@ export default class RedisClient<
313321
this.#options = this.#initiateOptions(options);
314322
this.#queue = this.#initiateQueue();
315323
this.#socket = this.#initiateSocket();
324+
this.#credentialSupplier = this.#initiateCredentialSupplier();
325+
316326
this.#epoch = 0;
317327
}
318328

@@ -345,16 +355,16 @@ export default class RedisClient<
345355
);
346356
}
347357

348-
#handshake(selectedDB: number) {
358+
#handshake(selectedDB: number, credential?: AuthOptions) {
349359
const commands = [];
350360

351361
if (this.#options?.RESP) {
352362
const hello: HelloOptions = {};
353363

354-
if (this.#options.password) {
364+
if (credential?.password) {
355365
hello.AUTH = {
356-
username: this.#options.username ?? 'default',
357-
password: this.#options.password
366+
username: credential?.username ?? 'default',
367+
password: credential?.password
358368
};
359369
}
360370

@@ -366,11 +376,11 @@ export default class RedisClient<
366376
parseArgs(HELLO, this.#options.RESP, hello)
367377
);
368378
} else {
369-
if (this.#options?.username || this.#options?.password) {
379+
if (credential) {
370380
commands.push(
371381
parseArgs(COMMANDS.AUTH, {
372-
username: this.#options.username,
373-
password: this.#options.password ?? ''
382+
username: credential.username,
383+
password: credential.password ?? ''
374384
})
375385
);
376386
}
@@ -396,7 +406,11 @@ export default class RedisClient<
396406
}
397407

398408
#initiateSocket(): RedisSocket {
399-
const socketInitiator = () => {
409+
const socketInitiator = async () => {
410+
// we have to call the credential fetch before pushing any commands into the queue,
411+
// so fetch the credentials before doing anything else.
412+
const credential: AuthOptions | undefined = await this.#credentialSupplier();
413+
400414
const promises = [],
401415
chainId = Symbol('Socket Initiator');
402416

@@ -418,7 +432,7 @@ export default class RedisClient<
418432
);
419433
}
420434

421-
const commands = this.#handshake(this.#selectedDB);
435+
const commands = this.#handshake(this.#selectedDB, credential);
422436
for (let i = commands.length - 1; i >= 0; --i) {
423437
promises.push(
424438
this.#queue.addCommand(commands[i], {
@@ -463,6 +477,15 @@ export default class RedisClient<
463477
.on('end', () => this.emit('end'));
464478
}
465479

480+
#initiateCredentialSupplier(): RedisCredentialSupplier {
481+
// if a credential supplier has been provided, use it, otherwise create a provider from the
482+
// supplier username and password (if provided)
483+
return this.#options?.credentialSupplier ?? (() => Promise.resolve((this.#options?.username || this.#options?.password) ? {
484+
username: this.#options?.username,
485+
password: this.#options?.password ?? '',
486+
} : undefined));
487+
}
488+
466489
#pingTimer?: NodeJS.Timeout;
467490

468491
#setPingTimer(): void {
@@ -997,10 +1020,11 @@ export default class RedisClient<
9971020
* Reset the client to its default state (i.e. stop PubSub, stop monitoring, select default DB, etc.)
9981021
*/
9991022
async reset() {
1023+
const credential: AuthOptions | undefined = await this.#credentialSupplier?.();
10001024
const chainId = Symbol('Reset Chain'),
10011025
promises = [this._self.#queue.reset(chainId)],
10021026
selectedDB = this._self.#options?.database ?? 0;
1003-
for (const command of this._self.#handshake(selectedDB)) {
1027+
for (const command of this._self.#handshake(selectedDB, credential)) {
10041028
promises.push(
10051029
this._self.#queue.addCommand(command, {
10061030
chainId

0 commit comments

Comments
 (0)