Skip to content

Commit 6b2d1a5

Browse files
authored
Add Seed property to MLContext and use as default for data splits (#4775)
* Add Seed property to MLContext and use as default for data splits * Separate Seed property out into an internal interface * Change seed in AutoFitMultiTest * Check typeof ISeededEnvironment in ComponentCatalog * PR feedback * nit * PR feedback * Handle casting of nullable int to uint
1 parent dc4e5f8 commit 6b2d1a5

File tree

7 files changed

+30
-29
lines changed

7 files changed

+30
-29
lines changed

src/Microsoft.ML.Core/Data/IHostEnvironment.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ internal interface ICancelable
8282
bool IsCanceled { get; }
8383
}
8484

85+
[BestFriend]
86+
internal interface ISeededEnvironment : IHostEnvironment
87+
{
88+
/// <summary>
89+
/// The seed property that, if assigned, makes components requiring randomness behave deterministically.
90+
/// </summary>
91+
int? Seed { get; }
92+
}
93+
8594
/// <summary>
8695
/// A host is coupled to a component and provides random number generation and concurrency guidance.
8796
/// Note that the random number generation, like the host environment methods, should be accessed only

src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -366,24 +366,7 @@ protected override void Dispose(bool disposing)
366366
public ConsoleEnvironment(int? seed = null, bool verbose = false,
367367
MessageSensitivity sensitivity = MessageSensitivity.All,
368368
TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null)
369-
: this(RandomUtils.Create(seed), verbose, sensitivity, outWriter, errWriter, testWriter)
370-
{
371-
}
372-
373-
// REVIEW: do we really care about custom random? If we do, let's make this ctor public.
374-
/// <summary>
375-
/// Create an ML.NET environment for local execution, with console feedback.
376-
/// </summary>
377-
/// <param name="rand">An custom source of randomness to use in the environment.</param>
378-
/// <param name="verbose">Set to <c>true</c> for fully verbose logging.</param>
379-
/// <param name="sensitivity">Allowed message sensitivity.</param>
380-
/// <param name="outWriter">Text writer to print normal messages to.</param>
381-
/// <param name="errWriter">Text writer to print error messages to.</param>
382-
/// <param name="testWriter">Optional TextWriter to write messages if the host is a test environment.</param>
383-
private ConsoleEnvironment(Random rand, bool verbose = false,
384-
MessageSensitivity sensitivity = MessageSensitivity.All,
385-
TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null)
386-
: base(rand, verbose, nameof(ConsoleEnvironment))
369+
: base(seed, verbose, nameof(ConsoleEnvironment))
387370
{
388371
Contracts.CheckValueOrNull(outWriter);
389372
Contracts.CheckValueOrNull(errWriter);

src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ internal interface IMessageSource
9393
/// query progress.
9494
/// </summary>
9595
[BestFriend]
96-
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, IHostEnvironment, IChannelProvider, ICancelable
96+
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, ISeededEnvironment, IChannelProvider, ICancelable
9797
where TEnv : HostEnvironmentBase<TEnv>
9898
{
9999
void ICancelable.CancelExecution()
@@ -330,6 +330,9 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
330330

331331
// The random number generator for this host.
332332
private readonly Random _rand;
333+
334+
public int? Seed { get; }
335+
333336
// A dictionary mapping the type of message to the Dispatcher that gets the strongly typed dispatch delegate.
334337
protected readonly ConcurrentDictionary<Type, Dispatcher> ListenerDict;
335338

@@ -345,14 +348,14 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
345348
private readonly List<WeakReference<IHost>> _children;
346349

347350
/// <summary>
348-
/// The main constructor.
351+
/// The main constructor.
349352
/// </summary>
350-
protected HostEnvironmentBase(Random rand, bool verbose,
353+
protected HostEnvironmentBase(int? seed, bool verbose,
351354
string shortName = null, string parentFullName = null)
352355
: base(shortName, parentFullName, verbose)
353356
{
354-
Contracts.CheckValueOrNull(rand);
355-
_rand = rand ?? RandomUtils.Create();
357+
Seed = seed;
358+
_rand = RandomUtils.Create(Seed);
356359
ListenerDict = new ConcurrentDictionary<Type, Dispatcher>();
357360
ProgressTracker = new ProgressReporting.ProgressTracker(this);
358361
_cancelLock = new object();

src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -496,13 +496,12 @@ internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment
496496
internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null)
497497
{
498498
Contracts.CheckValue(env, nameof(env));
499-
var host = env.Register("rand");
500499
// We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
501500
// build a single hash of it. If it is not, we generate a random number.
502501
if (samplingKeyColumn == null)
503502
{
504503
samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
505-
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? host.Rand.Next()));
504+
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed));
506505
}
507506
else
508507
{
@@ -518,7 +517,13 @@ internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDa
518517
// instead of having two hash transformations.
519518
var origStratCol = samplingKeyColumn;
520519
samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
521-
var columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)(seed ?? host.Rand.Next()));
520+
HashingEstimator.ColumnOptionsInternal columnOptions;
521+
if (seed.HasValue)
522+
columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)seed.Value);
523+
else if (((ISeededEnvironment)env).Seed.HasValue)
524+
columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)((ISeededEnvironment)env).Seed.Value);
525+
else
526+
columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30);
522527
data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
523528
}
524529
else

src/Microsoft.ML.Data/MLContext.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace Microsoft.ML
1414
/// create components for data preparation, feature enginering, training, prediction, model evaluation.
1515
/// It also allows logging, execution control, and the ability set repeatable random numbers.
1616
/// </summary>
17-
public sealed class MLContext : IHostEnvironment
17+
public sealed class MLContext : ISeededEnvironment
1818
{
1919
// REVIEW: consider making LocalEnvironment and MLContext the same class instead of encapsulation.
2020
private readonly LocalEnvironment _env;
@@ -140,6 +140,7 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
140140
IChannel IChannelProvider.Start(string name) => _env.Start(name);
141141
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
142142
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
143+
int? ISeededEnvironment.Seed => _env.Seed;
143144

144145
[BestFriend]
145146
internal void CancelExecution() => ((ICancelable)_env).CancelExecution();

src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ protected override void Dispose(bool disposing)
4747
/// </summary>
4848
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
4949
public LocalEnvironment(int? seed = null)
50-
: base(RandomUtils.Create(seed), verbose: false)
50+
: base(seed, verbose: false)
5151
{
5252
}
5353

test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public void AutoFitBinaryTest()
3737
[Fact]
3838
public void AutoFitMultiTest()
3939
{
40-
var context = new MLContext(1);
40+
var context = new MLContext(42);
4141
var columnInference = context.Auto().InferColumns(DatasetUtil.TrivialMulticlassDatasetPath, DatasetUtil.TrivialMulticlassDatasetLabel);
4242
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions);
4343
var trainData = textLoader.Load(DatasetUtil.TrivialMulticlassDatasetPath);

0 commit comments

Comments
 (0)