Skip to content

Commit 87b6766

Browse files
daholsteDmitry-A
authored andcommitted
exception fixes (dotnet#136)
1 parent 4365d98 commit 87b6766

File tree

9 files changed

+248
-46
lines changed

9 files changed

+248
-46
lines changed

src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using System.Collections.Generic;
76
using System.Linq;
87
using Microsoft.ML.Data;
@@ -16,13 +15,7 @@ public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) Infer
1615
{
1716
var sample = TextFileSample.CreateFromFullFile(path);
1817
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
19-
var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader);
20-
21-
// If label column index > inferred # of columns, throw error
22-
if (labelColumnIndex >= typeInference.Columns.Count())
23-
{
24-
throw new ArgumentOutOfRangeException(nameof(labelColumnIndex), $"Label column index ({labelColumnIndex}) is >= than # of inferred columns ({typeInference.Columns.Count()}).");
25-
}
18+
var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader, labelColumnIndex, null);
2619

2720
// if no column is named label,
2821
// rename label column to default ML.NET label column name
@@ -40,7 +33,7 @@ public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) Infer
4033
{
4134
var sample = TextFileSample.CreateFromFullFile(path);
4235
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
43-
var typeInference = InferColumnTypes(context, sample, splitInference, true);
36+
var typeInference = InferColumnTypes(context, sample, splitInference, true, null, label);
4437
return InferColumns(context, path, label, true, splitInference, typeInference, trimWhitespace, groupColumns);
4538
}
4639

@@ -49,10 +42,6 @@ public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) Infer
4942
bool trimWhitespace, bool groupColumns)
5043
{
5144
var loaderColumns = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns);
52-
if (!loaderColumns.Any(t => label.Equals(t.Name)))
53-
{
54-
throw new InferenceException(InferenceType.Label, $"Specified Label Column '{label}' was not found.");
55-
}
5645
var typedLoaderArgs = new TextLoader.Arguments
5746
{
5847
Column = loaderColumns,
@@ -121,7 +110,7 @@ private static TextFileContents.ColumnSplitResult InferSplit(TextFileSample samp
121110
}
122111

123112
private static ColumnTypeInference.InferenceResult InferColumnTypes(MLContext context, TextFileSample sample,
124-
TextFileContents.ColumnSplitResult splitInference, bool hasHeader)
113+
TextFileContents.ColumnSplitResult splitInference, bool hasHeader, uint? labelColumnIndex, string label)
125114
{
126115
// infer column types
127116
var typeInferenceResult = ColumnTypeInference.InferTextFileColumnTypes(context, sample,
@@ -131,7 +120,9 @@ private static ColumnTypeInference.InferenceResult InferColumnTypes(MLContext co
131120
Separator = splitInference.Separator.Value,
132121
AllowSparse = splitInference.AllowSparse,
133122
AllowQuote = splitInference.AllowQuote,
134-
HasHeader = hasHeader
123+
HasHeader = hasHeader,
124+
LabelColumnIndex = labelColumnIndex,
125+
Label = label
135126
});
136127

137128
if (!typeInferenceResult.IsSuccess)

src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ internal sealed class Arguments
3030
public int ColumnCount;
3131
public bool HasHeader;
3232
public int MaxRowsToRead;
33+
public uint? LabelColumnIndex;
34+
public string Label;
3335

3436
public Arguments()
3537
{
@@ -68,13 +70,31 @@ public IntermediateColumn(ReadOnlyMemory<char>[] data, int columnId)
6870
}
6971

7072
public ReadOnlyMemory<char>[] RawData { get { return _data; } }
73+
74+
public string Name { get; set; }
75+
76+
public bool HasAllBooleanValues()
77+
{
78+
if (this.RawData.Skip(1)
79+
.All(x => {
80+
bool value;
81+
// (note: Conversions.TryParse parses an empty string as a Boolean)
82+
return !string.IsNullOrEmpty(x.ToString()) &&
83+
Conversions.TryParse(in x, out value);
84+
}))
85+
{
86+
return true;
87+
}
88+
89+
return false;
90+
}
7191
}
7292

73-
public struct Column
93+
public class Column
7494
{
7595
public readonly int ColumnIndex;
76-
public readonly PrimitiveType ItemType;
7796

97+
public PrimitiveType ItemType;
7898
public string SuggestedName;
7999

80100
public Column(int columnIndex, string suggestedName, PrimitiveType itemType)
@@ -131,13 +151,10 @@ public void Apply(IntermediateColumn[] columns)
131151
{
132152
foreach (var col in columns)
133153
{
134-
if (!col.RawData.Skip(1)
135-
.All(x =>
136-
{
137-
bool value;
138-
return Conversions.TryParse(in x, out value);
139-
})
140-
)
154+
// skip columns that already have a suggested type,
155+
// or that don't have all Boolean values
156+
if (col.SuggestedType != null ||
157+
!col.HasAllBooleanValues())
141158
{
142159
continue;
143160
}
@@ -156,12 +173,6 @@ public void Apply(IntermediateColumn[] columns)
156173
{
157174
foreach (var col in columns)
158175
{
159-
// skip columns that already have a suggested type
160-
if(col.SuggestedType != null)
161-
{
162-
continue;
163-
}
164-
165176
if (!col.RawData.Skip(1)
166177
.All(x =>
167178
{
@@ -215,9 +226,9 @@ public void Apply(IntermediateColumn[] columns)
215226
private static IEnumerable<ITypeInferenceExpert> GetExperts()
216227
{
217228
// Current logic is pretty primitive: if every value (except the first) of a column
218-
// parses as a boolean it's boolean, if it parses as numeric then it's numeric. Otherwise, it is text.
219-
yield return new Experts.BooleanValues();
229+
// parses as numeric then it's numeric. Else if it parses as a Boolean, it's Boolean. Otherwise, it is text.
220230
yield return new Experts.AllNumericValues();
231+
yield return new Experts.BooleanValues();
221232
yield return new Experts.EverythingText();
222233
}
223234

@@ -329,7 +340,6 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMult
329340
}
330341

331342
// suggest names
332-
var names = new List<string>();
333343
usedNames.Clear();
334344
foreach (var col in cols)
335345
{
@@ -338,14 +348,23 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMult
338348
name0 = name = SuggestName(col, args.HasHeader);
339349
int i = 0;
340350
while (!usedNames.Add(name))
351+
{
341352
name = string.Format("{0}_{1:00}", name0, i++);
342-
names.Add(name);
353+
}
354+
col.Name = name;
355+
}
356+
357+
// validate & retrieve label column
358+
var labelColumn = GetAndValidateLabelColumn(args, cols);
359+
360+
// if label column has all Boolean values, set its type as Boolean
361+
if(labelColumn.HasAllBooleanValues())
362+
{
363+
labelColumn.SuggestedType = BoolType.Instance;
343364
}
344-
var outCols =
345-
cols.Select((x, i) => new Column(x.ColumnId, names[i], x.SuggestedType)).ToArray();
346365

347-
var numerics = outCols.Count(x => x.ItemType.IsNumber());
348-
366+
var outCols = cols.Select(x => new Column(x.ColumnId, x.Name, x.SuggestedType)).ToArray();
367+
349368
return InferenceResult.Success(outCols, args.HasHeader, cols.Select(col => col.RawData).ToArray());
350369
}
351370

@@ -361,6 +380,31 @@ private static string Sanitize(string header)
361380
return string.Join("", header.Select(x => Char.IsLetterOrDigit(x) ? x : '_'));
362381
}
363382

383+
private static IntermediateColumn GetAndValidateLabelColumn(Arguments args, IntermediateColumn[] cols)
384+
{
385+
IntermediateColumn labelColumn = null;
386+
if (args.LabelColumnIndex != null)
387+
{
388+
// if label column index > inferred # of columns, throw error
389+
if (args.LabelColumnIndex >= cols.Count())
390+
{
391+
throw new ArgumentOutOfRangeException(nameof(args.LabelColumnIndex), $"Label column index ({args.LabelColumnIndex}) is >= than # of inferred columns ({cols.Count()}).");
392+
}
393+
394+
labelColumn = cols[args.LabelColumnIndex.Value];
395+
}
396+
else
397+
{
398+
labelColumn = cols.FirstOrDefault(c => c.Name == args.Label);
399+
if (labelColumn == null)
400+
{
401+
throw new ArgumentException($"Specified label column '{args.Label}' was not found.");
402+
}
403+
}
404+
405+
return labelColumn;
406+
}
407+
364408
public static TextLoader.Column[] GenerateLoaderColumns(Column[] columns)
365409
{
366410
var loaderColumns = new List<TextLoader.Column>();

src/Microsoft.ML.Auto/ColumnInference/PurposeInference.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ public void Apply(IntermediateColumn[] columns)
171171
Double avgSpaces = 1.0 * sumSpaces / data.Length;
172172
if (cardinalityRatio < 0.7 || seen.Count < 100)
173173
column.SuggestedPurpose = ColumnPurpose.CategoricalFeature;
174-
else if (cardinalityRatio >= 0.85 && (avgLength > 30 || avgSpaces >= 1))
174+
// (note: the columns.Count() == 1 condition below, in case a dataset has only
175+
// a 'name' and a 'label' column, forces what would be a 'name' column to become a text feature)
176+
else if (cardinalityRatio >= 0.85 && (avgLength > 30 || avgSpaces >= 1 || columns.Count() == 1))
175177
column.SuggestedPurpose = ColumnPurpose.TextFeature;
176178
else if (cardinalityRatio >= 0.9)
177179
column.SuggestedPurpose = ColumnPurpose.Name;

src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,13 @@ public static void ValidateInferColumnsArgs(string path)
3434
ValidatePath(path);
3535
}
3636

37-
public static void ValidateAutoReadArgs(string path, string label)
38-
{
39-
ValidateLabel(label);
40-
ValidatePath(path);
41-
}
42-
4337
private static void ValidateTrainData(IDataView trainData)
4438
{
4539
if (trainData == null)
4640
{
4741
throw new ArgumentNullException(nameof(trainData), "Training data cannot be null");
4842
}
43+
4944
var type = trainData.Schema.GetColumnOrNull(DefaultColumnNames.Features)?.Type.GetItemType();
5045
if (type != null && type != NumberType.R4)
5146
{

src/Test/ColumnInferenceTests.cs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public void IncorrectLabelColumnTest()
2828
{
2929
var dataPath = DatasetUtil.DownloadUciAdultDataset();
3030
var context = new MLContext();
31-
Assert.ThrowsException<InferenceException>(new System.Action(() => context.Data.InferColumns(dataPath, "Junk", groupColumns: false)));
31+
Assert.ThrowsException<ArgumentException>(new System.Action(() => context.Data.InferColumns(dataPath, "Junk", groupColumns: false)));
3232
}
3333

3434
[TestMethod]
@@ -62,5 +62,51 @@ public void InferColumnsLabelIndexNoHeaders()
6262
Assert.AreEqual(1, labelPurposes.Count());
6363
Assert.AreEqual(DefaultColumnNames.Label, labelPurposes.First().Name);
6464
}
65+
66+
[TestMethod]
67+
public void InferColumnsWithDatasetWithEmptyColumn()
68+
{
69+
var result = new MLContext().Data.InferColumns(@".\TestData\DatasetWithEmptyColumn.txt", DefaultColumnNames.Label);
70+
var emptyColumn = result.TextLoaderArgs.Column.First(c => c.Name == "Empty");
71+
Assert.AreEqual(DataKind.TX, emptyColumn.Type);
72+
}
73+
74+
[TestMethod]
75+
public void InferColumnsWithDatasetWithBoolColumn()
76+
{
77+
var result = new MLContext().Data.InferColumns(@".\TestData\BinaryDatasetWithBoolColumn.txt", DefaultColumnNames.Label);
78+
Assert.AreEqual(2, result.TextLoaderArgs.Column.Count());
79+
Assert.AreEqual(2, result.ColumnPurpopses.Count());
80+
81+
var boolColumn = result.TextLoaderArgs.Column.First(c => c.Name == "Bool");
82+
var labelColumn = result.TextLoaderArgs.Column.First(c => c.Name == DefaultColumnNames.Label);
83+
// ensure non-label Boolean column is detected as R4
84+
Assert.AreEqual(DataKind.R4, boolColumn.Type);
85+
Assert.AreEqual(DataKind.BL, labelColumn.Type);
86+
87+
// ensure non-label Boolean column is detected as R4
88+
var boolPurpose = result.ColumnPurpopses.First(c => c.Name == "Bool").Purpose;
89+
var labelPurpose = result.ColumnPurpopses.First(c => c.Name == DefaultColumnNames.Label).Purpose;
90+
Assert.AreEqual(ColumnPurpose.NumericFeature, boolPurpose);
91+
Assert.AreEqual(ColumnPurpose.Label, labelPurpose);
92+
}
93+
94+
[TestMethod]
95+
public void InferColumnsWhereNameColumnIsOnlyFeature()
96+
{
97+
var result = new MLContext().Data.InferColumns(@".\TestData\NameColumnIsOnlyFeatureDataset.txt", DefaultColumnNames.Label);
98+
Assert.AreEqual(2, result.TextLoaderArgs.Column.Count());
99+
Assert.AreEqual(2, result.ColumnPurpopses.Count());
100+
101+
var nameColumn = result.TextLoaderArgs.Column.First(c => c.Name == DefaultColumnNames.Name);
102+
var labelColumn = result.TextLoaderArgs.Column.First(c => c.Name == DefaultColumnNames.Label);
103+
Assert.AreEqual(DataKind.TX, nameColumn.Type);
104+
Assert.AreEqual(DataKind.BL, labelColumn.Type);
105+
106+
var namePurpose = result.ColumnPurpopses.First(c => c.Name == DefaultColumnNames.Name).Purpose;
107+
var labelPurpose = result.ColumnPurpopses.First(c => c.Name == DefaultColumnNames.Label).Purpose;
108+
Assert.AreEqual(ColumnPurpose.TextFeature, namePurpose);
109+
Assert.AreEqual(ColumnPurpose.Label, labelPurpose);
110+
}
65111
}
66112
}

src/Test/Test.csproj

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,16 @@
1818
<ProjectReference Include="..\Microsoft.ML.Auto\Microsoft.ML.Auto.csproj" />
1919
</ItemGroup>
2020

21+
<ItemGroup>
22+
<None Update="TestData\NameColumnIsOnlyFeatureDataset.txt">
23+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
24+
</None>
25+
<None Update="TestData\BinaryDatasetWithBoolColumn.txt">
26+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
27+
</None>
28+
<None Update="TestData\DatasetWithEmptyColumn.txt">
29+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
30+
</None>
31+
</ItemGroup>
32+
2133
</Project>
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Label,Bool
2+
0,1
3+
0,0
4+
1,1
5+
1,0
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Label,Feature1,Empty
2+
0,2,
3+
0,4,
4+
1,1,

0 commit comments

Comments
 (0)