Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combined methods related to splitting data into one single method. Also fixed related issues. #5227

Merged
merged 14 commits into from
Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 2 additions & 35 deletions src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -305,41 +305,8 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output
}
}

if (string.IsNullOrEmpty(stratificationColumn))
{
stratificationColumn = "StratificationColumn";
int tmp;
int inc = 0;
while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc);
var keyGenArgs = new GenerateNumberTransform.Options();
var col = new GenerateNumberTransform.Column();
col.Name = stratificationColumn;
keyGenArgs.Columns = new[] { col };
output = new GenerateNumberTransform(Host, keyGenArgs, input);
}
else
{
int col;
if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col))
throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn);
var type = input.Schema[col].Type;
if (!RangeFilter.IsValidRangeFilterColumnType(ch, type))
{
ch.Info("Hashing the stratification column");
var origStratCol = stratificationColumn;
stratificationColumn = input.Schema.GetTempColumnName("strat");

// HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
var itemType = type.GetItemType();
if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
input = new TypeConvertingTransformer(Host, origStratCol, DataKind.Int64, origStratCol).Transform(input);

output = new HashingEstimator(Host, stratificationColumn, origStratCol, 30).Fit(input).Transform(input);
}
}

return stratificationColumn;
var splitColumn = DataOperationsCatalog.CreateSplitColumn(Host, ref output, stratificationColumn);
return splitColumn;
}

private bool TryGetOverallMetrics(Dictionary<string, IDataView>[] metrics, out List<IDataView> overallList)
Expand Down
108 changes: 79 additions & 29 deletions src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -413,24 +413,27 @@ public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, s
_env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive");
_env.CheckValueOrNull(samplingKeyColumnName);

EnsureGroupPreservationColumn(_env, ref data, ref samplingKeyColumnName, seed);
var splitColumn = CreateSplitColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true);

var trainFilter = new RangeFilter(_env, new RangeFilter.Options()
{
Column = samplingKeyColumnName,
Column = splitColumn,
Min = 0,
Max = testFraction,
Complement = true
}, data);
var testFilter = new RangeFilter(_env, new RangeFilter.Options()
{
Column = samplingKeyColumnName,
Column = splitColumn,
Min = 0,
Max = testFraction,
Complement = false
}, data);

return new TrainTestData(trainFilter, testFilter);
var trainDV = ColumnSelectingTransformer.CreateDrop(_env, trainFilter, splitColumn);
var testDV = ColumnSelectingTransformer.CreateDrop(_env, testFilter, splitColumn);

return new TrainTestData(trainDV, testDV);
}

/// <summary>
Expand All @@ -455,20 +458,26 @@ public IReadOnlyList<TrainTestData> CrossValidationSplit(IDataView data, int num
_env.CheckValue(data, nameof(data));
_env.CheckParam(numberOfFolds > 1, nameof(numberOfFolds), "Must be more than 1");
_env.CheckValueOrNull(samplingKeyColumnName);
EnsureGroupPreservationColumn(_env, ref data, ref samplingKeyColumnName, seed);
var splitColumn = CreateSplitColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true);
var result = new List<TrainTestData>();
foreach (var split in CrossValidationSplit(_env, data, numberOfFolds, samplingKeyColumnName))
foreach (var split in CrossValidationSplit(_env, data, splitColumn, numberOfFolds))
result.Add(split);
return result;
}

internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment env, IDataView data, int numberOfFolds = 5, string samplingKeyColumnName = null)
/// <summary>
/// Splits the data based on the splitColumn, and drops that column as it is only
/// intended to be used for splitting the data, and shouldn't be part of the output schema.
/// </summary>
internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment env, IDataView data, string splitColumn, int numberOfFolds = 5)
{
env.CheckValue(splitColumn, nameof(splitColumn));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

splitColumn [](start = 27, length = 11)

Should we check/assert that this column exists in the data?

Copy link
Member Author

@antoniovs1029 antoniovs1029 Jun 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did it because if the column doesn't exist (or if splitColumn was null) it is going to throw an exception anyway inside the RangeFilter. Since it's necessary that the column exists for the RangeFilter to work, I thought it was sensible to check for this here and don't let splitColumn to be null.


In reply to: 447414598 [](ancestors = 447414598)


for (int fold = 0; fold < numberOfFolds; fold++)
{
var trainFilter = new RangeFilter(env, new RangeFilter.Options
{
Column = samplingKeyColumnName,
Column = splitColumn,
Min = (double)fold / numberOfFolds,
Max = (double)(fold + 1) / numberOfFolds,
Complement=true,
Expand All @@ -478,63 +487,104 @@ internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment

var testFilter = new RangeFilter(env, new RangeFilter.Options
{
Column = samplingKeyColumnName,
Column = splitColumn,
Min = (double)fold / numberOfFolds,
Max = (double)(fold + 1) / numberOfFolds,
Complement = false,
IncludeMin = true,
IncludeMax = true
}, data);

yield return new TrainTestData(trainFilter, testFilter);
var trainDV = ColumnSelectingTransformer.CreateDrop(env, trainFilter, splitColumn);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CreateDrop [](start = 57, length = 10)

Is this a new behavior? Or is it needed now because of the other changes?
If it is a new behavior, why is it needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropping the SplitColumn is something new to fix the situation I mentioned here:
#5227 (comment)

Not dropping it was causing an issue with AutoML. And also it didn't make sense that we didn't drop it, because if we didn't do it then the schema of a DataView was changed by splitting it, which I think should be considered unexpected behavior.


In reply to: 447415004 [](ancestors = 447415004)

var testDV = ColumnSelectingTransformer.CreateDrop(env, testFilter, splitColumn);

yield return new TrainTestData(trainDV, testDV);
}
}

/// <summary>
/// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
/// Based on the input samplingKeyColumn creates a new splitColumn that will be used by the callers to apply a RangeFilter that will produce train-test splits
/// or cross-validation splits.
///
/// Notice that the new splitColumn might get dropped by the callers of this method after using it, as it wasn't part of
/// the input DataView schema.
/// </summary>
internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null)
/// <param name="env">IHostEnvironment of the caller</param>
/// <param name="data">DataView that should contain the "samplingKeyColumn". The new splitColumn will be added to this DataView.</param>
/// <param name="samplingKeyColumn">Name of the column that will be used as base of the new splitColumn.
/// Notice that in other places in the code the samplingKeyColumn, and/or the splitColumn this method creates,
/// are refered to as "SamplingKeyColumn", "StratificationColumn", "SplitColumn", "GroupPreservationColumn" or similar names. </param>
/// <param name="seed">The seed that might be used by the transformers that will create the new splitColumn</param>
/// <param name="fallbackInEnvSeed">If seed = null, then should we use the env seed? If seed = null, and this parameter is false, then we won't use a seed.</param>
/// <return>The name of the new column</return>
[BestFriend]
internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int? seed = null, bool fallbackInEnvSeed = false)
{
Contracts.CheckValue(env, nameof(env));
// We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
// build a single hash of it. If it is not, we generate a random number.
Contracts.CheckValueOrNull(samplingKeyColumn);

var splitColumnName = data.Schema.GetTempColumnName("SplitColumn");
int? seedToUse;

if(seed.HasValue)
{
seedToUse = seed.Value;
}
else if(fallbackInEnvSeed)
{
ISeededEnvironment seededEnv = (ISeededEnvironment)env;
seedToUse = seededEnv.Seed;
}
else
{
seedToUse = null;
}

// We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number.
if (samplingKeyColumn == null)
{
samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed));
data = new GenerateNumberTransform(env, data, splitColumnName, (uint?)seedToUse);
}
else
{
if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol))
// If samplingKeyColumn was provided we will make a new column based on it, but using a temporary
// name, as it might be dropped elsewhere in the code

if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int samplingColIndex))
throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);

var type = data.Schema[stratCol].Type;
var type = data.Schema[samplingColIndex].Type;
if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
{
var origStratCol = samplingKeyColumn;
samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
var hashInputColumnName = samplingKeyColumn;
// HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
var itemType = type.GetItemType();
if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
data = new TypeConvertingTransformer(env, origStratCol, DataKind.Int64, origStratCol).Transform(data);
{
data = new TypeConvertingTransformer(env, splitColumnName, DataKind.Int64, samplingKeyColumn).Transform(data);
hashInputColumnName = splitColumnName;
}

var localSeed = seed.HasValue ? seed : ((ISeededEnvironment)env).Seed.HasValue ? ((ISeededEnvironment)env).Seed : null;
var columnOptions =
localSeed.HasValue ?
new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, (uint)localSeed.Value, combine: true) :
new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, combine: true);
seedToUse.HasValue ?
new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) :
new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, combine: true);
data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
}
else
{
if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double))
if (type != NumberDataViewType.Single && type != NumberDataViewType.Double)
{
var origStratCol = samplingKeyColumn;
samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data);
data = new ColumnCopyingEstimator(env, (splitColumnName, samplingKeyColumn)).Fit(data).Transform(data);
}
else
{
data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(splitColumnName, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data);
}
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved
}
}

return splitColumnName;
}
}
}
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/TrainCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ private protected CrossValidationResult[] CrossValidateTrain(IDataView data, IEs
Environment.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1");
Environment.CheckValueOrNull(samplingKeyColumn);

DataOperationsCatalog.EnsureGroupPreservationColumn(Environment, ref data, ref samplingKeyColumn, seed);
var splitColumn = DataOperationsCatalog.CreateSplitColumn(Environment, ref data, samplingKeyColumn, seed, fallbackInEnvSeed: true);
var result = new CrossValidationResult[numFolds];
int fold = 0;
// Sequential per-fold training.
// REVIEW: we could have a parallel implementation here. We would need to
// spawn off a separate host per fold in that case.
foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, numFolds, samplingKeyColumn))
foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, splitColumn, numFolds))
{
var model = estimator.Fit(split.TrainSet);
var scoredTest = model.Transform(split.TestSet);
Expand Down
11 changes: 6 additions & 5 deletions src/Microsoft.ML.EntryPoints/CVSplit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
Expand Down Expand Up @@ -53,7 +54,7 @@ public static Output Split(IHostEnvironment env, Input input)

var data = input.Data;

var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn);
var splitCol = DataOperationsCatalog.CreateSplitColumn(env, ref data, input.StratificationColumn);

int n = input.NumFolds;
var output = new Output
Expand All @@ -67,12 +68,12 @@ public static Output Split(IHostEnvironment env, Input input)
for (int i = 0; i < n; i++)
{
var trainData = new RangeFilter(host,
new RangeFilter.Options { Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = true }, data);
output.TrainData[i] = ColumnSelectingTransformer.CreateDrop(host, trainData, stratCol);
new RangeFilter.Options { Column = splitCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = true }, data);
output.TrainData[i] = ColumnSelectingTransformer.CreateDrop(host, trainData, splitCol);

var testData = new RangeFilter(host,
new RangeFilter.Options { Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = false }, data);
output.TestData[i] = ColumnSelectingTransformer.CreateDrop(host, testData, stratCol);
new RangeFilter.Options { Column = splitCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = false }, data);
output.TestData[i] = ColumnSelectingTransformer.CreateDrop(host, testData, splitCol);
}

return output;
Expand Down
Loading