Skip to content

Default seed is not propagated from MLContext #4752

@najeeb-kazmi

Description

@najeeb-kazmi

In theory, the seed set in MLContext is intended to provide the global seed for all components and operations requiring randomness, e.g. sampling, permutation, etc. In practice, this doesn't always hold true.

TrainTestSplit, CrossValidationSplit, and CrossValidate all have a user specified seed and call EnsureGroupPreservationColumn, which in turn uses GenerateNumberTransform and HashingEstimator.

When the seed is not specified by the user, it is not derived from MLContext. Instead, GenerateNumberTransform and HashingEstimator use their own defaults, so that if a user doesn't specify a seed to TrainTestSplit, CrossValidationSplit, or CrossValidate, they will always get a deterministic split regardless of the seed in MLContext.

internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null)
{
// We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
// build a single hash of it. If it is not, we generate a random number.
if (samplingKeyColumn == null)
{
samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)seed);
}

if (seed.HasValue)
columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)seed.Value);
else
columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30);
data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);

cc: @harishsk @justinormont

Metadata

Metadata

Assignees

Labels

P0Priority of the issue for triage purpose: IMPORTANT, needs to be fixed right away.bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions