Skip to content

Commit f3d57a1

Browse files
yaeldMSTomFinley
authored andcommitted
Fix a bug with group Id column in CV macro and add NameColumn argument to CV and TrainTest macros (#467)
* Fix a bug with group Id column in CV macro * add NameColumn argument
1 parent 53c2a15 commit f3d57a1

File tree

6 files changed

+171
-95
lines changed

6 files changed

+171
-95
lines changed

src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ public float Cost
475475

476476
private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCatalog, RunContext context,
477477
string id, string entryPointName, JObject inputs, JObject outputs, bool checkpoint = false,
478-
string stageId = "", float cost = float.NaN, string label = null, string group = null, string weight = null)
478+
string stageId = "", float cost = float.NaN, string label = null, string group = null, string weight = null, string name = null)
479479
{
480480
Contracts.AssertValue(env);
481481
env.AssertNonEmpty(id);
@@ -510,49 +510,10 @@ private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCa
510510
throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}");
511511

512512
var inputInstance = _inputBuilder.GetInstance();
513-
var warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." +
514-
" Using column '{2}'. To column use '{1}' instead, please specify this name in" +
515-
"the trainer node arguments.";
516-
if (!string.IsNullOrEmpty(label) && Utils.Size(_entryPoint.InputKinds) > 0 &&
517-
_entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithLabel)))
518-
{
519-
var labelColField = _inputBuilder.GetFieldNameOrNull("LabelColumn");
520-
ch.AssertNonEmpty(labelColField);
521-
var labelColFieldType = _inputBuilder.GetFieldTypeOrNull(labelColField);
522-
ch.Assert(labelColFieldType == typeof(string));
523-
var inputLabel = inputInstance.GetType().GetField(labelColField).GetValue(inputInstance);
524-
if (label != (string)inputLabel)
525-
ch.Warning(warning, "label", label, inputLabel);
526-
else
527-
_inputBuilder.TrySetValue(labelColField, label);
528-
}
529-
if (!string.IsNullOrEmpty(group) && Utils.Size(_entryPoint.InputKinds) > 0 &&
530-
_entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithGroupId)))
531-
{
532-
var groupColField = _inputBuilder.GetFieldNameOrNull("GroupIdColumn");
533-
ch.AssertNonEmpty(groupColField);
534-
var groupColFieldType = _inputBuilder.GetFieldTypeOrNull(groupColField);
535-
ch.Assert(groupColFieldType == typeof(string));
536-
var inputGroup = inputInstance.GetType().GetField(groupColField).GetValue(inputInstance);
537-
if (group != (Optional<string>)inputGroup)
538-
ch.Warning(warning, "group Id", label, inputGroup);
539-
else
540-
_inputBuilder.TrySetValue(groupColField, label);
541-
}
542-
if (!string.IsNullOrEmpty(weight) && Utils.Size(_entryPoint.InputKinds) > 0 &&
543-
(_entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithWeight)) ||
544-
_entryPoint.InputKinds.Contains(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))))
545-
{
546-
var weightColField = _inputBuilder.GetFieldNameOrNull("WeightColumn");
547-
ch.AssertNonEmpty(weightColField);
548-
var weightColFieldType = _inputBuilder.GetFieldTypeOrNull(weightColField);
549-
ch.Assert(weightColFieldType == typeof(string));
550-
var inputWeight = inputInstance.GetType().GetField(weightColField).GetValue(inputInstance);
551-
if (weight != (Optional<string>)inputWeight)
552-
ch.Warning(warning, "weight", label, inputWeight);
553-
else
554-
_inputBuilder.TrySetValue(weightColField, label);
555-
}
513+
SetColumnArgument(ch, inputInstance, "LabelColumn", label, "label", typeof(CommonInputs.ITrainerInputWithLabel));
514+
SetColumnArgument(ch, inputInstance, "GroupIdColumn", group, "group Id", typeof(CommonInputs.ITrainerInputWithGroupId));
515+
SetColumnArgument(ch, inputInstance, "WeightColumn", weight, "weight", typeof(CommonInputs.ITrainerInputWithWeight), typeof(CommonInputs.IUnsupervisedTrainerWithWeight));
516+
SetColumnArgument(ch, inputInstance, "NameColumn", name, "name");
556517

557518
// Validate outputs.
558519
_outputHelper = new OutputHelper(_host, _entryPoint.OutputType);
@@ -568,6 +529,38 @@ private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCa
568529
Cost = cost;
569530
}
570531

532+
private void SetColumnArgument(IChannel ch, object inputInstance, string argName, string colName, string columnRole, params Type[] inputKinds)
533+
{
534+
Contracts.AssertValue(ch);
535+
ch.AssertValue(inputInstance);
536+
ch.AssertNonEmpty(argName);
537+
ch.AssertValueOrNull(colName);
538+
ch.AssertNonEmpty(columnRole);
539+
ch.AssertValueOrNull(inputKinds);
540+
541+
var colField = _inputBuilder.GetFieldNameOrNull(argName);
542+
if (string.IsNullOrEmpty(colField))
543+
return;
544+
545+
const string warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." +
546+
" Using column '{2}'. To column use '{1}' instead, please specify this name in" +
547+
"the trainer node arguments.";
548+
if (!string.IsNullOrEmpty(colName) && Utils.Size(_entryPoint.InputKinds) > 0 &&
549+
(Utils.Size(inputKinds) == 0 || _entryPoint.InputKinds.Intersect(inputKinds).Any()))
550+
{
551+
ch.AssertNonEmpty(colField);
552+
var colFieldType = _inputBuilder.GetFieldTypeOrNull(colField);
553+
ch.Assert(colFieldType == typeof(string));
554+
var inputColName = inputInstance.GetType().GetField(colField).GetValue(inputInstance);
555+
ch.Assert(inputColName is string || inputColName is Optional<string>);
556+
var str = inputColName is string ? (string)inputColName : ((Optional<string>)inputColName).Value;
557+
if (colName != str)
558+
ch.Warning(warning, columnRole, colName, inputColName);
559+
else
560+
_inputBuilder.TrySetValue(colField, colName);
561+
}
562+
}
563+
571564
public static EntryPointNode Create(
572565
IHostEnvironment env,
573566
string entryPointName,
@@ -902,7 +895,7 @@ private object BuildParameterValue(List<ParameterBinding> bindings)
902895
}
903896

904897
public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContext context, JArray nodes,
905-
ModuleCatalog moduleCatalog, string label = null, string group = null, string weight = null)
898+
ModuleCatalog moduleCatalog, string label = null, string group = null, string weight = null, string name = null)
906899
{
907900
Contracts.AssertValue(env);
908901
env.AssertValue(context);
@@ -918,7 +911,7 @@ public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContex
918911
if (node == null)
919912
throw env.Except("Unexpected node token: '{0}'", nodes[i]);
920913

921-
string name = node[FieldNames.Name].Value<string>();
914+
string nodeName = node[FieldNames.Name].Value<string>();
922915
var inputs = node[FieldNames.Inputs] as JObject;
923916
if (inputs == null && node[FieldNames.Inputs] != null)
924917
throw env.Except("Unexpected {0} token: '{1}'", FieldNames.Inputs, node[FieldNames.Inputs]);
@@ -927,7 +920,7 @@ public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContex
927920
if (outputs == null && node[FieldNames.Outputs] != null)
928921
throw env.Except("Unexpected {0} token: '{1}'", FieldNames.Outputs, node[FieldNames.Outputs]);
929922

930-
var id = context.GenerateId(name);
923+
var id = context.GenerateId(nodeName);
931924
var unexpectedFields = node.Properties().Where(
932925
x => x.Name != FieldNames.Name && x.Name != FieldNames.Inputs && x.Name != FieldNames.Outputs
933926
&& x.Name != FieldNames.StageId && x.Name != FieldNames.Checkpoint && x.Name != FieldNames.Cost);
@@ -942,7 +935,7 @@ public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContex
942935
ch.Warning("Node '{0}' has unexpected fields that are ignored: {1}", id, string.Join(", ", unexpectedFields.Select(x => x.Name)));
943936
}
944937

945-
result.Add(new EntryPointNode(env, ch, moduleCatalog, context, id, name, inputs, outputs, checkpoint, stageId, cost, label, group, weight));
938+
result.Add(new EntryPointNode(env, ch, moduleCatalog, context, id, nodeName, inputs, outputs, checkpoint, stageId, cost, label, group, weight, name));
946939
}
947940

948941
ch.Done();

src/Microsoft.ML/CSharpApi.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2509,6 +2509,11 @@ public sealed partial class CrossValidationResultsCombiner
25092509
/// </summary>
25102510
public Microsoft.ML.Runtime.EntryPoints.Optional<string> GroupColumn { get; set; }
25112511

2512+
/// <summary>
2513+
/// Name column name
2514+
/// </summary>
2515+
public Microsoft.ML.Runtime.EntryPoints.Optional<string> NameColumn { get; set; }
2516+
25122517
/// <summary>
25132518
/// Specifies the trainer kind, which determines the evaluator to be used.
25142519
/// </summary>
@@ -2629,6 +2634,11 @@ public sealed partial class CrossValidator
26292634
/// </summary>
26302635
public Microsoft.ML.Runtime.EntryPoints.Optional<string> GroupColumn { get; set; }
26312636

2637+
/// <summary>
2638+
/// Name column name
2639+
/// </summary>
2640+
public Microsoft.ML.Runtime.EntryPoints.Optional<string> NameColumn { get; set; }
2641+
26322642

26332643
public sealed class Output
26342644
{
@@ -4020,6 +4030,11 @@ public sealed partial class TrainTestEvaluator
40204030
/// </summary>
40214031
public Microsoft.ML.Runtime.EntryPoints.Optional<string> GroupColumn { get; set; }
40224032

4033+
/// <summary>
4034+
/// Name column name
4035+
/// </summary>
4036+
public Microsoft.ML.Runtime.EntryPoints.Optional<string> NameColumn { get; set; }
4037+
40234038

40244039
public sealed class Output
40254040
{

src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,26 +66,29 @@ public sealed class Arguments
6666

6767
// For splitting the data into folds, this column is used for grouping rows and makes sure
6868
// that a group of rows is not split among folds.
69-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for stratification", ShortName = "strat", SortOrder = 6)]
69+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for stratification", ShortName = "strat", SortOrder = 6)]
7070
public string StratificationColumn;
7171

7272
// The number of folds to generate.
73-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of folds in k-fold cross-validation", ShortName = "k", SortOrder = 7)]
73+
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of folds in k-fold cross-validation", ShortName = "k", SortOrder = 7)]
7474
public int NumFolds = 2;
7575

7676
// REVIEW: suggest moving to subcomponents for evaluators, to allow for different parameters on the evaluators
7777
// (and the same for the TrainTest macro). I currently do not know how to do this, so this should be revisited in the future.
7878
[Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 8)]
7979
public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer;
8080

81-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 10)]
81+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 9)]
8282
public string LabelColumn = DefaultColumnNames.Label;
8383

84-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 11)]
84+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 10)]
8585
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
8686

87-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)]
87+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 11)]
8888
public Optional<string> GroupColumn = Optional<string>.Implicit(DefaultColumnNames.GroupId);
89+
90+
[Argument(ArgumentType.AtMostOnce, HelpText = "Name column name", ShortName = "name", SortOrder = 12)]
91+
public Optional<string> NameColumn = Optional<string>.Implicit(DefaultColumnNames.Name);
8992
}
9093

9194
// REVIEW: This output would be much better as an array of CommonOutputs.ClassificationEvaluateOutput,
@@ -127,16 +130,19 @@ public sealed class CombineMetricsInput
127130
[Argument(ArgumentType.Multiple, HelpText = "Warning datasets", SortOrder = 4)]
128131
public IDataView[] Warnings;
129132

130-
[Argument(ArgumentType.AtMostOnce, HelpText = "The label column name", ShortName = "Label", SortOrder = 5)]
133+
[Argument(ArgumentType.AtMostOnce, HelpText = "The label column name", ShortName = "Label", SortOrder = 6)]
131134
public string LabelColumn = DefaultColumnNames.Label;
132135

133-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 6)]
136+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 7)]
134137
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
135138

136-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)]
139+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 8)]
137140
public Optional<string> GroupColumn = Optional<string>.Implicit(DefaultColumnNames.GroupId);
138141

139-
[Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 6)]
142+
[Argument(ArgumentType.AtMostOnce, HelpText = "Name column name", ShortName = "name", SortOrder = 9)]
143+
public Optional<string> NameColumn = Optional<string>.Implicit(DefaultColumnNames.Name);
144+
145+
[Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 5)]
140146
public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer;
141147
}
142148

@@ -206,7 +212,8 @@ public static CommonOutputs.MacroOutput<Output> CrossValidate(
206212
TransformModel = null,
207213
LabelColumn = input.LabelColumn,
208214
GroupColumn = input.GroupColumn,
209-
WeightColumn = input.WeightColumn
215+
WeightColumn = input.WeightColumn,
216+
NameColumn = input.NameColumn
210217
};
211218

212219
if (transformModelVarName != null)
@@ -377,6 +384,7 @@ public static CommonOutputs.MacroOutput<Output> CrossValidate(
377384
combineArgs.LabelColumn = input.LabelColumn;
378385
combineArgs.WeightColumn = input.WeightColumn;
379386
combineArgs.GroupColumn = input.GroupColumn;
387+
combineArgs.NameColumn = input.NameColumn;
380388

381389
// Set the input bindings for the CombineMetrics entry point.
382390
var combineInputBindingMap = new Dictionary<string, List<ParameterBinding>>();
@@ -429,7 +437,8 @@ public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetrics
429437
{
430438
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn),
431439
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, input.WeightColumn.Value),
432-
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, input.GroupColumn.Value)
440+
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, input.GroupColumn.Value),
441+
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Name, input.NameColumn.Value)
433442
})).ToArray(),
434443
out var variableSizeVectorColumnNames);
435444

src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,17 @@ public sealed class Arguments
6363
[Argument(ArgumentType.AtMostOnce, HelpText = "Indicates whether to include and output training dataset metrics.", SortOrder = 9)]
6464
public Boolean IncludeTrainingMetrics = false;
6565

66-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 10)]
66+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 10)]
6767
public string LabelColumn = DefaultColumnNames.Label;
6868

69-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 11)]
69+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 11)]
7070
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
7171

72-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)]
72+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)]
7373
public Optional<string> GroupColumn = Optional<string>.Implicit(DefaultColumnNames.GroupId);
74+
75+
[Argument(ArgumentType.AtMostOnce, HelpText = "Name column name", ShortName = "name", SortOrder = 13)]
76+
public Optional<string> NameColumn = Optional<string>.Implicit(DefaultColumnNames.Name);
7477
}
7578

7679
public sealed class Output
@@ -120,7 +123,9 @@ public static CommonOutputs.MacroOutput<Output> TrainTest(
120123
// Parse the subgraph.
121124
var subGraphRunContext = new RunContext(env);
122125
var subGraphNodes = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes, node.Catalog, input.LabelColumn,
123-
input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null, input.WeightColumn.IsExplicit ? input.WeightColumn.Value : null);
126+
input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null,
127+
input.WeightColumn.IsExplicit ? input.WeightColumn.Value : null,
128+
input.NameColumn.IsExplicit ? input.NameColumn.Value : null);
124129

125130
// Change the subgraph to use the training data as input.
126131
var varName = input.Inputs.Data.VarName;
@@ -221,7 +226,8 @@ public static CommonOutputs.MacroOutput<Output> TrainTest(
221226
{
222227
LabelColumn = input.LabelColumn,
223228
WeightColumn = input.WeightColumn.IsExplicit ? input.WeightColumn.Value : null,
224-
GroupColumn = input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null
229+
GroupColumn = input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null,
230+
NameColumn = input.NameColumn.IsExplicit ? input.NameColumn.Value : null
225231
};
226232

227233
string outVariableName;

0 commit comments

Comments
 (0)