Skip to content

Commit 05ef676

Browse files
authored
Offer suggestions for possibly mistyped label column names in AutoML (#5574) (#5624)
* Offer suggestions for possibly mistyped label column names * review changes
1 parent 3d3d45c commit 05ef676

File tree

3 files changed

+106
-3
lines changed

3 files changed

+106
-3
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using System;
2+
3+
namespace Microsoft.ML.AutoML.Utils
4+
{
5+
internal static class StringEditDistance
6+
{
7+
public static int GetLevenshteinDistance(string first, string second)
8+
{
9+
if (first is null)
10+
{
11+
throw new ArgumentNullException(nameof(first));
12+
}
13+
14+
if (second is null)
15+
{
16+
throw new ArgumentNullException(nameof(second));
17+
}
18+
19+
if (first.Length == 0 || second.Length == 0)
20+
{
21+
return first.Length + second.Length;
22+
}
23+
24+
var currentRow = 0;
25+
var nextRow = 1;
26+
var rows = new int[second.Length + 1, second.Length + 1];
27+
28+
for (var j = 0; j <= second.Length; ++j)
29+
{
30+
rows[currentRow, j] = j;
31+
}
32+
33+
for (var i = 1; i <= first.Length; ++i)
34+
{
35+
rows[nextRow, 0] = i;
36+
for (var j = 1; j <= second.Length; ++j)
37+
{
38+
var deletion = rows[currentRow, j] + 1;
39+
var insertion = rows[nextRow, j - 1] + 1;
40+
var substitution = rows[currentRow, j - 1] + (first[i - 1].Equals(second[j - 1]) ? 0 : 1);
41+
42+
rows[nextRow, j] = Math.Min(deletion, Math.Min(insertion, substitution));
43+
}
44+
45+
if (currentRow == 0)
46+
{
47+
currentRow = 1;
48+
nextRow = 0;
49+
}
50+
else
51+
{
52+
currentRow = 0;
53+
nextRow = 1;
54+
}
55+
}
56+
57+
return rows[currentRow, second.Length];
58+
}
59+
}
60+
}

src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Collections.Generic;
77
using System.IO;
88
using System.Linq;
9+
using Microsoft.ML.AutoML.Utils;
910
using Microsoft.ML.Data;
1011

1112
namespace Microsoft.ML.AutoML
@@ -248,7 +249,15 @@ private static void ValidateTrainDataColumn(IDataView trainData, string columnNa
248249
var nullableColumn = trainData.Schema.GetColumnOrNull(columnName);
249250
if (nullableColumn == null)
250251
{
251-
throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' not found in training data.");
252+
var closestNamed = ClosestNamed(trainData, columnName, 7);
253+
254+
var exceptionMessage = $"Provided {columnPurpose} column '{columnName}' not found in training data.";
255+
if (closestNamed != string.Empty)
256+
{
257+
exceptionMessage += $" Did you mean '{closestNamed}'.";
258+
}
259+
260+
throw new ArgumentException(exceptionMessage);
252261
}
253262

254263
if(allowedTypes == null)
@@ -272,6 +281,23 @@ private static void ValidateTrainDataColumn(IDataView trainData, string columnNa
272281
}
273282
}
274283

284+
private static string ClosestNamed(IDataView trainData, string columnName, int maxAllowableEditDistance = int.MaxValue)
285+
{
286+
var minEditDistance = int.MaxValue;
287+
var closestNamed = string.Empty;
288+
foreach (var column in trainData.Schema)
289+
{
290+
var editDistance = StringEditDistance.GetLevenshteinDistance(column.Name, columnName);
291+
if (editDistance < minEditDistance)
292+
{
293+
minEditDistance = editDistance;
294+
closestNamed = column.Name;
295+
}
296+
}
297+
298+
return minEditDistance <= maxAllowableEditDistance ? closestNamed : string.Empty;
299+
}
300+
275301
private static string FindFirstDuplicate(IEnumerable<string> values)
276302
{
277303
var groups = values.GroupBy(v => v);

test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using System.IO;
8+
using System.Linq;
89
using System.Threading.Tasks;
910
using Microsoft.ML.Data;
1011
using Microsoft.ML.TestFramework;
@@ -43,10 +44,26 @@ public void ValidateExperimentExecuteLabelNotInTrain()
4344
{
4445
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
4546
{
47+
const string columnName = "ReallyLongNonExistingColumnName";
4648
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
47-
new ColumnInformation() { LabelColumnName = "L" }, null, task));
49+
new ColumnInformation() { LabelColumnName = columnName }, null, task));
4850

49-
Assert.Equal("Provided label column 'L' not found in training data.", ex.Message);
51+
Assert.Equal($"Provided label column '{columnName}' not found in training data.", ex.Message);
52+
}
53+
}
54+
55+
[Fact]
56+
public void ValidateExperimentExecuteLabelNotInTrainMistyped()
57+
{
58+
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
59+
{
60+
var originalColumnName = _data.Schema.First().Name;
61+
var mistypedColumnName = originalColumnName + "a";
62+
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
63+
new ColumnInformation() { LabelColumnName = mistypedColumnName }, null, task));
64+
65+
Assert.Equal($"Provided label column '{mistypedColumnName}' not found in training data. Did you mean '{originalColumnName}'.",
66+
ex.Message);
5067
}
5168
}
5269

0 commit comments

Comments
 (0)