Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ internal sealed class ScramShaFirstSaslStep : ISaslStep
private readonly IScramShaAlgorithm _algorithm;
private readonly ScramCache _cache;
private readonly UsernamePasswordCredential _credential;
private readonly IRandomStringGenerator _randomStringGenerator;
private readonly IRandom _random;

public ScramShaFirstSaslStep(IScramShaAlgorithm algorithm, UsernamePasswordCredential credential, IRandomStringGenerator randomStringGenerator, ScramCache cache)
public ScramShaFirstSaslStep(IScramShaAlgorithm algorithm, UsernamePasswordCredential credential, IRandom random, ScramCache cache)
{
_algorithm = algorithm;
_credential = credential;
_randomStringGenerator = randomStringGenerator;
_random = random;
_cache = cache;
}

Expand Down Expand Up @@ -63,7 +63,7 @@ private string GenerateRandomString()
{
const string legalCharacters = "!\"#$%&'()*+-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~";

return _randomStringGenerator.Generate(20, legalCharacters);
return _random.GenerateString(20, legalCharacters);
}

private string PrepUsername(string username)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,25 @@ internal sealed class ScramShaSaslMechanism : ISaslMechanism
public const string ScramSha256MechanismName = "SCRAM-SHA-256";

public static ScramShaSaslMechanism CreateScramSha1Mechanism(SaslContext context)
=> CreateScramSha1Mechanism(context, DefaultRandomStringGenerator.Instance);
=> CreateScramSha1Mechanism(context, DefaultRandom.Instance);

internal static ScramShaSaslMechanism CreateScramSha1Mechanism(SaslContext context, IRandomStringGenerator randomStringGenerator)
=> Create(context, ScramSha1MechanismName, new ScramSha1Algorithm(), randomStringGenerator);
internal static ScramShaSaslMechanism CreateScramSha1Mechanism(SaslContext context, IRandom random)
=> Create(context, ScramSha1MechanismName, new ScramSha1Algorithm(), random);

public static ScramShaSaslMechanism CreateScramSha256Mechanism(SaslContext context)
=> CreateScramSha256Mechanism(context, DefaultRandomStringGenerator.Instance);
=> CreateScramSha256Mechanism(context, DefaultRandom.Instance);

internal static ScramShaSaslMechanism CreateScramSha256Mechanism(SaslContext context, IRandomStringGenerator randomStringGenerator)
=> Create(context, ScramSha256MechanismName, new ScramSha256Algorithm(), randomStringGenerator);
internal static ScramShaSaslMechanism CreateScramSha256Mechanism(SaslContext context, IRandom random)
=> Create(context, ScramSha256MechanismName, new ScramSha256Algorithm(), random);

private static ScramShaSaslMechanism Create(
SaslContext context,
string mechanismName,
IScramShaAlgorithm algorithm,
IRandomStringGenerator randomStringGenerator)
IRandom random)
{
Ensure.IsNotNull(context, nameof(context));
Ensure.IsNotNull(randomStringGenerator, nameof(randomStringGenerator));
Ensure.IsNotNull(random, nameof(random));
if (context.Mechanism != mechanismName)
{
throw new InvalidOperationException($"Unexpected authentication mechanism: {context.Mechanism}");
Expand All @@ -63,25 +63,25 @@ private static ScramShaSaslMechanism Create(
throw new NotSupportedException($"{mechanismName} auth mechanism require password.");
}

return new ScramShaSaslMechanism(mechanismName, algorithm, credential, randomStringGenerator, new ScramCache());
return new ScramShaSaslMechanism(mechanismName, algorithm, credential, random, new ScramCache());
}

private readonly IScramShaAlgorithm _algorithm;
private readonly ScramCache _cache;
private readonly UsernamePasswordCredential _credential;
private readonly IRandomStringGenerator _randomStringGenerator;
private readonly IRandom _random;

private ScramShaSaslMechanism(
string mechanismName,
IScramShaAlgorithm algorithm,
UsernamePasswordCredential credential,
IRandomStringGenerator randomStringGenerator,
IRandom random,
ScramCache cache)
{
Name = mechanismName;
_algorithm = algorithm;
_credential = credential;
_randomStringGenerator = randomStringGenerator;
_random = random;
_cache = cache;
}

Expand All @@ -99,7 +99,7 @@ public BsonDocument CustomizeSaslStartCommand(BsonDocument startCommand)
}

public ISaslStep Initialize(SaslConversation conversation, ConnectionDescription description)
=> new ScramShaFirstSaslStep(_algorithm, _credential, _randomStringGenerator, _cache);
=> new ScramShaFirstSaslStep(_algorithm, _credential, _random, _cache);

public void OnReAuthenticationRequired()
{
Expand Down
10 changes: 6 additions & 4 deletions src/MongoDB.Driver/ClientSessionHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ internal sealed class ClientSessionHandle : IClientSessionHandle, IClientSession
private readonly ICoreSessionHandle _coreSession;
private bool _disposed;
private readonly ClientSessionOptions _options;
private readonly IRandom _random;
private IServerSession _serverSession;

// constructors
Expand All @@ -44,16 +45,17 @@ internal sealed class ClientSessionHandle : IClientSessionHandle, IClientSession
/// <param name="options">The options.</param>
/// <param name="coreSession">The wrapped session.</param>
public ClientSessionHandle(IMongoClient client, ClientSessionOptions options, ICoreSessionHandle coreSession)
: this(client, options, coreSession, SystemClock.Instance)
: this(client, options, coreSession, SystemClock.Instance, DefaultRandom.Instance)
{
}

internal ClientSessionHandle(IMongoClient client, ClientSessionOptions options, ICoreSessionHandle coreSession, IClock clock)
internal ClientSessionHandle(IMongoClient client, ClientSessionOptions options, ICoreSessionHandle coreSession, IClock clock, IRandom random)
{
_client = client;
_options = options;
_coreSession = coreSession;
_clock = clock;
_random = random;
}

// public properties
Expand Down Expand Up @@ -166,15 +168,15 @@ public void StartTransaction(TransactionOptions transactionOptions = null)
{
Ensure.IsNotNull(callback, nameof(callback));

return TransactionExecutor.ExecuteWithRetries(this, callback, transactionOptions, _clock, cancellationToken);
return TransactionExecutor.ExecuteWithRetries(this, callback, transactionOptions, _clock, _random, cancellationToken);
}

/// <inheritdoc />
public Task<TResult> WithTransactionAsync<TResult>(Func<IClientSessionHandle, CancellationToken, Task<TResult>> callbackAsync, TransactionOptions transactionOptions = null, CancellationToken cancellationToken = default(CancellationToken))
{
Ensure.IsNotNull(callbackAsync, nameof(callbackAsync));

return TransactionExecutor.ExecuteWithRetriesAsync(this, callbackAsync, transactionOptions, _clock, cancellationToken);
return TransactionExecutor.ExecuteWithRetriesAsync(this, callbackAsync, transactionOptions, _clock, _random, cancellationToken);
}

private TransactionOptions GetEffectiveTransactionOptions(TransactionOptions transactionOptions)
Expand Down
71 changes: 71 additions & 0 deletions src/MongoDB.Driver/Core/Misc/DefaultRandom.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Security.Cryptography;

namespace MongoDB.Driver.Core.Misc;

internal sealed class DefaultRandom : IRandom
{
public static DefaultRandom Instance { get; } = new DefaultRandom();

public string GenerateString(int length, string legalCharacters)
{
Ensure.IsGreaterThanOrEqualToZero(length, nameof(length));
Ensure.IsNotNullOrEmpty(legalCharacters, nameof(legalCharacters));

if (length == 0)
{
return string.Empty;
}

#if NET472
var randomData = new byte[length];
using (var rng = RandomNumberGenerator.Create())
{
rng.GetBytes(randomData);
}

var sb = new System.Text.StringBuilder(length);
for (var i = 0; i < length; i++)
{
sb.Append(GetResultChar(legalCharacters, randomData[i]));
}

return sb.ToString();
#else
return string.Create(length, legalCharacters, (buffer, charset) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

{
var randomData = buffer.Length < 1024 ? stackalloc byte[buffer.Length] : new byte[buffer.Length];
RandomNumberGenerator.Fill(randomData);
for (var i = 0; i < buffer.Length; i++)
{
buffer[i] = GetResultChar(charset, randomData[i]);
}
});
#endif

static char GetResultChar(string charset, byte randomValue) => charset[randomValue % charset.Length];
}

public double NextDouble()
{
#if NET6_0_OR_GREATER
return System.Random.Shared.NextDouble();
#else
return ThreadStaticRandom.NextDouble();
#endif
}
}
43 changes: 0 additions & 43 deletions src/MongoDB.Driver/Core/Misc/DefaultRandomStringGenerator.cs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2013-present MongoDB Inc.
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,10 +13,12 @@
* limitations under the License.
*/

namespace MongoDB.Driver.Core.Misc
namespace MongoDB.Driver.Core.Misc;

internal interface IRandom
{
internal interface IRandomStringGenerator
{
string Generate(int length, string legalCharacters);
}
string GenerateString(int length, string legalCharacters);

double NextDouble();
}

15 changes: 8 additions & 7 deletions src/MongoDB.Driver/Core/Misc/ThreadStaticRandom.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2013-present MongoDB Inc.
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -26,13 +26,14 @@ internal static class ThreadStaticRandom
// static methods
public static int Next(int maxValue)
{
var random = __threadStaticRandom;
if (random == null)
{
random = __threadStaticRandom = new Random();
}
__threadStaticRandom ??= new Random();
return __threadStaticRandom.Next(maxValue);
}

return random.Next(maxValue);
public static double NextDouble()
{
__threadStaticRandom ??= new Random();
return __threadStaticRandom.NextDouble();
}
}
}
12 changes: 12 additions & 0 deletions src/MongoDB.Driver/Core/Operations/RetryabilityHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ public static void AddRetryableWriteErrorLabelIfRequired(MongoException exceptio
}
}

public static int GetRetryDelayMs(IRandom random, int attempt, double backoffBase, int backoffInitial, int backoffMax)
{
Ensure.IsNotNull(random, nameof(random));
Ensure.IsGreaterThanZero(attempt, nameof(attempt));
Ensure.IsGreaterThanZero(backoffBase, nameof(backoffBase));
Ensure.IsGreaterThanZero(backoffInitial, nameof(backoffInitial));
Ensure.IsGreaterThan(backoffMax, backoffInitial, nameof(backoffMax));

var j = random.NextDouble();
return (int)(j * Math.Min(backoffMax, backoffInitial * Math.Pow(backoffBase, attempt - 1)));
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculation backoffInitial * Math.Pow(backoffBase, attempt - 1) may overflow for large attempt values, which would then be capped by Math.Min. Consider adding overflow protection or documenting the maximum safe attempt value to prevent unexpected behavior.

Suggested change
return (int)(j * Math.Min(backoffMax, backoffInitial * Math.Pow(backoffBase, attempt - 1)));
// compute the largest exponent such that backoffInitial * backoffBase^exponent <= backoffMax
var maxExponent = Math.Log(backoffMax / (double)backoffInitial, backoffBase);
var effectiveExponent = attempt - 1;
double delayWithoutJitter;
if (effectiveExponent >= maxExponent)
{
delayWithoutJitter = backoffMax;
}
else
{
delayWithoutJitter = backoffInitial * Math.Pow(backoffBase, effectiveExponent);
}
return (int)(j * delayWithoutJitter);

Copilot uses AI. Check for mistakes.
}

public static bool IsCommandRetryable(BsonDocument command)
{
return
Expand Down
Loading