Skip to content

Commit c491651

Browse files
authored
[Part 3] Added convenience constructors for set of transforms. (#520)
1 parent 3053f3d commit c491651

13 files changed

+284
-15
lines changed

src/Microsoft.ML.Transforms/GroupTransform.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,18 @@ public sealed class Arguments : TransformInputBase
8888

8989
private readonly GroupSchema _schema;
9090

91+
/// <summary>
92+
/// Convenience constructor for public facing API.
93+
/// </summary>
94+
/// <param name="env">Host Environment.</param>
95+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
96+
/// <param name="groupKey">Columns to group by</param>
97+
/// <param name="columns">Columns to group together</param>
98+
public GroupTransform(IHostEnvironment env, IDataView input, string groupKey, params string[] columns)
99+
: this(env, new Arguments() { GroupKey = new[] { groupKey }, Column = columns }, input)
100+
{
101+
}
102+
91103
public GroupTransform(IHostEnvironment env, Arguments args, IDataView input)
92104
: base(env, RegistrationName, input)
93105
{

src/Microsoft.ML.Transforms/HashJoinTransform.cs

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ public sealed class HashJoinTransform : OneToOneTransformBase
3737
public const int NumBitsMin = 1;
3838
public const int NumBitsLim = 32;
3939

40+
private static class Defaults
41+
{
42+
public const bool Join = true;
43+
public const int HashBits = NumBitsLim - 1;
44+
public const uint Seed = 314489979;
45+
public const bool Ordered = true;
46+
}
47+
4048
public sealed class Arguments : TransformInputBase
4149
{
4250
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)",
@@ -45,17 +53,17 @@ public sealed class Arguments : TransformInputBase
4553
public Column[] Column;
4654

4755
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the values need to be combined for a single hash")]
48-
public bool Join = true;
56+
public bool Join = Defaults.Join;
4957

5058
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive.",
5159
ShortName = "bits", SortOrder = 2)]
52-
public int HashBits = NumBitsLim - 1;
60+
public int HashBits = Defaults.HashBits;
5361

5462
[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
55-
public uint Seed = 314489979;
63+
public uint Seed = Defaults.Seed;
5664

5765
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash", ShortName = "ord")]
58-
public bool Ordered = true;
66+
public bool Ordered = Defaults.Ordered;
5967
}
6068

6169
public sealed class Column : OneToOneColumn
@@ -166,6 +174,25 @@ private static VersionInfo GetVersionInfo()
166174

167175
private readonly ColumnInfoEx[] _exes;
168176

177+
/// <summary>
178+
/// Convenience constructor for public facing API.
179+
/// </summary>
180+
/// <param name="env">Host Environment.</param>
181+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
182+
/// <param name="name">Name of the output column.</param>
183+
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
184+
/// <param name="join">Whether the values need to be combined for a single hash.</param>
185+
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 31, inclusive.</param>
186+
public HashJoinTransform(IHostEnvironment env,
187+
IDataView input,
188+
string name,
189+
string source = null,
190+
bool join = Defaults.Join,
191+
int hashBits = Defaults.HashBits)
192+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, Join = join, HashBits = hashBits }, input)
193+
{
194+
}
195+
169196
public HashJoinTransform(IHostEnvironment env, Arguments args, IDataView input)
170197
: base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input, TestColumnType)
171198
{

src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ private static VersionInfo GetVersionInfo()
5454

5555
private readonly VectorType[] _types;
5656

57+
/// <summary>
58+
/// Convenience constructor for public facing API.
59+
/// </summary>
60+
/// <param name="env">Host Environment.</param>
61+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
62+
/// <param name="name">Name of the output column.</param>
63+
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
64+
public KeyToBinaryVectorTransform(IHostEnvironment env, IDataView input, string name, string source = null)
65+
: this(env, new Arguments() { Column = new[] { new KeyToVectorTransform.Column() { Source = source ?? name, Name = name } } }, input)
66+
{
67+
}
68+
5769
/// <summary>
5870
/// Public constructor corresponding to SignatureDataTransform.
5971
/// </summary>

src/Microsoft.ML.Transforms/LoadTransform.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,25 @@ public class Arguments
3939

4040
internal const string Summary = "Loads specified transforms from the model file and applies them to current data.";
4141

42+
/// <summary>
43+
/// A helper method to create <see cref="LoadTransform"/> for public facing API.
44+
/// </summary>
45+
/// <param name="env">Host Environment.</param>
46+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
47+
/// <param name="modelFile">Model file to load the transforms from.</param>
48+
/// <param name="tag">The tags (comma-separated) to be loaded (or omitted, if complement is true).</param>
49+
/// <param name="complement">Whether to load all transforms except those marked by tags.</param>
50+
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string[] tag, bool complement = false)
51+
{
52+
var args = new Arguments()
53+
{
54+
ModelFile = modelFile,
55+
Tag = tag,
56+
Complement = complement
57+
};
58+
return Create(env, args, input);
59+
}
60+
4261
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
4362
{
4463
Contracts.CheckValue(env, nameof(env));

src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ public static class MutualInformationFeatureSelectionTransform
3333
public const string UserName = "Mutual Information Feature Selection Transform";
3434
public const string ShortName = "MIFeatureSelection";
3535

36+
private static class Defaults
37+
{
38+
public const string LabelColumn = DefaultColumnNames.Label;
39+
public const int SlotsInOutput = 1000;
40+
public const int NumBins = 256;
41+
}
42+
3643
public sealed class Arguments : TransformInputBase
3744
{
3845
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Columns to use for feature selection", ShortName = "col",
@@ -41,19 +48,45 @@ public sealed class Arguments : TransformInputBase
4148

4249
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for labels", ShortName = "lab",
4350
SortOrder = 4, Purpose = SpecialPurpose.ColumnName)]
44-
public string LabelColumn = DefaultColumnNames.Label;
51+
public string LabelColumn = Defaults.LabelColumn;
4552

4653
[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of slots to preserve in output", ShortName = "topk,numSlotsToKeep",
4754
SortOrder = 1)]
48-
public int SlotsInOutput = 1000;
55+
public int SlotsInOutput = Defaults.SlotsInOutput;
4956

5057
[Argument(ArgumentType.AtMostOnce, HelpText = "Max number of bins for R4/R8 columns, power of 2 recommended",
5158
ShortName = "bins")]
52-
public int NumBins = 256;
59+
public int NumBins = Defaults.NumBins;
5360
}
5461

5562
internal static string RegistrationName = "MutualInformationFeatureSelectionTransform";
5663

64+
/// <summary>
65+
/// A helper method to create <see cref="IDataTransform"/> for selecting the top k slots ordered by their mutual information.
66+
/// </summary>
67+
/// <param name="env">Host Environment.</param>
68+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
69+
/// <param name="labelColumn">Column to use for labels.</param>
70+
/// <param name="slotsInOutput">The maximum number of slots to preserve in output.</param>
71+
/// <param name="numBins">Max number of bins for R4/R8 columns, power of 2 recommended.</param>
72+
/// <param name="columns">Columns to use for feature selection.</param>
73+
public static IDataTransform Create(IHostEnvironment env,
74+
IDataView input,
75+
string labelColumn = Defaults.LabelColumn,
76+
int slotsInOutput = Defaults.SlotsInOutput,
77+
int numBins = Defaults.NumBins,
78+
params string[] columns)
79+
{
80+
var args = new Arguments()
81+
{
82+
Column = columns,
83+
LabelColumn = labelColumn,
84+
SlotsInOutput = slotsInOutput,
85+
NumBins = numBins
86+
};
87+
return Create(env, args, input);
88+
}
89+
5790
/// <summary>
5891
/// Create method corresponding to SignatureDataTransform.
5992
/// </summary>

src/Microsoft.ML.Transforms/NADropTransform.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,18 @@ private static VersionInfo GetVersionInfo()
6969
// The isNA delegates, parallel to Infos.
7070
private readonly Delegate[] _isNAs;
7171

72+
/// <summary>
73+
/// Convenience constructor for public facing API.
74+
/// </summary>
75+
/// <param name="env">Host Environment.</param>
76+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
77+
/// <param name="name">Name of the output column.</param>
78+
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
79+
public NADropTransform(IHostEnvironment env, IDataView input, string name, string source = null)
80+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
81+
{
82+
}
83+
7284
public NADropTransform(IHostEnvironment env, Arguments args, IDataView input)
7385
: base(Contracts.CheckRef(env, nameof(env)), RegistrationName, env.CheckRef(args, nameof(args)).Column, input, TestType)
7486
{

src/Microsoft.ML.Transforms/NAHandleTransform.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,25 @@ public static class NAHandleTransform
3636
{
3737
public enum ReplacementKind
3838
{
39+
/// <summary>
40+
/// Replace with the default value of the column based on it's type. For example, 'zero' for numeric and 'empty' for string/text columns.
41+
/// </summary>
3942
[EnumValueDisplay("Zero/empty")]
4043
DefaultValue,
44+
45+
/// <summary>
46+
/// Replace with the mean value of the column. Supports only numeric/time span/ DateTime columns.
47+
/// </summary>
4148
Mean,
49+
50+
/// <summary>
51+
/// Replace with the minimum value of the column. Supports only numeric/time span/ DateTime columns.
52+
/// </summary>
4253
Minimum,
54+
55+
/// <summary>
56+
/// Replace with the maximum value of the column. Supports only numeric/time span/ DateTime columns.
57+
/// </summary>
4358
Maximum,
4459

4560
[HideEnumValue]
@@ -105,6 +120,27 @@ public bool TryUnparse(StringBuilder sb)
105120
internal const string FriendlyName = "NA Handle Transform";
106121
internal const string ShortName = "NAHandle";
107122

123+
/// <summary>
124+
/// A helper method to create <see cref="NAHandleTransform"/> for public facing API.
125+
/// </summary>
126+
/// <param name="env">Host Environment.</param>
127+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
128+
/// <param name="name">Name of the output column.</param>
129+
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
130+
/// <param name="replaceWith">The replacement method to utilize.</param>
131+
public static IDataTransform Create(IHostEnvironment env, IDataView input, string name, string source = null, ReplacementKind replaceWith = ReplacementKind.DefaultValue)
132+
{
133+
var args = new Arguments()
134+
{
135+
Column = new[]
136+
{
137+
new Column() { Source = source ?? name, Name = name }
138+
},
139+
ReplaceWith = replaceWith
140+
};
141+
return Create(env, args, input);
142+
}
143+
108144
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
109145
{
110146
Contracts.CheckValue(env, nameof(env));

src/Microsoft.ML.Transforms/NAIndicatorTransform.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@ private static string TestType(ColumnType type)
8585
// The output column types, parallel to Infos.
8686
private readonly ColumnType[] _types;
8787

88+
/// <summary>
89+
/// Convenience constructor for public facing API.
90+
/// </summary>
91+
/// <param name="env">Host Environment.</param>
92+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
93+
/// <param name="name">Name of the output column.</param>
94+
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
95+
public NAIndicatorTransform(IHostEnvironment env, IDataView input, string name, string source = null)
96+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
97+
{
98+
}
99+
88100
/// <summary>
89101
/// Public constructor corresponding to SignatureDataTransform.
90102
/// </summary>

src/Microsoft.ML.Transforms/NAReplaceTransform.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,19 @@ private static string TestType<T>(ColumnType type)
186186

187187
public override bool CanSaveOnnx => true;
188188

189+
/// <summary>
190+
/// Convenience constructor for public facing API.
191+
/// </summary>
192+
/// <param name="env">Host Environment.</param>
193+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
194+
/// <param name="name">Name of the output column.</param>
195+
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
196+
/// <param name="replacementKind">The replacement method to utilize.</param>
197+
public NAReplaceTransform(IHostEnvironment env, IDataView input, string name, string source = null, ReplacementKind replacementKind = ReplacementKind.DefaultValue)
198+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ReplacementKind = replacementKind }, input)
199+
{
200+
}
201+
189202
/// <summary>
190203
/// Public constructor corresponding to SignatureDataTransform.
191204
/// </summary>

src/Microsoft.ML.Transforms/OptionalColumnTransform.cs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,14 @@
2626

2727
namespace Microsoft.ML.Runtime.DataPipe
2828
{
29-
public class OptionalColumnTransform : RowToRowMapperTransformBase
29+
/// <summary>
30+
/// This transform is used to mark some of the columns (e.g. Label) optional during training so that the columns are not required during scoring.
31+
/// When applied to new data, if any of the optional columns is not present a dummy columns is created having the same properties (e.g. 'name', 'type' etc.) as used during training.
32+
/// The columns are filled with default values. The value is
33+
/// - scalar for scalar column
34+
/// - totally sparse vector for vector column.
35+
/// </summary>
36+
public sealed class OptionalColumnTransform : RowToRowMapperTransformBase
3037
{
3138
public sealed class Arguments : TransformInputBase
3239
{
@@ -232,6 +239,17 @@ private static VersionInfo GetVersionInfo()
232239

233240
private const string RegistrationName = "OptionalColumn";
234241

242+
/// <summary>
243+
/// Convenience constructor for public facing API.
244+
/// </summary>
245+
/// <param name="env">Host Environment.</param>
246+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
247+
/// <param name="columns">Columns to transform.</param>
248+
public OptionalColumnTransform(IHostEnvironment env, IDataView input, params string[] columns)
249+
: this(env, new Arguments() { Column = columns }, input)
250+
{
251+
}
252+
235253
/// <summary>
236254
/// Public constructor corresponding to SignatureDataTransform.
237255
/// </summary>

0 commit comments

Comments
 (0)