Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combined methods related to splitting data into one single method. Also fixed related issues. #5227

Merged
merged 14 commits into from
Jun 30, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions src/Microsoft.ML.EntryPoints/TrainTestSplit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,21 @@ public static Output Split(IHostEnvironment env, Input input)

internal static class SplitUtils
{
// Creates a new Stratification column to be used for splitting.
// Notice that the new column might be dropped elsewhere in the code
// Returns: the name of the new column.
public static string CreateStratificationColumn(IHost host, ref IDataView data, string stratificationColumn = null)
{
host.CheckValue(data, nameof(data));
host.CheckValueOrNull(stratificationColumn);

// Pick a unique name for the stratificationColumn.
// Pick a unique name for the new stratificationColumn.
const string stratColName = "StratificationKey";
string stratCol = data.Schema.GetTempColumnName(stratColName);

// Construct the stratification column. If user-provided stratification column exists, use HashJoin
// of it to construct the strat column, otherwise generate a random number and use it.
if (stratificationColumn == null)
{
// If the stratificationColumn wasn't provided by the user, simply create a new Random Number Generator
data = new GenerateNumberTransform(host,
new GenerateNumberTransform.Options
{
Expand All @@ -106,11 +108,15 @@ public static string CreateStratificationColumn(IHost host, ref IDataView data,
else
{
if (data.Schema[stratificationColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double))
return stratificationColumn;

data = new NormalizingEstimator(host,
new NormalizingEstimator.MinMaxColumnOptions(stratCol, stratificationColumn, ensureZeroUntouched: true))
.Fit(data).Transform(data);
{
data = new ColumnCopyingEstimator(host,(stratCol,stratificationColumn)).Fit(data).Transform(data);
Copy link

@yaeldekel yaeldekel Jun 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ColumnCopyingEstimator [](start = 35, length = 22)

Should this be added to EnsureGroupPreservationColumn as well?

And more generally, would it make sense to unify the two methods? There is also a method called GetSplitColumn in line 285 in CrossValidationCommand.cs that does more or less the same thing. I think the same utility method should be called for maml, entry points and the C# API. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I took a quick look into the EnsureGroupPreservationColumn callers, and it doesn't seem that they drop the created column, so this wouldn't be an issue there. I'll look into it more later to confirm this.

About merging these methods, I think it would be a good idea to maintain consistency across the codebase, although it wouldn't be necessary to fix the issue opened by @ganik to unblock NimbusML update. I'll look into this idea.


In reply to: 438552643 [](ancestors = 438552643)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirmed that when EnsureGroupPreservationColumn was called, the samplingKeyColumn was never dropped, so the issue didn't apply on that case.

Also, I've merged the methods you've mentioned, so now only CreateGroupPreservationColumn exists


In reply to: 438969558 [](ancestors = 438969558,438552643)

Copy link
Member Author

@antoniovs1029 antoniovs1029 Jun 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I've renamed CreateGroupPreservationColumn to CreateSplitColumn and made similar renames elsewhere to make the code easier to read. Now "splitColumn" is the name of the temporary column we create for splitting, regardless if it's based-off a "samplingKeyColumn" (ML.NET) or a "stratificationColumn" (NimbusML, Maml, legacy TLC naming conventions...)

Also now I do drop the new "splitColumn" when it's created while splitting through ML.NET's API, because not dropping it was causing other issues... see comment:
#5227 (comment)
Added 2 tests to cheked this is dropped, and updated my PR description to explain all the issues it's now fixing.


In reply to: 441992076 [](ancestors = 441992076,438969558,438552643)

}
else
{
data = new NormalizingEstimator(host,
new NormalizingEstimator.MinMaxColumnOptions(stratCol, stratificationColumn, ensureZeroUntouched: true))
.Fit(data).Transform(data);
}
}
}

Expand Down