Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Seed property to MLContext and use as default for data splits #4775

Merged
merged 8 commits into from
Feb 11, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ internal EntryPointInfo(MethodInfo method,
var parameters = method.GetParameters();
if (parameters.Length != 2 && parameters.Length != 3)
throw Contracts.Except("Method '{0}' has {1} parameters, but must have 2 or 3", method.Name, parameters.Length);
if (parameters[0].ParameterType != typeof(IHostEnvironment))
if (parameters[0].ParameterType != typeof(IHostEnvironment) && parameters[0].ParameterType != typeof(ISeededEnvironment))
najeeb-kazmi marked this conversation as resolved.
Show resolved Hide resolved
throw Contracts.Except("Method '{0}', 1st parameter is {1}, but must be IHostEnvironment", method.Name, parameters[0].ParameterType);
InputType = parameters[1].ParameterType;
var outputType = method.ReturnType;
Expand Down
9 changes: 9 additions & 0 deletions src/Microsoft.ML.Core/Data/IHostEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ internal interface ICancelable
bool IsCanceled { get; }
}

[BestFriend]
internal interface ISeededEnvironment : IHostEnvironment
{
/// <summary>
/// The seed property that, if assigned, makes components requiring randomness behave deterministically.
/// </summary>
int? Seed { get; }
}

/// <summary>
/// A host is coupled to a component and provides random number generation and concurrency guidance.
/// Note that the random number generation, like the host environment methods, should be accessed only
Expand Down
19 changes: 1 addition & 18 deletions src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -366,24 +366,7 @@ protected override void Dispose(bool disposing)
public ConsoleEnvironment(int? seed = null, bool verbose = false,
MessageSensitivity sensitivity = MessageSensitivity.All,
TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null)
: this(RandomUtils.Create(seed), verbose, sensitivity, outWriter, errWriter, testWriter)
{
}

// REVIEW: do we really care about custom random? If we do, let's make this ctor public.
/// <summary>
/// Create an ML.NET environment for local execution, with console feedback.
/// </summary>
/// <param name="rand">An custom source of randomness to use in the environment.</param>
/// <param name="verbose">Set to <c>true</c> for fully verbose logging.</param>
/// <param name="sensitivity">Allowed message sensitivity.</param>
/// <param name="outWriter">Text writer to print normal messages to.</param>
/// <param name="errWriter">Text writer to print error messages to.</param>
/// <param name="testWriter">Optional TextWriter to write messages if the host is a test environment.</param>
private ConsoleEnvironment(Random rand, bool verbose = false,
MessageSensitivity sensitivity = MessageSensitivity.All,
TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null)
: base(rand, verbose, nameof(ConsoleEnvironment))
: base(seed, verbose, nameof(ConsoleEnvironment))
{
Contracts.CheckValueOrNull(outWriter);
Contracts.CheckValueOrNull(errWriter);
Expand Down
13 changes: 8 additions & 5 deletions src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal interface IMessageSource
/// query progress.
/// </summary>
[BestFriend]
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, IHostEnvironment, IChannelProvider, ICancelable
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, ISeededEnvironment, IHostEnvironment, IChannelProvider, ICancelable
where TEnv : HostEnvironmentBase<TEnv>
{
void ICancelable.CancelExecution()
Expand Down Expand Up @@ -330,6 +330,9 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)

// The random number generator for this host.
private readonly Random _rand;

public int? Seed { get; }

// A dictionary mapping the type of message to the Dispatcher that gets the strongly typed dispatch delegate.
protected readonly ConcurrentDictionary<Type, Dispatcher> ListenerDict;

Expand All @@ -345,14 +348,14 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
private readonly List<WeakReference<IHost>> _children;

/// <summary>
/// The main constructor.
/// The main constructor.
/// </summary>
protected HostEnvironmentBase(Random rand, bool verbose,
protected HostEnvironmentBase(int? seed, bool verbose,
string shortName = null, string parentFullName = null)
: base(shortName, parentFullName, verbose)
{
Contracts.CheckValueOrNull(rand);
_rand = rand ?? RandomUtils.Create();
Seed = seed;
_rand = RandomUtils.Create(Seed);
ListenerDict = new ConcurrentDictionary<Type, Dispatcher>();
ProgressTracker = new ProgressReporting.ProgressTracker(this);
_cancelLock = new object();
Expand Down
11 changes: 5 additions & 6 deletions src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Microsoft.ML
public sealed class DataOperationsCatalog : IInternalCatalog
{
IHostEnvironment IInternalCatalog.Environment => _env;
private readonly IHostEnvironment _env;
private readonly ISeededEnvironment _env;

/// <summary>
/// A pair of datasets, for the train and test set.
Expand All @@ -44,7 +44,7 @@ internal TrainTestData(IDataView trainSet, IDataView testSet)
}
}

internal DataOperationsCatalog(IHostEnvironment env)
internal DataOperationsCatalog(ISeededEnvironment env)
{
Contracts.AssertValue(env);
_env = env;
Expand Down Expand Up @@ -493,16 +493,15 @@ internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment
/// <summary>
/// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
/// </summary>
internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null)
internal static void EnsureGroupPreservationColumn(ISeededEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("rand");
// We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
// build a single hash of it. If it is not, we generate a random number.
if (samplingKeyColumn == null)
{
samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? host.Rand.Next()));
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? env.Seed));
najeeb-kazmi marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
Expand All @@ -518,7 +517,7 @@ internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDa
// instead of having two hash transformations.
var origStratCol = samplingKeyColumn;
samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
var columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)(seed ?? host.Rand.Next()));
var columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)(seed ?? env.Seed));
Copy link

@yaeldekel yaeldekel Feb 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env.Seed [](start = 135, length = 8)

This can also be null, can't it? #Resolved

Copy link
Member Author

@najeeb-kazmi najeeb-kazmi Feb 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if env.Seed is also null, it will fall back to the HashingEstimator default seed:

internal static class Defaults
{
public const int NumberOfBits = NumBitsLim - 1;
public const uint Seed = 314489979;
#Resolved

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you are casting it to uint here. If it's null it will throw an exception.


In reply to: 375521058 [](ancestors = 375521058)

Copy link
Member Author

@najeeb-kazmi najeeb-kazmi Feb 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, you're right. Thanks for catching that @yaeldekel. I've changed the handling back to how it was before I added (uint)(seed ?? host.Rand.Next()) and added a case to handle seed from the environment. Order of precedence for seed now is (1) user specified to TrainTestSplit/CrossValidata, (2) MLContext seed, (3) default seed for HashingEstimator #Resolved

data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
}
else
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/MLContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Microsoft.ML
/// create components for data preparation, feature enginering, training, prediction, model evaluation.
/// It also allows logging, execution control, and the ability set repeatable random numbers.
/// </summary>
public sealed class MLContext : IHostEnvironment
public sealed class MLContext : ISeededEnvironment, IHostEnvironment
najeeb-kazmi marked this conversation as resolved.
Show resolved Hide resolved
{
// REVIEW: consider making LocalEnvironment and MLContext the same class instead of encapsulation.
private readonly LocalEnvironment _env;
Expand Down Expand Up @@ -140,6 +140,7 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
IChannel IChannelProvider.Start(string name) => _env.Start(name);
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
int? ISeededEnvironment.Seed => _env.Seed;

[BestFriend]
internal void CancelExecution() => ((ICancelable)_env).CancelExecution();
Expand Down
18 changes: 9 additions & 9 deletions src/Microsoft.ML.Data/TrainCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public abstract class TrainCatalogBase : IInternalCatalog
IHostEnvironment IInternalCatalog.Environment => Environment;

[BestFriend]
private protected IHostEnvironment Environment { get; }
private protected ISeededEnvironment Environment { get; }

/// <summary>
/// Results for specific cross-validation fold.
Expand Down Expand Up @@ -111,7 +111,7 @@ private protected CrossValidationResult[] CrossValidateTrain(IDataView data, IEs
}

[BestFriend]
private protected TrainCatalogBase(IHostEnvironment env, string registrationName)
private protected TrainCatalogBase(ISeededEnvironment env, string registrationName)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonEmpty(registrationName, nameof(registrationName));
Expand Down Expand Up @@ -151,7 +151,7 @@ public sealed class BinaryClassificationCatalog : TrainCatalogBase
/// </summary>
public BinaryClassificationTrainers Trainers { get; }

internal BinaryClassificationCatalog(IHostEnvironment env)
internal BinaryClassificationCatalog(ISeededEnvironment env)
: base(env, nameof(BinaryClassificationCatalog))
{
Calibrators = new CalibratorsCatalog(this);
Expand Down Expand Up @@ -388,7 +388,7 @@ public sealed class ClusteringCatalog : TrainCatalogBase
/// <summary>
/// The clustering context.
/// </summary>
internal ClusteringCatalog(IHostEnvironment env)
internal ClusteringCatalog(ISeededEnvironment env)
: base(env, nameof(ClusteringCatalog))
{
Trainers = new ClusteringTrainers(this);
Expand Down Expand Up @@ -468,7 +468,7 @@ public sealed class MulticlassClassificationCatalog : TrainCatalogBase
/// </summary>
public MulticlassClassificationTrainers Trainers { get; }

internal MulticlassClassificationCatalog(IHostEnvironment env)
internal MulticlassClassificationCatalog(ISeededEnvironment env)
: base(env, nameof(MulticlassClassificationCatalog))
{
Trainers = new MulticlassClassificationTrainers(this);
Expand Down Expand Up @@ -549,7 +549,7 @@ public sealed class RegressionCatalog : TrainCatalogBase
/// </summary>
public RegressionTrainers Trainers { get; }

internal RegressionCatalog(IHostEnvironment env)
internal RegressionCatalog(ISeededEnvironment env)
: base(env, nameof(RegressionCatalog))
{
Trainers = new RegressionTrainers(this);
Expand Down Expand Up @@ -619,7 +619,7 @@ public sealed class RankingCatalog : TrainCatalogBase
/// </summary>
public RankingTrainers Trainers { get; }

internal RankingCatalog(IHostEnvironment env)
internal RankingCatalog(ISeededEnvironment env)
: base(env, nameof(RankingCatalog))
{
Trainers = new RankingTrainers(this);
Expand Down Expand Up @@ -685,7 +685,7 @@ public sealed class AnomalyDetectionCatalog : TrainCatalogBase
/// </summary>
public AnomalyDetectionTrainers Trainers { get; }

internal AnomalyDetectionCatalog(IHostEnvironment env)
internal AnomalyDetectionCatalog(ISeededEnvironment env)
: base(env, nameof(AnomalyDetectionCatalog))
{
Trainers = new AnomalyDetectionTrainers(this);
Expand Down Expand Up @@ -753,7 +753,7 @@ public sealed class ForecastingCatalog : TrainCatalogBase
/// </summary>
public Forecasters Trainers { get; }

internal ForecastingCatalog(IHostEnvironment env) : base(env, nameof(ForecastingCatalog))
internal ForecastingCatalog(ISeededEnvironment env) : base(env, nameof(ForecastingCatalog))
{
Trainers = new Forecasters(this);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ protected override void Dispose(bool disposing)
/// </summary>
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
public LocalEnvironment(int? seed = null)
: base(RandomUtils.Create(seed), verbose: false)
: base(seed, verbose: false)
{
}

Expand Down
12 changes: 6 additions & 6 deletions src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace Microsoft.ML.Transforms
internal static class PermutationFeatureImportanceEntryPoints
{
[TlcModule.EntryPoint(Name = "Transforms.PermutationFeatureImportance", Desc = "Permutation Feature Importance (PFI)", UserName = "PFI", ShortName = "PFI")]
public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IHostEnvironment env, PermutationFeatureImportanceArguments input)
public static PermutationFeatureImportanceOutput PermutationFeatureImportance(ISeededEnvironment env, PermutationFeatureImportanceArguments input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("Pfi");
Expand Down Expand Up @@ -57,7 +57,7 @@ internal sealed class PermutationFeatureImportanceArguments : TransformInputBase
internal static class PermutationFeatureImportanceUtils
{
internal static IDataView GetMetrics(
IHostEnvironment env,
ISeededEnvironment env,
IPredictor predictor,
RoleMappedData roleMappedData,
PermutationFeatureImportanceArguments input)
Expand All @@ -82,7 +82,7 @@ internal static IDataView GetMetrics(
}

private static IDataView GetBinaryMetrics(
IHostEnvironment env,
ISeededEnvironment env,
IPredictor predictor,
RoleMappedData roleMappedData,
PermutationFeatureImportanceArguments input)
Expand Down Expand Up @@ -139,7 +139,7 @@ private static IDataView GetBinaryMetrics(
}

private static IDataView GetMulticlassMetrics(
IHostEnvironment env,
ISeededEnvironment env,
IPredictor predictor,
RoleMappedData roleMappedData,
PermutationFeatureImportanceArguments input)
Expand Down Expand Up @@ -198,7 +198,7 @@ private static IDataView GetMulticlassMetrics(
}

private static IDataView GetRegressionMetrics(
IHostEnvironment env,
ISeededEnvironment env,
IPredictor predictor,
RoleMappedData roleMappedData,
PermutationFeatureImportanceArguments input)
Expand Down Expand Up @@ -249,7 +249,7 @@ private static IDataView GetRegressionMetrics(
}

private static IDataView GetRankingMetrics(
IHostEnvironment env,
ISeededEnvironment env,
IPredictor predictor,
RoleMappedData roleMappedData,
PermutationFeatureImportanceArguments input)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Recommender/RecommenderCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public sealed class RecommendationCatalog : TrainCatalogBase
/// </summary>
public RecommendationTrainers Trainers { get; }

internal RecommendationCatalog(IHostEnvironment env)
internal RecommendationCatalog(ISeededEnvironment env)
: base(env, nameof(RecommendationCatalog))
{
Trainers = new RecommendationTrainers(this);
Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public void AutoFitBinaryTest()
[Fact]
public void AutoFitMultiTest()
{
var context = new MLContext(1);
var context = new MLContext(42);
najeeb-kazmi marked this conversation as resolved.
Show resolved Hide resolved
var columnInference = context.Auto().InferColumns(DatasetUtil.TrivialMulticlassDatasetPath, DatasetUtil.TrivialMulticlassDatasetLabel);
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions);
var trainData = textLoader.Load(DatasetUtil.TrivialMulticlassDatasetPath);
Expand Down