Skip to content

Fix a bug with group Id column in CV macro and add NameColumn argument to CV and TrainTest macros #467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 41 additions & 48 deletions src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ public float Cost

private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCatalog, RunContext context,
string id, string entryPointName, JObject inputs, JObject outputs, bool checkpoint = false,
string stageId = "", float cost = float.NaN, string label = null, string group = null, string weight = null)
string stageId = "", float cost = float.NaN, string label = null, string group = null, string weight = null, string name = null)
{
Contracts.AssertValue(env);
env.AssertNonEmpty(id);
Expand Down Expand Up @@ -510,49 +510,10 @@ private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCa
throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}");

var inputInstance = _inputBuilder.GetInstance();
var warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." +
" Using column '{2}'. To column use '{1}' instead, please specify this name in" +
"the trainer node arguments.";
if (!string.IsNullOrEmpty(label) && Utils.Size(_entryPoint.InputKinds) > 0 &&
_entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithLabel)))
{
var labelColField = _inputBuilder.GetFieldNameOrNull("LabelColumn");
ch.AssertNonEmpty(labelColField);
var labelColFieldType = _inputBuilder.GetFieldTypeOrNull(labelColField);
ch.Assert(labelColFieldType == typeof(string));
var inputLabel = inputInstance.GetType().GetField(labelColField).GetValue(inputInstance);
if (label != (string)inputLabel)
ch.Warning(warning, "label", label, inputLabel);
else
_inputBuilder.TrySetValue(labelColField, label);
}
if (!string.IsNullOrEmpty(group) && Utils.Size(_entryPoint.InputKinds) > 0 &&
_entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithGroupId)))
{
var groupColField = _inputBuilder.GetFieldNameOrNull("GroupIdColumn");
ch.AssertNonEmpty(groupColField);
var groupColFieldType = _inputBuilder.GetFieldTypeOrNull(groupColField);
ch.Assert(groupColFieldType == typeof(string));
var inputGroup = inputInstance.GetType().GetField(groupColField).GetValue(inputInstance);
if (group != (Optional<string>)inputGroup)
ch.Warning(warning, "group Id", label, inputGroup);
else
_inputBuilder.TrySetValue(groupColField, label);
}
if (!string.IsNullOrEmpty(weight) && Utils.Size(_entryPoint.InputKinds) > 0 &&
(_entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithWeight)) ||
_entryPoint.InputKinds.Contains(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))))
{
var weightColField = _inputBuilder.GetFieldNameOrNull("WeightColumn");
ch.AssertNonEmpty(weightColField);
var weightColFieldType = _inputBuilder.GetFieldTypeOrNull(weightColField);
ch.Assert(weightColFieldType == typeof(string));
var inputWeight = inputInstance.GetType().GetField(weightColField).GetValue(inputInstance);
if (weight != (Optional<string>)inputWeight)
ch.Warning(warning, "weight", label, inputWeight);
else
_inputBuilder.TrySetValue(weightColField, label);
}
SetColumnArgument(ch, inputInstance, "LabelColumn", label, "label", typeof(CommonInputs.ITrainerInputWithLabel));
SetColumnArgument(ch, inputInstance, "GroupIdColumn", group, "group Id", typeof(CommonInputs.ITrainerInputWithGroupId));
SetColumnArgument(ch, inputInstance, "WeightColumn", weight, "weight", typeof(CommonInputs.ITrainerInputWithWeight), typeof(CommonInputs.IUnsupervisedTrainerWithWeight));
SetColumnArgument(ch, inputInstance, "NameColumn", name, "name");

// Validate outputs.
_outputHelper = new OutputHelper(_host, _entryPoint.OutputType);
Expand All @@ -568,6 +529,38 @@ private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCa
Cost = cost;
}

private void SetColumnArgument(IChannel ch, object inputInstance, string argName, string colName, string columnRole, params Type[] inputKinds)
{
Contracts.AssertValue(ch);
ch.AssertValue(inputInstance);
ch.AssertNonEmpty(argName);
ch.AssertValueOrNull(colName);
ch.AssertNonEmpty(columnRole);
ch.AssertValueOrNull(inputKinds);

var colField = _inputBuilder.GetFieldNameOrNull(argName);
if (string.IsNullOrEmpty(colField))
return;

const string warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." +
" Using column '{2}'. To column use '{1}' instead, please specify this name in" +
"the trainer node arguments.";
if (!string.IsNullOrEmpty(colName) && Utils.Size(_entryPoint.InputKinds) > 0 &&
(Utils.Size(inputKinds) == 0 || _entryPoint.InputKinds.Intersect(inputKinds).Any()))
{
ch.AssertNonEmpty(colField);
var colFieldType = _inputBuilder.GetFieldTypeOrNull(colField);
ch.Assert(colFieldType == typeof(string));
var inputColName = inputInstance.GetType().GetField(colField).GetValue(inputInstance);
ch.Assert(inputColName is string || inputColName is Optional<string>);
var str = inputColName is string ? (string)inputColName : ((Optional<string>)inputColName).Value;
if (colName != str)
ch.Warning(warning, columnRole, colName, inputColName);
else
_inputBuilder.TrySetValue(colField, colName);
}
}

public static EntryPointNode Create(
IHostEnvironment env,
string entryPointName,
Expand Down Expand Up @@ -902,7 +895,7 @@ private object BuildParameterValue(List<ParameterBinding> bindings)
}

public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContext context, JArray nodes,
ModuleCatalog moduleCatalog, string label = null, string group = null, string weight = null)
ModuleCatalog moduleCatalog, string label = null, string group = null, string weight = null, string name = null)
{
Contracts.AssertValue(env);
env.AssertValue(context);
Expand All @@ -918,7 +911,7 @@ public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContex
if (node == null)
throw env.Except("Unexpected node token: '{0}'", nodes[i]);

string name = node[FieldNames.Name].Value<string>();
string nodeName = node[FieldNames.Name].Value<string>();
var inputs = node[FieldNames.Inputs] as JObject;
if (inputs == null && node[FieldNames.Inputs] != null)
throw env.Except("Unexpected {0} token: '{1}'", FieldNames.Inputs, node[FieldNames.Inputs]);
Expand All @@ -927,7 +920,7 @@ public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContex
if (outputs == null && node[FieldNames.Outputs] != null)
throw env.Except("Unexpected {0} token: '{1}'", FieldNames.Outputs, node[FieldNames.Outputs]);

var id = context.GenerateId(name);
var id = context.GenerateId(nodeName);
var unexpectedFields = node.Properties().Where(
x => x.Name != FieldNames.Name && x.Name != FieldNames.Inputs && x.Name != FieldNames.Outputs
&& x.Name != FieldNames.StageId && x.Name != FieldNames.Checkpoint && x.Name != FieldNames.Cost);
Expand All @@ -942,7 +935,7 @@ public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContex
ch.Warning("Node '{0}' has unexpected fields that are ignored: {1}", id, string.Join(", ", unexpectedFields.Select(x => x.Name)));
}

result.Add(new EntryPointNode(env, ch, moduleCatalog, context, id, name, inputs, outputs, checkpoint, stageId, cost, label, group, weight));
result.Add(new EntryPointNode(env, ch, moduleCatalog, context, id, nodeName, inputs, outputs, checkpoint, stageId, cost, label, group, weight, name));
}

ch.Done();
Expand Down
15 changes: 15 additions & 0 deletions src/Microsoft.ML/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2509,6 +2509,11 @@ public sealed partial class CrossValidationResultsCombiner
/// </summary>
public Microsoft.ML.Runtime.EntryPoints.Optional<string> GroupColumn { get; set; }

/// <summary>
/// Name column name
/// </summary>
public Microsoft.ML.Runtime.EntryPoints.Optional<string> NameColumn { get; set; }

/// <summary>
/// Specifies the trainer kind, which determines the evaluator to be used.
/// </summary>
Expand Down Expand Up @@ -2629,6 +2634,11 @@ public sealed partial class CrossValidator
/// </summary>
public Microsoft.ML.Runtime.EntryPoints.Optional<string> GroupColumn { get; set; }

/// <summary>
/// Name column name
/// </summary>
public Microsoft.ML.Runtime.EntryPoints.Optional<string> NameColumn { get; set; }


public sealed class Output
{
Expand Down Expand Up @@ -4020,6 +4030,11 @@ public sealed partial class TrainTestEvaluator
/// </summary>
public Microsoft.ML.Runtime.EntryPoints.Optional<string> GroupColumn { get; set; }

/// <summary>
/// Name column name
/// </summary>
public Microsoft.ML.Runtime.EntryPoints.Optional<string> NameColumn { get; set; }


public sealed class Output
{
Expand Down
31 changes: 20 additions & 11 deletions src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,29 @@ public sealed class Arguments

// For splitting the data into folds, this column is used for grouping rows and makes sure
// that a group of rows is not split among folds.
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for stratification", ShortName = "strat", SortOrder = 6)]
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for stratification", ShortName = "strat", SortOrder = 6)]
public string StratificationColumn;

// The number of folds to generate.
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of folds in k-fold cross-validation", ShortName = "k", SortOrder = 7)]
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of folds in k-fold cross-validation", ShortName = "k", SortOrder = 7)]
public int NumFolds = 2;

// REVIEW: suggest moving to subcomponents for evaluators, to allow for different parameters on the evaluators
// (and the same for the TrainTest macro). I currently do not know how to do this, so this should be revisited in the future.
[Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 8)]
public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer;

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 10)]
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 9)]
public string LabelColumn = DefaultColumnNames.Label;

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 11)]
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 10)]
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)]
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 11)]
public Optional<string> GroupColumn = Optional<string>.Implicit(DefaultColumnNames.GroupId);

[Argument(ArgumentType.AtMostOnce, HelpText = "Name column name", ShortName = "name", SortOrder = 12)]
public Optional<string> NameColumn = Optional<string>.Implicit(DefaultColumnNames.Name);
}

// REVIEW: This output would be much better as an array of CommonOutputs.ClassificationEvaluateOutput,
Expand Down Expand Up @@ -127,16 +130,19 @@ public sealed class CombineMetricsInput
[Argument(ArgumentType.Multiple, HelpText = "Warning datasets", SortOrder = 4)]
public IDataView[] Warnings;

[Argument(ArgumentType.AtMostOnce, HelpText = "The label column name", ShortName = "Label", SortOrder = 5)]
[Argument(ArgumentType.AtMostOnce, HelpText = "The label column name", ShortName = "Label", SortOrder = 6)]
public string LabelColumn = DefaultColumnNames.Label;

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 6)]
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 7)]
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)]
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 8)]
public Optional<string> GroupColumn = Optional<string>.Implicit(DefaultColumnNames.GroupId);

[Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 6)]
[Argument(ArgumentType.AtMostOnce, HelpText = "Name column name", ShortName = "name", SortOrder = 9)]
public Optional<string> NameColumn = Optional<string>.Implicit(DefaultColumnNames.Name);

[Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 5)]
public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer;
}

Expand Down Expand Up @@ -206,7 +212,8 @@ public static CommonOutputs.MacroOutput<Output> CrossValidate(
TransformModel = null,
LabelColumn = input.LabelColumn,
GroupColumn = input.GroupColumn,
WeightColumn = input.WeightColumn
WeightColumn = input.WeightColumn,
NameColumn = input.NameColumn
};

if (transformModelVarName != null)
Expand Down Expand Up @@ -377,6 +384,7 @@ public static CommonOutputs.MacroOutput<Output> CrossValidate(
combineArgs.LabelColumn = input.LabelColumn;
combineArgs.WeightColumn = input.WeightColumn;
combineArgs.GroupColumn = input.GroupColumn;
combineArgs.NameColumn = input.NameColumn;

// Set the input bindings for the CombineMetrics entry point.
var combineInputBindingMap = new Dictionary<string, List<ParameterBinding>>();
Expand Down Expand Up @@ -429,7 +437,8 @@ public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetrics
{
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn),
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, input.WeightColumn.Value),
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, input.GroupColumn.Value)
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, input.GroupColumn.Value),
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Name, input.NameColumn.Value)
})).ToArray(),
out var variableSizeVectorColumnNames);

Expand Down
16 changes: 11 additions & 5 deletions src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,17 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Indicates whether to include and output training dataset metrics.", SortOrder = 9)]
public Boolean IncludeTrainingMetrics = false;

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 10)]
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 10)]
public string LabelColumn = DefaultColumnNames.Label;

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 11)]
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 11)]
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)]
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)]
public Optional<string> GroupColumn = Optional<string>.Implicit(DefaultColumnNames.GroupId);

[Argument(ArgumentType.AtMostOnce, HelpText = "Name column name", ShortName = "name", SortOrder = 13)]
public Optional<string> NameColumn = Optional<string>.Implicit(DefaultColumnNames.Name);
}

public sealed class Output
Expand Down Expand Up @@ -120,7 +123,9 @@ public static CommonOutputs.MacroOutput<Output> TrainTest(
// Parse the subgraph.
var subGraphRunContext = new RunContext(env);
var subGraphNodes = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes, node.Catalog, input.LabelColumn,
input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null, input.WeightColumn.IsExplicit ? input.WeightColumn.Value : null);
input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null,
input.WeightColumn.IsExplicit ? input.WeightColumn.Value : null,
input.NameColumn.IsExplicit ? input.NameColumn.Value : null);

// Change the subgraph to use the training data as input.
var varName = input.Inputs.Data.VarName;
Expand Down Expand Up @@ -221,7 +226,8 @@ public static CommonOutputs.MacroOutput<Output> TrainTest(
{
LabelColumn = input.LabelColumn,
WeightColumn = input.WeightColumn.IsExplicit ? input.WeightColumn.Value : null,
GroupColumn = input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null
GroupColumn = input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null,
NameColumn = input.NameColumn.IsExplicit ? input.NameColumn.Value : null
};

string outVariableName;
Expand Down
Loading