Skip to content

Added convenience constructors for set of transforms. #491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ public bool TryUnparse(StringBuilder sb)

public sealed class Arguments
{
public Arguments()
{

}

internal Arguments(params string[] columns)
{
Column = new Column[columns.Length];
for (int i = 0; i < columns.Length; i++)
{
Column[i] = new Column() { Source = columns[i], Name = columns[i] };
}
}

[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
public Column[] Column;

Expand Down Expand Up @@ -442,6 +456,17 @@ private static VersionInfo GetVersionInfo()

private const string RegistrationName = "ChooseColumns";

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="columns">Names of the columns to choose.</param>
public ChooseColumnsTransform(IHostEnvironment env, IDataView input, params string[] columns)
: this(env, new Arguments(columns), input)
{
}

/// <summary>
/// Public constructor corresponding to SignatureDataTransform.
/// </summary>
Expand Down
17 changes: 17 additions & 0 deletions src/Microsoft.ML.Data/Transforms/ConvertTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,23 @@ private static VersionInfo GetVersionInfo()
// This is parallel to Infos.
private readonly ColInfoEx[] _exes;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="resultType">The expected type of the converted column.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be converted. If this is null '<paramref name="name"/>' will be used.</param>
public ConvertTransform(IHostEnvironment env,
IDataView input,
DataKind resultType,
string name,
string source = null)
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ResultType = resultType }, input)
{
}

public ConvertTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column,
input, null)
Expand Down
22 changes: 20 additions & 2 deletions src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,22 @@ private bool TryParse(string str)
}
}

private static class Defaults
{
public const bool UseCounter = false;
public const uint Seed = 42;
}

public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:seed)", ShortName = "col", SortOrder = 1)]
public Column[] Column;

[Argument(ArgumentType.AtMostOnce, HelpText = "Use an auto-incremented integer starting at zero instead of a random number", ShortName = "cnt")]
public bool UseCounter;
public bool UseCounter = Defaults.UseCounter;

[Argument(ArgumentType.AtMostOnce, HelpText = "The random seed")]
public uint Seed = 42;
public uint Seed = Defaults.Seed;
}

private sealed class Bindings : ColumnBindingsBase
Expand Down Expand Up @@ -250,6 +256,18 @@ private static VersionInfo GetVersionInfo()

private const string RegistrationName = "GenerateNumber";

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="useCounter">Use an auto-incremented integer starting at zero instead of a random number.</param>
public GenerateNumberTransform(IHostEnvironment env, IDataView input, string name, bool useCounter = Defaults.UseCounter)
: this(env, new Arguments() { Column = new[] { new Column() { Name = name } }, UseCounter = useCounter }, input)
{
}

/// <summary>
/// Public constructor corresponding to SignatureDataTransform.
/// </summary>
Expand Down
37 changes: 33 additions & 4 deletions src/Microsoft.ML.Data/Transforms/HashTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ public sealed class HashTransform : OneToOneTransformBase, ITransformTemplate
public const int NumBitsMin = 1;
public const int NumBitsLim = 32;

private static class Defaults
{
public const int HashBits = NumBitsLim - 1;
public const uint Seed = 314489979;
public const bool Ordered = false;
public const int InvertHash = 0;
}

public sealed class Arguments
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col",
Expand All @@ -41,18 +49,18 @@ public sealed class Arguments

[Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive",
ShortName = "bits", SortOrder = 2)]
public int HashBits = NumBitsLim - 1;
public int HashBits = Defaults.HashBits;

[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
public uint Seed = 314489979;
public uint Seed = Defaults.Seed;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash",
ShortName = "ord")]
public bool Ordered;
public bool Ordered = Defaults.Ordered;

[Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.",
ShortName = "ih")]
public int InvertHash;
public int InvertHash = Defaults.InvertHash;
}

public sealed class Column : OneToOneColumn
Expand Down Expand Up @@ -234,6 +242,27 @@ public override void Save(ModelSaveContext ctx)
TextModelHelper.SaveAll(Host, ctx, Infos.Length, _keyValues);
}

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 31, inclusive.</param>
/// <param name="invertHash">Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.</param>
public HashTransform(IHostEnvironment env,
IDataView input,
string name,
string source = null,
int hashBits = Defaults.HashBits,
int invertHash = Defaults.InvertHash)
: this(env, new Arguments() {
Column = new[] { new Column() { Source = source ?? name, Name = name } },
HashBits = hashBits, InvertHash = invertHash }, input)
{
}

public HashTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(Contracts.CheckRef(env, nameof(env)), RegistrationName, env.CheckRef(args, nameof(args)).Column,
input, TestType)
Expand Down
13 changes: 13 additions & 0 deletions src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ private static VersionInfo GetVersionInfo()
private readonly ColumnType[] _types;
private KeyToValueMap[] _kvMaps;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
public KeyToValueTransform(IHostEnvironment env, IDataView input, string name, string source = null)
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jul 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IHostEnvironment env, IDataView input, string name, string source = null [](start = 35, length = 72)

half of your files formatted in this way, half is one parameter for each line, why? #Closed

Copy link
Contributor Author

@zeahmed zeahmed Jul 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the line is going to be long enough to fit into the view, I format every parameter on separate line otherwise not.

For this particular, its fine to have it on one line. #Closed

: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
{
}


/// <summary>
/// Public constructor corresponding to SignatureDataTransform.
/// </summary>
Expand Down
24 changes: 23 additions & 1 deletion src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,19 @@ public bool TryUnparse(StringBuilder sb)
}
}

private static class Defaults
{
public const bool Bag = false;
}

public sealed class Arguments
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
public Column[] Column;

[Argument(ArgumentType.AtMostOnce,
HelpText = "Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector.")]
public bool Bag;
public bool Bag = Defaults.Bag;
}

internal const string Summary = "Converts a key column to an indicator vector.";
Expand Down Expand Up @@ -112,6 +117,23 @@ private static VersionInfo GetVersionInfo()
private readonly bool[] _concat;
private readonly VectorType[] _types;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
/// <param name="bag">Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector.</param>
public KeyToVectorTransform(IHostEnvironment env,
IDataView input,
string name,
string source = null,
bool bag = Defaults.Bag)
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, Bag = bag }, input)
{
}

/// <summary>
/// Public constructor corresponding to SignatureDataTransform.
/// </summary>
Expand Down
12 changes: 12 additions & 0 deletions src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "LabelConvert";
private VectorType _slotType;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
public LabelConvertTransform(IHostEnvironment env, IDataView input, string name, string source = null)
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
{
}

public LabelConvertTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input, RowCursorUtils.TestGetLabelGetter)
{
Expand Down
17 changes: 17 additions & 0 deletions src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,23 @@ private static string TestIsMulticlassLabel(ColumnType type)
return $"Label column type is not supported for binary remapping: {type}. Supported types: key, float, double.";
}

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="classIndex">Label of the positive class.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
public LabelIndicatorTransform(IHostEnvironment env,
IDataView input,
int classIndex,
string name,
string source = null)
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ClassIndex = classIndex }, input)
{
}

public LabelIndicatorTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, LoadName, Contracts.CheckRef(args, nameof(args)).Column,
input, TestIsMulticlassLabel)
Expand Down
13 changes: 13 additions & 0 deletions src/Microsoft.ML.Data/Transforms/RangeFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ private static VersionInfo GetVersionInfo()
private readonly bool _includeMin;
private readonly bool _includeMax;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="column">Name of the input column.</param>
/// <param name="minimum">Minimum value (0 to 1 for key types).</param>
/// <param name="maximum">Maximum value (0 to 1 for key types).</param>
public RangeFilter(IHostEnvironment env, IDataView input, string column, Double? minimum = null, Double? maximum = null)
: this(env, new Arguments() { Column = column, Min = minimum, Max = maximum }, input)
{
}

public RangeFilter(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, input)
{
Expand Down
30 changes: 27 additions & 3 deletions src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,25 @@ namespace Microsoft.ML.Runtime.Data
/// </summary>
public sealed class ShuffleTransform : RowToRowTransformBase
{
private static class Defaults
{
public const int PoolRows = 1000;
public const bool PoolOnly = false;
public const bool ForceShuffle = false;
}

public sealed class Arguments
{
// REVIEW: A more intelligent heuristic, based on the expected size of the inputs, perhaps?
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The pool will have this many rows", ShortName = "rows")]
public int PoolRows = 1000;
public int PoolRows = Defaults.PoolRows;

// REVIEW: Come up with a better way to specify the desired set of functionality.
[Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will not attempt to shuffle the input cursor but only shuffle based on the pool. This parameter has no effect if the input data was not itself shufflable.", ShortName = "po")]
public bool PoolOnly;
public bool PoolOnly = Defaults.PoolOnly;

[Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will always provide a shuffled view.", ShortName = "force")]
public bool ForceShuffle;
public bool ForceShuffle = Defaults.ForceShuffle;

[Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will always shuffle the input. The default value is the same as forceShuffle.", ShortName = "forceSource")]
public bool? ForceShuffleSource;
Expand Down Expand Up @@ -79,6 +86,23 @@ private static VersionInfo GetVersionInfo()
// know how to copy other types of values.
private readonly IDataView _subsetInput;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="poolRows">The pool will have this many rows</param>
/// <param name="poolOnly">If true, the transform will not attempt to shuffle the input cursor but only shuffle based on the pool. This parameter has no effect if the input data was not itself shufflable.</param>
/// <param name="forceShuffle">If true, the transform will always provide a shuffled view.</param>
public ShuffleTransform(IHostEnvironment env,
IDataView input,
int poolRows = Defaults.PoolRows,
bool poolOnly = Defaults.PoolOnly,
bool forceShuffle = Defaults.ForceShuffle)
: this(env, new Arguments() { PoolRows = poolRows, PoolOnly = poolOnly, ForceShuffle = forceShuffle }, input)
{
}

/// <summary>
/// Public constructor corresponding to SignatureDataTransform.
/// </summary>
Expand Down
Loading