Skip to content

Added convenience constructor for set of transforms (#380). #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ public bool TryUnparse(StringBuilder sb)

public sealed class Arguments : TransformInputBase
{
public Arguments()
{
}

public Arguments(string name, params string[] source)
{
Column = new[] { new Column()
{
Name = name,
Source = source
}};
}

[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:srcs)", ShortName = "col", SortOrder = 1)]
public Column[] Column;
}
Expand Down Expand Up @@ -527,6 +540,18 @@ private static VersionInfo GetVersionInfo()

public override ISchema Schema => _bindings;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Input columns to concatenate.</param>
public ConcatTransform(IHostEnvironment env, IDataView input, string name, params string[] source)
: this(env, new Arguments(name, source), input)
{
}

/// <summary>
/// Public constructor corresponding to SignatureDataTransform.
/// </summary>
Expand Down
12 changes: 12 additions & 0 deletions src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ private static VersionInfo GetVersionInfo()

private const string RegistrationName = "CopyColumns";

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be copied.</param>
public CopyColumnsTransform(IHostEnvironment env, IDataView input, string name, string source)
: this(env, new Arguments(){ Column = new[] { new Column() { Source = source, Name = name }}}, input)
{
}

public CopyColumnsTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, null)
{
Expand Down
24 changes: 24 additions & 0 deletions src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,17 @@ private static VersionInfo GetVersionInfo()
private const string DropRegistrationName = "DropColumns";
private const string KeepRegistrationName = "KeepColumns";

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="columnsToDrop">Name of the columns to be dropped.</param>
public DropColumnsTransform(IHostEnvironment env, IDataView input, params string[] columnsToDrop)
:this(env, new Arguments() { Column = columnsToDrop }, input)
{
}

/// <summary>
/// Public constructor corresponding to SignatureDataTransform.
/// </summary>
Expand Down Expand Up @@ -383,4 +394,17 @@ public ValueGetter<TValue> GetGetter<TValue>(int col)
}
}
}

public class KeepColumnsTransform
{
/// <summary>
/// A helper method to create <see cref="KeepColumnsTransform"/> for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="columnsToKeep">Name of the columns to be kept. All other columns will be removed.</param>
/// <returns></returns>
public static IDataTransform Create(IHostEnvironment env, IDataView input, params string[] columnsToKeep)
=> new DropColumnsTransform(env, new DropColumnsTransform.KeepArguments() { Column = columnsToKeep }, input);
}
}
19 changes: 18 additions & 1 deletion src/Microsoft.ML.Data/Transforms/NAFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@ namespace Microsoft.ML.Runtime.Data
{
public sealed class NAFilter : FilterBase
{
private static class Defaults
{
public const bool Complement = false;
}

public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Column", ShortName = "col", SortOrder = 1)]
public string[] Column;

[Argument(ArgumentType.Multiple, HelpText = "If true, keep only rows that contain NA values, and filter the rest.")]
public bool Complement;
public bool Complement = Defaults.Complement;
}

private sealed class ColInfo
Expand Down Expand Up @@ -72,6 +77,18 @@ private static VersionInfo GetVersionInfo()
private readonly bool _complement;
private const string RegistrationName = "MissingValueFilter";

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="complement">If true, keep only rows that contain NA values, and filter the rest.</param>
/// <param name="columns">Name of the columns. Only these columns will be used to filter rows having 'NA' values.</param>
public NAFilter(IHostEnvironment env, IDataView input, bool complement = Defaults.Complement, params string[] columns)
: this(env, new Arguments() { Column = columns, Complement = complement }, input)
{
}

public NAFilter(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, input)
{
Expand Down
32 changes: 29 additions & 3 deletions src/Microsoft.ML.Transforms/BootstrapSampleTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,27 @@ namespace Microsoft.ML.Runtime.Data
/// </summary>
public sealed class BootstrapSampleTransform : FilterBase
{
private static class Defaults
{
public const bool Complement = false;
public const bool ShuffleInput = true;
public const int PoolSize = 1000;
}

public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether this is the out-of-bag sample, that is, all those rows that are not selected by the transform.",
ShortName = "comp")]
public bool Complement;
public bool Complement = Defaults.Complement;

[Argument(ArgumentType.AtMostOnce, HelpText = "The random seed. If unspecified random state will be instead derived from the environment.")]
public uint? Seed;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether we should attempt to shuffle the source data. By default on, but can be turned off for efficiency.", ShortName = "si")]
public bool ShuffleInput = true;
public bool ShuffleInput = Defaults.ShuffleInput;

[Argument(ArgumentType.LastOccurenceWins, HelpText = "When shuffling the output, the number of output rows to keep in that pool. Note that shuffling of output is completely distinct from shuffling of input.", ShortName = "pool")]
public int PoolSize = 1000;
public int PoolSize = Defaults.PoolSize;
}

internal const string Summary = "Approximate bootstrap sampling.";
Expand Down Expand Up @@ -76,6 +83,25 @@ public BootstrapSampleTransform(IHostEnvironment env, Arguments args, IDataView
_poolSize = args.PoolSize;
}

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="complement">Whether this is the out-of-bag sample, that is, all those rows that are not selected by the transform.</param>
/// <param name="seed">The random seed. If unspecified random state will be instead derived from the environment.</param>
/// <param name="shuffleInput">Whether we should attempt to shuffle the source data. By default on, but can be turned off for efficiency.</param>
/// <param name="poolSize">When shuffling the output, the number of output rows to keep in that pool. Note that shuffling of output is completely distinct from shuffling of input.</param>
public BootstrapSampleTransform(IHostEnvironment env,
IDataView input,
bool complement = Defaults.Complement,
uint? seed = null,
bool shuffleInput = Defaults.ShuffleInput,
int poolSize = Defaults.PoolSize)
: this(env, new Arguments() { Complement = complement, Seed = seed, ShuffleInput = shuffleInput, PoolSize = poolSize }, input)
{
}

private BootstrapSampleTransform(IHost host, ModelLoadContext ctx, IDataView input)
: base(host, input)
{
Expand Down
51 changes: 46 additions & 5 deletions src/Microsoft.ML.Transforms/CategoricalHashTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ public bool TryUnparse(StringBuilder sb)
}
}

private static class Defaults
{
public const int HashBits = 16;
public const uint Seed = 314489979;
public const bool Ordered = true;
public const int InvertHash = 0;
public const CategoricalTransform.OutputKind OutputKind = CategoricalTransform.OutputKind.Bag;
}

/// <summary>
/// This class is a merger of <see cref="HashTransform.Arguments"/> and <see cref="KeyToVectorTransform.Arguments"/>
/// with join option removed
Expand All @@ -97,29 +106,61 @@ public sealed class Arguments : TransformInputBase

[Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 30, inclusive.",
ShortName = "bits", SortOrder = 2)]
public int HashBits = 16;
public int HashBits = Defaults.HashBits;

[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
public uint Seed = 314489979;
public uint Seed = Defaults.Seed;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash", ShortName = "ord")]
public bool Ordered = true;
public bool Ordered = Defaults.Ordered;

[Argument(ArgumentType.AtMostOnce,
HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.",
ShortName = "ih")]
public int InvertHash;
public int InvertHash = Defaults.InvertHash;

[Argument(ArgumentType.AtMostOnce, HelpText = "Output kind: Bag (multi-set vector), Ind (indicator vector), or Key (index)",
ShortName = "kind", SortOrder = 102)]
public CategoricalTransform.OutputKind OutputKind = CategoricalTransform.OutputKind.Bag;
public CategoricalTransform.OutputKind OutputKind = Defaults.OutputKind;
}

internal const string Summary = "Converts the categorical value into an indicator array by hashing the value and using the hash as an index in the "
+ "bag. If the input column is a vector, a single indicator bag is returned for it.";

public const string UserName = "Categorical Hash Transform";

/// <summary>
/// A helper method to create <see cref="CategoricalHashTransform"/> for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 30, inclusive.</param>
/// <param name="invertHash">Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.</param>
/// <param name="outputKind">The type of output expected.</param>
public static IDataTransform Create(IHostEnvironment env,
IDataView input,
string name,
string source =null,
int hashBits = Defaults.HashBits,
int invertHash = Defaults.InvertHash,
CategoricalTransform.OutputKind outputKind = Defaults.OutputKind)
{
var args = new Arguments()
{
Column = new[] { new Column(){
Source = source ?? name,
Name = name
}
},
HashBits = hashBits,
InvertHash = invertHash,
OutputKind = outputKind
};
return Create(env, args, input);
}

public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
Expand Down
41 changes: 40 additions & 1 deletion src/Microsoft.ML.Transforms/CategoricalTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,27 @@ public static class CategoricalTransform
{
public enum OutputKind : byte
{
/// <summary>
/// Output is a bag (multi-set) vector
/// </summary>
[TGUI(Label = "Output is a bag (multi-set) vector")]
Bag = 1,

/// <summary>
/// Output is an indicator vector
/// </summary>
[TGUI(Label = "Output is an indicator vector")]
Ind = 2,

/// <summary>
/// Output is a key value
/// </summary>
[TGUI(Label = "Output is a key value")]
Key = 3,

/// <summary>
/// Output is binary encoded
/// </summary>
[TGUI(Label = "Output is binary encoded")]
Bin = 4,
}
Expand Down Expand Up @@ -96,14 +108,19 @@ public bool TryUnparse(StringBuilder sb)
}
}

private static class Defaults
{
public const OutputKind OutKind = OutputKind.Ind;
}

public sealed class Arguments : TermTransform.ArgumentsBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
public Column[] Column;

[Argument(ArgumentType.AtMostOnce, HelpText = "Output kind: Bag (multi-set vector), Ind (indicator vector), or Key (index)",
ShortName = "kind", SortOrder = 102)]
public OutputKind OutputKind = OutputKind.Ind;
public OutputKind OutputKind = Defaults.OutKind;

public Arguments()
{
Expand All @@ -118,6 +135,28 @@ public Arguments()

public const string UserName = "Categorical Transform";

/// <summary>
/// A helper method to create <see cref="CategoricalTransform"/> for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
/// <param name="outputKind">The type of output expected.</param>
public static IDataTransform Create(IHostEnvironment env, IDataView input, string name, string source = null, OutputKind outputKind = Defaults.OutKind)
{
var args = new Arguments()
{
Column = new[] { new Column(){
Source = source ?? name,
Name = name
}
},
OutputKind = outputKind
};
return Create(env, args, input);
}

public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
Expand Down
25 changes: 24 additions & 1 deletion src/Microsoft.ML.Transforms/CountFeatureSelection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,40 @@ public static class CountFeatureSelectionTransform
public const string Summary = "Selects the slots for which the count of non-default values is greater than or equal to a threshold.";
public const string UserName = "Count Feature Selection Transform";

private static class Defaults
{
public const long Count = 1;
}

public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Columns to use for feature selection", ShortName = "col", SortOrder = 1)]
public string[] Column;

[Argument(ArgumentType.Required, HelpText = "If the count of non-default values for a slot is greater than or equal to this threshold, the slot is preserved", ShortName = "c", SortOrder = 1)]
public long Count = 1;
public long Count = Defaults.Count;
}

internal static string RegistrationName = "CountFeatureSelectionTransform";

/// <summary>
/// A helper method to create CountFeatureSelection transform for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="count">If the count of non-default values for a slot is greater than or equal to this threshold, the slot is preserved.</param>
/// <param name="columns">Columns to use for feature selection.</param>
/// <returns></returns>
public static IDataTransform Create(IHostEnvironment env, IDataView input, long count = Defaults.Count, params string[] columns)
{
var args = new Arguments()
{
Column = columns,
Count = count
};
return Create(env, args, input);
}

/// <summary>
/// Create method corresponding to SignatureDataTransform.
/// </summary>
Expand Down
Loading