Skip to content

Commit 19bf48c

Browse files
committed
feat: add credentialProvider option when creating clients
In some instances, credentials for the redis client will be short-lived and need to be fetched on-demand when connecting to redis. This is the case when connecting in AWS using IAM authentication or Entra ID in Azure. This feature allows for a credentialProvider to be provided which is a callable function returning a Promise that resolves to a username/password object.
1 parent ffa7d25 commit 19bf48c

File tree

3 files changed

+47
-10
lines changed

3 files changed

+47
-10
lines changed

docs/client-configuration.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
| scripts | | Script definitions (see [Lua Scripts](../README.md#lua-scripts)) |
2323
| functions | | Function definitions (see [Functions](../README.md#functions)) |
2424
| commandsQueueMaxLength | | Maximum length of the client's internal command queue |
25+
| credentialSupplier | | A callable function that returns a Promise which resolves to an object with username and password properties |
2526
| disableOfflineQueue | `false` | Disables offline queuing, see [FAQ](./FAQ.md#what-happens-when-the-network-goes-down) |
2627
| readonly | `false` | Connect in [`READONLY`](https://redis.io/commands/readonly) mode |
2728
| legacyMode | `false` | Maintain some backwards compatibility (see the [Migration Guide](./v3-to-v4.md)) |

packages/client/lib/client/index.spec.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { strict as assert } from 'node:assert';
2+
import { setTimeout } from 'node:timers/promises';
23
import testUtils, { GLOBAL, waitTillBeenCalled } from '../test-utils';
34
import RedisClient, { RedisClientType } from '.';
45
import { AbortError, ClientClosedError, ClientOfflineError, ConnectionTimeoutError, DisconnectsClientError, ErrorReply, MultiErrorReply, SocketClosedUnexpectedlyError, WatchError } from '../errors';
@@ -103,6 +104,21 @@ describe('Client', () => {
103104
},
104105
minimumDockerVersion: [6, 2]
105106
});
107+
108+
testUtils.testWithClient('should accept a credentialSupplier', async client => {
109+
assert.equal(
110+
await client.ping(),
111+
'PONG'
112+
);
113+
}, {
114+
...GLOBAL.SERVERS.PASSWORD,
115+
clientOptions: {
116+
// simulate a slight pause to fetch the credentials
117+
credentialSupplier: () => setTimeout(50).then(() => Promise.resolve({
118+
...GLOBAL.SERVERS.PASSWORD.clientOptions,
119+
})),
120+
}
121+
});
106122
});
107123

108124
testUtils.testWithClient('should set connection name', async client => {

packages/client/lib/client/index.ts

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@ import { Command, CommandSignature, TypeMapping, CommanderConfig, RedisFunction,
1111
import RedisClientMultiCommand, { RedisClientMultiCommandType } from './multi-command';
1212
import { RedisMultiQueuedCommand } from '../multi-command';
1313
import HELLO, { HelloOptions } from '../commands/HELLO';
14+
import { AuthOptions } from '../commands/AUTH';
1415
import { ScanOptions, ScanCommonOptions } from '../commands/SCAN';
1516
import { RedisLegacyClient, RedisLegacyClientType } from './legacy-mode';
1617
import { RedisPoolOptions, RedisClientPool } from './pool';
1718
import { RedisVariadicArgument, parseArgs, pushVariadicArguments } from '../commands/generic-transformers';
1819
import { BasicCommandParser, CommandParser } from './parser';
1920

21+
export type RedisCredentialSupplier = () => Promise<AuthOptions>;
22+
2023
export interface RedisClientOptions<
2124
M extends RedisModules = RedisModules,
2225
F extends RedisFunctions = RedisFunctions,
@@ -34,6 +37,10 @@ export interface RedisClientOptions<
3437
* Socket connection properties
3538
*/
3639
socket?: SocketOptions;
40+
/**
41+
* Credential supplier callback function
42+
*/
43+
credentialSupplier?: RedisCredentialSupplier;
3744
/**
3845
* ACL username ([see ACL guide](https://redis.io/topics/acl))
3946
*/
@@ -276,6 +283,7 @@ export default class RedisClient<
276283
readonly #options?: RedisClientOptions<M, F, S, RESP, TYPE_MAPPING>;
277284
readonly #socket: RedisSocket;
278285
readonly #queue: RedisCommandsQueue;
286+
#credentialSupplier?: RedisCredentialSupplier;
279287
#selectedDB = 0;
280288
#monitorCallback?: MonitorCallback<TYPE_MAPPING>;
281289
private _self = this;
@@ -334,6 +342,13 @@ export default class RedisClient<
334342
this._commandOptions = options.commandOptions;
335343
}
336344

345+
// if a credential supplier has been provided, use it, otherwise create a provider from the
346+
// supplier username and password (if provided), otherwise leave it undefined
347+
this.#credentialSupplier = options?.credentialSupplier ?? ((options?.username || options?.password) ? () => Promise.resolve({
348+
username: options?.username,
349+
password: options?.password ?? '',
350+
}) : undefined);
351+
337352
return options;
338353
}
339354

@@ -345,16 +360,16 @@ export default class RedisClient<
345360
);
346361
}
347362

348-
#handshake(selectedDB: number) {
363+
#handshake(selectedDB: number, credential?: AuthOptions) {
349364
const commands = [];
350365

351366
if (this.#options?.RESP) {
352367
const hello: HelloOptions = {};
353368

354-
if (this.#options.password) {
369+
if (credential?.password) {
355370
hello.AUTH = {
356-
username: this.#options.username ?? 'default',
357-
password: this.#options.password
371+
username: credential?.username ?? 'default',
372+
password: credential?.password
358373
};
359374
}
360375

@@ -366,11 +381,11 @@ export default class RedisClient<
366381
parseArgs(HELLO, this.#options.RESP, hello)
367382
);
368383
} else {
369-
if (this.#options?.username || this.#options?.password) {
384+
if (credential) {
370385
commands.push(
371386
parseArgs(COMMANDS.AUTH, {
372-
username: this.#options.username,
373-
password: this.#options.password ?? ''
387+
username: credential.username,
388+
password: credential.password ?? ''
374389
})
375390
);
376391
}
@@ -396,7 +411,11 @@ export default class RedisClient<
396411
}
397412

398413
#initiateSocket(): RedisSocket {
399-
const socketInitiator = () => {
414+
const socketInitiator = async () => {
415+
// we have to call the credential fetch before pushing any commands into the queue,
416+
// so fetch the credentials before doing anything else.
417+
const credential: AuthOptions | undefined = await this.#credentialSupplier?.();
418+
400419
const promises = [],
401420
chainId = Symbol('Socket Initiator');
402421

@@ -418,7 +437,7 @@ export default class RedisClient<
418437
);
419438
}
420439

421-
const commands = this.#handshake(this.#selectedDB);
440+
const commands = this.#handshake(this.#selectedDB, credential);
422441
for (let i = commands.length - 1; i >= 0; --i) {
423442
promises.push(
424443
this.#queue.addCommand(commands[i], {
@@ -997,10 +1016,11 @@ export default class RedisClient<
9971016
* Reset the client to its default state (i.e. stop PubSub, stop monitoring, select default DB, etc.)
9981017
*/
9991018
async reset() {
1019+
const credential: AuthOptions | undefined = await this.#credentialSupplier?.();
10001020
const chainId = Symbol('Reset Chain'),
10011021
promises = [this._self.#queue.reset(chainId)],
10021022
selectedDB = this._self.#options?.database ?? 0;
1003-
for (const command of this._self.#handshake(selectedDB)) {
1023+
for (const command of this._self.#handshake(selectedDB, credential)) {
10041024
promises.push(
10051025
this._self.#queue.addCommand(command, {
10061026
chainId

0 commit comments

Comments
 (0)