Skip to content

Commit 16f7883

Browse files
committed
Added convenience constructors for set of transforms (Part 2).
1 parent 17a3813 commit 16f7883

13 files changed

+280
-15
lines changed

src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,18 @@ private static VersionInfo GetVersionInfo()
442442

443443
private const string RegistrationName = "ChooseColumns";
444444

445+
/// <summary>
446+
/// Convenience constructor for public facing API.
447+
/// </summary>
448+
/// <param name="env">Host Environment.</param>
449+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
450+
/// <param name="name">Name of the output column.</param>
451+
/// <param name="source">Name of the selected column. If this is null '<paramref name="name"/>' will be used.</param>
452+
public ChooseColumnsTransform(IHostEnvironment env, IDataView input, string name, string source = null)
453+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
454+
{
455+
}
456+
445457
/// <summary>
446458
/// Public constructor corresponding to SignatureDataTransform.
447459
/// </summary>

src/Microsoft.ML.Data/Transforms/ConvertTransform.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,16 @@ public bool TryUnparse(StringBuilder sb)
108108

109109
public class Arguments : TransformInputBase
110110
{
111+
public Arguments()
112+
{
113+
114+
}
115+
116+
public Arguments(string name, string source)
117+
{
118+
Column = new[] { new Column() { Source = source ?? name, Name = name } };
119+
}
120+
111121
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:type:src)", ShortName = "col", SortOrder = 1)]
112122
public Column[] Column;
113123

@@ -169,6 +179,25 @@ private static VersionInfo GetVersionInfo()
169179
// This is parallel to Infos.
170180
private readonly ColInfoEx[] _exes;
171181

182+
/// <summary>
183+
/// Convenience constructor for public facing API.
184+
/// </summary>
185+
/// <param name="env">Host Environment.</param>
186+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
187+
/// <param name="name">Name of the output column.</param>
188+
/// <param name="source">Name of the column to be converted. If this is null '<paramref name="name"/>' will be used.</param>
189+
/// <param name="resultType">The expected type of the converted column.</param>
190+
/// <param name="keyRange">For a key column, this defines the range of values.</param>
191+
public ConvertTransform(IHostEnvironment env,
192+
IDataView input,
193+
string name,
194+
string source = null,
195+
DataKind? resultType = null,
196+
KeyRange keyRange = null)
197+
: this(env, new Arguments(name, source) { ResultType = resultType, KeyRange = keyRange }, input)
198+
{
199+
}
200+
172201
public ConvertTransform(IHostEnvironment env, Arguments args, IDataView input)
173202
: base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column,
174203
input, null)

src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,18 @@ public ColInfoEx(SlotDropper slotDropper, bool suppressed, ColumnType typeDst, i
216216

217217
private readonly ColInfoEx[] _exes;
218218

219+
/// <summary>
220+
/// Convenience constructor for public facing API.
221+
/// </summary>
222+
/// <param name="env">Host Environment.</param>
223+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
224+
/// <param name="name">Name of the output column.</param>
225+
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
226+
public DropSlotsTransform(IHostEnvironment env, IDataView input, string name, string source = null)
227+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
228+
{
229+
}
230+
219231
/// <summary>
220232
/// Public constructor corresponding to SignatureDataTransform.
221233
/// </summary>

src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,22 @@ private bool TryParse(string str)
7777
}
7878
}
7979

80+
private static class Defaults
81+
{
82+
public const bool UseCounter = false;
83+
public const uint Seed = 42;
84+
}
85+
8086
public sealed class Arguments : TransformInputBase
8187
{
8288
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:seed)", ShortName = "col", SortOrder = 1)]
8389
public Column[] Column;
8490

8591
[Argument(ArgumentType.AtMostOnce, HelpText = "Use an auto-incremented integer starting at zero instead of a random number", ShortName = "cnt")]
86-
public bool UseCounter;
92+
public bool UseCounter = Defaults.UseCounter;
8793

8894
[Argument(ArgumentType.AtMostOnce, HelpText = "The random seed")]
89-
public uint Seed = 42;
95+
public uint Seed = Defaults.Seed;
9096
}
9197

9298
private sealed class Bindings : ColumnBindingsBase
@@ -250,6 +256,18 @@ private static VersionInfo GetVersionInfo()
250256

251257
private const string RegistrationName = "GenerateNumber";
252258

259+
/// <summary>
260+
/// Convenience constructor for public facing API.
261+
/// </summary>
262+
/// <param name="env">Host Environment.</param>
263+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
264+
/// <param name="name">Name of the output column.</param>
265+
/// <param name="useCounter">Use an auto-incremented integer starting at zero instead of a random number.</param>
266+
public GenerateNumberTransform(IHostEnvironment env, IDataView input, string name, bool useCounter = Defaults.UseCounter)
267+
: this(env, new Arguments() { Column = new[] { new Column() { Name = name } }, UseCounter = useCounter }, input)
268+
{
269+
}
270+
253271
/// <summary>
254272
/// Public constructor corresponding to SignatureDataTransform.
255273
/// </summary>

src/Microsoft.ML.Data/Transforms/HashTransform.cs

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,48 @@ public sealed class HashTransform : OneToOneTransformBase, ITransformTemplate
3333
public const int NumBitsMin = 1;
3434
public const int NumBitsLim = 32;
3535

36+
private static class Defaults
37+
{
38+
public const int HashBits = NumBitsLim - 1;
39+
public const uint Seed = 314489979;
40+
public const bool Ordered = false;
41+
public const int InvertHash = 0;
42+
}
43+
3644
public sealed class Arguments
3745
{
46+
public Arguments()
47+
{
48+
49+
}
50+
51+
public Arguments(string name, string source)
52+
{
53+
Column = new[] { new Column(){
54+
Source = source ?? name,
55+
Name = name
56+
}
57+
};
58+
}
59+
3860
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col",
3961
SortOrder = 1)]
4062
public Column[] Column;
4163

4264
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive",
4365
ShortName = "bits", SortOrder = 2)]
44-
public int HashBits = NumBitsLim - 1;
66+
public int HashBits = Defaults.HashBits;
4567

4668
[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
47-
public uint Seed = 314489979;
69+
public uint Seed = Defaults.Seed;
4870

4971
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash",
5072
ShortName = "ord")]
51-
public bool Ordered;
73+
public bool Ordered = Defaults.Ordered;
5274

5375
[Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.",
5476
ShortName = "ih")]
55-
public int InvertHash;
77+
public int InvertHash = Defaults.InvertHash;
5678
}
5779

5880
public sealed class Column : OneToOneColumn
@@ -234,6 +256,25 @@ public override void Save(ModelSaveContext ctx)
234256
TextModelHelper.SaveAll(Host, ctx, Infos.Length, _keyValues);
235257
}
236258

259+
/// <summary>
260+
/// Convenience constructor for public facing API.
261+
/// </summary>
262+
/// <param name="env">Host Environment.</param>
263+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
264+
/// <param name="name">Name of the output column.</param>
265+
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
266+
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 31, inclusive.</param>
267+
/// <param name="invertHash">Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.</param>
268+
public HashTransform(IHostEnvironment env,
269+
IDataView input,
270+
string name,
271+
string source = null,
272+
int hashBits = Defaults.HashBits,
273+
int invertHash = Defaults.InvertHash)
274+
: this(env, new Arguments(name, source) { HashBits = hashBits, InvertHash = invertHash }, input)
275+
{
276+
}
277+
237278
public HashTransform(IHostEnvironment env, Arguments args, IDataView input)
238279
: base(Contracts.CheckRef(env, nameof(env)), RegistrationName, env.CheckRef(args, nameof(args)).Column,
239280
input, TestType)

src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ private static VersionInfo GetVersionInfo()
7373
private readonly ColumnType[] _types;
7474
private KeyToValueMap[] _kvMaps;
7575

76+
/// <summary>
77+
/// Convenience constructor for public facing API.
78+
/// </summary>
79+
/// <param name="env">Host Environment.</param>
80+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
81+
/// <param name="name">Name of the output column.</param>
82+
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
83+
public KeyToValueTransform(IHostEnvironment env, IDataView input, string name, string source = null)
84+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
85+
{
86+
}
87+
88+
7689
/// <summary>
7790
/// Public constructor corresponding to SignatureDataTransform.
7891
/// </summary>

src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,19 @@ public bool TryUnparse(StringBuilder sb)
7070
}
7171
}
7272

73+
private static class Defaults
74+
{
75+
public const bool Bag = false;
76+
}
77+
7378
public sealed class Arguments
7479
{
7580
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
7681
public Column[] Column;
7782

7883
[Argument(ArgumentType.AtMostOnce,
7984
HelpText = "Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector.")]
80-
public bool Bag;
85+
public bool Bag = Defaults.Bag;
8186
}
8287

8388
internal const string Summary = "Converts a key column to an indicator vector.";
@@ -112,6 +117,23 @@ private static VersionInfo GetVersionInfo()
112117
private readonly bool[] _concat;
113118
private readonly VectorType[] _types;
114119

120+
/// <summary>
121+
/// Convenience constructor for public facing API.
122+
/// </summary>
123+
/// <param name="env">Host Environment.</param>
124+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
125+
/// <param name="name">Name of the output column.</param>
126+
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
127+
/// <param name="bag">Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector.</param>
128+
public KeyToVectorTransform(IHostEnvironment env,
129+
IDataView input,
130+
string name,
131+
string source = null,
132+
bool bag = Defaults.Bag)
133+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, Bag = bag }, input)
134+
{
135+
}
136+
115137
/// <summary>
116138
/// Public constructor corresponding to SignatureDataTransform.
117139
/// </summary>

src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,18 @@ private static VersionInfo GetVersionInfo()
6464
private const string RegistrationName = "LabelConvert";
6565
private VectorType _slotType;
6666

67+
/// <summary>
68+
/// Convenience constructor for public facing API.
69+
/// </summary>
70+
/// <param name="env">Host Environment.</param>
71+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
72+
/// <param name="name">Name of the output column.</param>
73+
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
74+
public LabelConvertTransform(IHostEnvironment env, IDataView input, string name, string source = null)
75+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
76+
{
77+
}
78+
6779
public LabelConvertTransform(IHostEnvironment env, Arguments args, IDataView input)
6880
: base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input, RowCursorUtils.TestGetLabelGetter)
6981
{

src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,18 @@ public bool TryUnparse(StringBuilder sb)
6464
}
6565
}
6666

67+
private static class Defaults
68+
{
69+
public const int ClassIndex = 0;
70+
}
71+
6772
public sealed class Arguments : TransformInputBase
6873
{
6974
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
7075
public Column[] Column;
7176

7277
[Argument(ArgumentType.AtMostOnce, HelpText = "Label of the positive class.", ShortName = "index")]
73-
public int ClassIndex;
78+
public int ClassIndex = Defaults.ClassIndex;
7479
}
7580

7681
public static LabelIndicatorTransform Create(IHostEnvironment env,
@@ -111,6 +116,23 @@ private static string TestIsMulticlassLabel(ColumnType type)
111116
return $"Label column type is not supported for binary remapping: {type}. Supported types: key, float, double.";
112117
}
113118

119+
/// <summary>
120+
/// Convenience constructor for public facing API.
121+
/// </summary>
122+
/// <param name="env">Host Environment.</param>
123+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
124+
/// <param name="name">Name of the output column.</param>
125+
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
126+
/// <param name="classIndex">Label of the positive class.</param>
127+
public LabelIndicatorTransform(IHostEnvironment env,
128+
IDataView input,
129+
string name,
130+
string source = null,
131+
int classIndex = Defaults.ClassIndex)
132+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ClassIndex = classIndex }, input)
133+
{
134+
}
135+
114136
public LabelIndicatorTransform(IHostEnvironment env, Arguments args, IDataView input)
115137
: base(env, LoadName, Contracts.CheckRef(args, nameof(args)).Column,
116138
input, TestIsMulticlassLabel)

src/Microsoft.ML.Data/Transforms/RangeFilter.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,19 @@ private static VersionInfo GetVersionInfo()
7777
private readonly bool _includeMin;
7878
private readonly bool _includeMax;
7979

80+
/// <summary>
81+
/// Convenience constructor for public facing API.
82+
/// </summary>
83+
/// <param name="env">Host Environment.</param>
84+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
85+
/// <param name="column">Name of the input column.</param>
86+
/// <param name="minimum">Minimum value (0 to 1 for key types).</param>
87+
/// <param name="maximum">Maximum value (0 to 1 for key types).</param>
88+
public RangeFilter(IHostEnvironment env, IDataView input, string column, Double? minimum = null, Double? maximum = null)
89+
: this(env, new Arguments() { Column = column, Min = minimum, Max = maximum }, input)
90+
{
91+
}
92+
8093
public RangeFilter(IHostEnvironment env, Arguments args, IDataView input)
8194
: base(env, RegistrationName, input)
8295
{

0 commit comments

Comments
 (0)