-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Combined methods related to splitting data into one single method. Also fixed related issues. #5227
Changes from all commits
d2fb318
e9454a0
8a747e5
1f94584
033ae83
1638341
5e9a8e4
cc9f6ed
37332c1
177fa04
e2a76c7
c7e0437
2cf9d6a
3656236
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -413,24 +413,27 @@ public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, s | |
_env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); | ||
_env.CheckValueOrNull(samplingKeyColumnName); | ||
|
||
EnsureGroupPreservationColumn(_env, ref data, ref samplingKeyColumnName, seed); | ||
var splitColumn = CreateSplitColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); | ||
|
||
var trainFilter = new RangeFilter(_env, new RangeFilter.Options() | ||
{ | ||
Column = samplingKeyColumnName, | ||
Column = splitColumn, | ||
Min = 0, | ||
Max = testFraction, | ||
Complement = true | ||
}, data); | ||
var testFilter = new RangeFilter(_env, new RangeFilter.Options() | ||
{ | ||
Column = samplingKeyColumnName, | ||
Column = splitColumn, | ||
Min = 0, | ||
Max = testFraction, | ||
Complement = false | ||
}, data); | ||
|
||
return new TrainTestData(trainFilter, testFilter); | ||
var trainDV = ColumnSelectingTransformer.CreateDrop(_env, trainFilter, splitColumn); | ||
var testDV = ColumnSelectingTransformer.CreateDrop(_env, testFilter, splitColumn); | ||
|
||
return new TrainTestData(trainDV, testDV); | ||
} | ||
|
||
/// <summary> | ||
|
@@ -455,20 +458,26 @@ public IReadOnlyList<TrainTestData> CrossValidationSplit(IDataView data, int num | |
_env.CheckValue(data, nameof(data)); | ||
_env.CheckParam(numberOfFolds > 1, nameof(numberOfFolds), "Must be more than 1"); | ||
_env.CheckValueOrNull(samplingKeyColumnName); | ||
EnsureGroupPreservationColumn(_env, ref data, ref samplingKeyColumnName, seed); | ||
var splitColumn = CreateSplitColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); | ||
var result = new List<TrainTestData>(); | ||
foreach (var split in CrossValidationSplit(_env, data, numberOfFolds, samplingKeyColumnName)) | ||
foreach (var split in CrossValidationSplit(_env, data, splitColumn, numberOfFolds)) | ||
result.Add(split); | ||
return result; | ||
} | ||
|
||
internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment env, IDataView data, int numberOfFolds = 5, string samplingKeyColumnName = null) | ||
/// <summary> | ||
/// Splits the data based on the splitColumn, and drops that column as it is only | ||
/// intended to be used for splitting the data, and shouldn't be part of the output schema. | ||
/// </summary> | ||
internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment env, IDataView data, string splitColumn, int numberOfFolds = 5) | ||
{ | ||
env.CheckValue(splitColumn, nameof(splitColumn)); | ||
|
||
for (int fold = 0; fold < numberOfFolds; fold++) | ||
{ | ||
var trainFilter = new RangeFilter(env, new RangeFilter.Options | ||
{ | ||
Column = samplingKeyColumnName, | ||
Column = splitColumn, | ||
Min = (double)fold / numberOfFolds, | ||
Max = (double)(fold + 1) / numberOfFolds, | ||
Complement=true, | ||
|
@@ -478,63 +487,104 @@ internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment | |
|
||
var testFilter = new RangeFilter(env, new RangeFilter.Options | ||
{ | ||
Column = samplingKeyColumnName, | ||
Column = splitColumn, | ||
Min = (double)fold / numberOfFolds, | ||
Max = (double)(fold + 1) / numberOfFolds, | ||
Complement = false, | ||
IncludeMin = true, | ||
IncludeMax = true | ||
}, data); | ||
|
||
yield return new TrainTestData(trainFilter, testFilter); | ||
var trainDV = ColumnSelectingTransformer.CreateDrop(env, trainFilter, splitColumn); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is this a new behavior? Or is it needed now because of the other changes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dropping the SplitColumn is something new to fix the situation I mentioned here: Not dropping it was causing an issue with AutoML. And also it didn't make sense that we didn't drop it, because if we didn't do it then the schema of a DataView was changed by splitting it, which I think should be considered unexpected behavior. In reply to: 447415004 [](ancestors = 447415004) |
||
var testDV = ColumnSelectingTransformer.CreateDrop(env, testFilter, splitColumn); | ||
|
||
yield return new TrainTestData(trainDV, testDV); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null. | ||
/// Based on the input samplingKeyColumn creates a new splitColumn that will be used by the callers to apply a RangeFilter that will produce train-test splits | ||
/// or cross-validation splits. | ||
/// | ||
/// Notice that the new splitColumn might get dropped by the callers of this method after using it, as it wasn't part of | ||
/// the input DataView schema. | ||
/// </summary> | ||
internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null) | ||
/// <param name="env">IHostEnvironment of the caller</param> | ||
/// <param name="data">DataView that should contain the "samplingKeyColumn". The new splitColumn will be added to this DataView.</param> | ||
/// <param name="samplingKeyColumn">Name of the column that will be used as base of the new splitColumn. | ||
/// Notice that in other places in the code the samplingKeyColumn, and/or the splitColumn this method creates, | ||
/// are refered to as "SamplingKeyColumn", "StratificationColumn", "SplitColumn", "GroupPreservationColumn" or similar names. </param> | ||
/// <param name="seed">The seed that might be used by the transformers that will create the new splitColumn</param> | ||
/// <param name="fallbackInEnvSeed">If seed = null, then should we use the env seed? If seed = null, and this parameter is false, then we won't use a seed.</param> | ||
/// <return>The name of the new column</return> | ||
[BestFriend] | ||
internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int? seed = null, bool fallbackInEnvSeed = false) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
// We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to | ||
// build a single hash of it. If it is not, we generate a random number. | ||
Contracts.CheckValueOrNull(samplingKeyColumn); | ||
|
||
var splitColumnName = data.Schema.GetTempColumnName("SplitColumn"); | ||
int? seedToUse; | ||
|
||
if(seed.HasValue) | ||
{ | ||
seedToUse = seed.Value; | ||
} | ||
else if(fallbackInEnvSeed) | ||
{ | ||
ISeededEnvironment seededEnv = (ISeededEnvironment)env; | ||
seedToUse = seededEnv.Seed; | ||
} | ||
else | ||
{ | ||
seedToUse = null; | ||
} | ||
|
||
// We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number. | ||
if (samplingKeyColumn == null) | ||
{ | ||
samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn"); | ||
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed)); | ||
data = new GenerateNumberTransform(env, data, splitColumnName, (uint?)seedToUse); | ||
} | ||
else | ||
{ | ||
if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol)) | ||
// If samplingKeyColumn was provided we will make a new column based on it, but using a temporary | ||
// name, as it might be dropped elsewhere in the code | ||
|
||
if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int samplingColIndex)) | ||
throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn); | ||
|
||
var type = data.Schema[stratCol].Type; | ||
var type = data.Schema[samplingColIndex].Type; | ||
if (!RangeFilter.IsValidRangeFilterColumnType(env, type)) | ||
{ | ||
var origStratCol = samplingKeyColumn; | ||
samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); | ||
var hashInputColumnName = samplingKeyColumn; | ||
// HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. | ||
var itemType = type.GetItemType(); | ||
if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) | ||
data = new TypeConvertingTransformer(env, origStratCol, DataKind.Int64, origStratCol).Transform(data); | ||
{ | ||
data = new TypeConvertingTransformer(env, splitColumnName, DataKind.Int64, samplingKeyColumn).Transform(data); | ||
hashInputColumnName = splitColumnName; | ||
} | ||
|
||
var localSeed = seed.HasValue ? seed : ((ISeededEnvironment)env).Seed.HasValue ? ((ISeededEnvironment)env).Seed : null; | ||
var columnOptions = | ||
localSeed.HasValue ? | ||
new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, (uint)localSeed.Value, combine: true) : | ||
new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, combine: true); | ||
seedToUse.HasValue ? | ||
new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) : | ||
new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, combine: true); | ||
data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data); | ||
} | ||
else | ||
{ | ||
if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double)) | ||
if (type != NumberDataViewType.Single && type != NumberDataViewType.Double) | ||
{ | ||
var origStratCol = samplingKeyColumn; | ||
samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); | ||
data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data); | ||
data = new ColumnCopyingEstimator(env, (splitColumnName, samplingKeyColumn)).Fit(data).Transform(data); | ||
} | ||
else | ||
{ | ||
data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(splitColumnName, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data); | ||
} | ||
antoniovs1029 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
return splitColumnName; | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we check/assert that this column exists in the data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did it because if the column doesn't exist (or if splitColumn was null) it is going to throw an exception anyway inside the RangeFilter. Since it's necessary that the column exists for the RangeFilter to work, I thought it was sensible to check for this here and don't let
splitColumn
to be null.In reply to: 447414598 [](ancestors = 447414598)