Skip to content

[Part 3] Added convenience constructors for set of transforms. #520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/Microsoft.ML.Transforms/GroupTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ public sealed class Arguments : TransformInputBase

private readonly GroupSchema _schema;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="groupKey">Columns to group by</param>
/// <param name="columns">Columns to group together</param>
public GroupTransform(IHostEnvironment env, IDataView input, string groupKey, params string[] columns)
: this(env, new Arguments() { GroupKey = new[] { groupKey }, Column = columns }, input)
{
}

public GroupTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, input)
{
Expand Down
35 changes: 31 additions & 4 deletions src/Microsoft.ML.Transforms/HashJoinTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ public sealed class HashJoinTransform : OneToOneTransformBase
public const int NumBitsMin = 1;
public const int NumBitsLim = 32;

private static class Defaults
{
public const bool Join = true;
public const int HashBits = NumBitsLim - 1;
public const uint Seed = 314489979;
public const bool Ordered = true;
}

public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)",
Expand All @@ -45,17 +53,17 @@ public sealed class Arguments : TransformInputBase
public Column[] Column;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the values need to be combined for a single hash")]
public bool Join = true;
public bool Join = Defaults.Join;

[Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive.",
ShortName = "bits", SortOrder = 2)]
public int HashBits = NumBitsLim - 1;
public int HashBits = Defaults.HashBits;

[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
public uint Seed = 314489979;
public uint Seed = Defaults.Seed;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash", ShortName = "ord")]
public bool Ordered = true;
public bool Ordered = Defaults.Ordered;
}

public sealed class Column : OneToOneColumn
Expand Down Expand Up @@ -166,6 +174,25 @@ private static VersionInfo GetVersionInfo()

private readonly ColumnInfoEx[] _exes;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
/// <param name="join">Whether the values need to be combined for a single hash.</param>
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 31, inclusive.</param>
public HashJoinTransform(IHostEnvironment env,
IDataView input,
string name,
string source = null,
bool join = Defaults.Join,
int hashBits = Defaults.HashBits)
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, Join = join, HashBits = hashBits }, input)
{
}

public HashJoinTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input, TestColumnType)
{
Expand Down
12 changes: 12 additions & 0 deletions src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ private static VersionInfo GetVersionInfo()

private readonly VectorType[] _types;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
public KeyToBinaryVectorTransform(IHostEnvironment env, IDataView input, string name, string source = null)
: this(env, new Arguments() { Column = new[] { new KeyToVectorTransform.Column() { Source = source ?? name, Name = name } } }, input)
{
}

/// <summary>
/// Public constructor corresponding to SignatureDataTransform.
/// </summary>
Expand Down
19 changes: 19 additions & 0 deletions src/Microsoft.ML.Transforms/LoadTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,25 @@ public class Arguments

internal const string Summary = "Loads specified transforms from the model file and applies them to current data.";

/// <summary>
/// A helper method to create <see cref="LoadTransform"/> for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="modelFile">Model file to load the transforms from.</param>
/// <param name="tag">The tags (comma-separated) to be loaded (or omitted, if complement is true).</param>
/// <param name="complement">Whether to load all transforms except those marked by tags.</param>
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string[] tag, bool complement = false)
{
var args = new Arguments()
{
ModelFile = modelFile,
Tag = tag,
Complement = complement
};
return Create(env, args, input);
}

public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
Expand Down
39 changes: 36 additions & 3 deletions src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ public static class MutualInformationFeatureSelectionTransform
public const string UserName = "Mutual Information Feature Selection Transform";
public const string ShortName = "MIFeatureSelection";

private static class Defaults
{
public const string LabelColumn = DefaultColumnNames.Label;
public const int SlotsInOutput = 1000;
public const int NumBins = 256;
}

public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Columns to use for feature selection", ShortName = "col",
Expand All @@ -41,19 +48,45 @@ public sealed class Arguments : TransformInputBase

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for labels", ShortName = "lab",
SortOrder = 4, Purpose = SpecialPurpose.ColumnName)]
public string LabelColumn = DefaultColumnNames.Label;
public string LabelColumn = Defaults.LabelColumn;

[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of slots to preserve in output", ShortName = "topk,numSlotsToKeep",
SortOrder = 1)]
public int SlotsInOutput = 1000;
public int SlotsInOutput = Defaults.SlotsInOutput;

[Argument(ArgumentType.AtMostOnce, HelpText = "Max number of bins for R4/R8 columns, power of 2 recommended",
ShortName = "bins")]
public int NumBins = 256;
public int NumBins = Defaults.NumBins;
}

internal static string RegistrationName = "MutualInformationFeatureSelectionTransform";

/// <summary>
/// A helper method to create <see cref="IDataTransform"/> for selecting the top k slots ordered by their mutual information.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="labelColumn">Column to use for labels.</param>
/// <param name="slotsInOutput">The maximum number of slots to preserve in output.</param>
/// <param name="numBins">Max number of bins for R4/R8 columns, power of 2 recommended.</param>
/// <param name="columns">Columns to use for feature selection.</param>
public static IDataTransform Create(IHostEnvironment env,
IDataView input,
string labelColumn = Defaults.LabelColumn,
int slotsInOutput = Defaults.SlotsInOutput,
int numBins = Defaults.NumBins,
params string[] columns)
{
var args = new Arguments()
{
Column = columns,
LabelColumn = labelColumn,
SlotsInOutput = slotsInOutput,
NumBins = numBins
};
return Create(env, args, input);
}

/// <summary>
/// Create method corresponding to SignatureDataTransform.
/// </summary>
Expand Down
12 changes: 12 additions & 0 deletions src/Microsoft.ML.Transforms/NADropTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ private static VersionInfo GetVersionInfo()
// The isNA delegates, parallel to Infos.
private readonly Delegate[] _isNAs;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
public NADropTransform(IHostEnvironment env, IDataView input, string name, string source = null)
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
{
}

public NADropTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(Contracts.CheckRef(env, nameof(env)), RegistrationName, env.CheckRef(args, nameof(args)).Column, input, TestType)
{
Expand Down
36 changes: 36 additions & 0 deletions src/Microsoft.ML.Transforms/NAHandleTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,25 @@ public static class NAHandleTransform
{
public enum ReplacementKind
{
/// <summary>
/// Replace with the default value of the column based on it's type. For example, 'zero' for numeric and 'empty' for string/text columns.
/// </summary>
[EnumValueDisplay("Zero/empty")]
DefaultValue,

/// <summary>
/// Replace with the mean value of the column. Supports only numeric/time span/ DateTime columns.
/// </summary>
Mean,

/// <summary>
/// Replace with the minimum value of the column. Supports only numeric/time span/ DateTime columns.
/// </summary>
Minimum,

/// <summary>
/// Replace with the maximum value of the column. Supports only numeric/time span/ DateTime columns.
/// </summary>
Maximum,

[HideEnumValue]
Expand Down Expand Up @@ -105,6 +120,27 @@ public bool TryUnparse(StringBuilder sb)
internal const string FriendlyName = "NA Handle Transform";
internal const string ShortName = "NAHandle";

/// <summary>
/// A helper method to create <see cref="NAHandleTransform"/> for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
/// <param name="replaceWith">The replacement method to utilize.</param>
public static IDataTransform Create(IHostEnvironment env, IDataView input, string name, string source = null, ReplacementKind replaceWith = ReplacementKind.DefaultValue)
{
var args = new Arguments()
{
Column = new[]
{
new Column() { Source = source ?? name, Name = name }
},
ReplaceWith = replaceWith
};
return Create(env, args, input);
}

public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
Expand Down
12 changes: 12 additions & 0 deletions src/Microsoft.ML.Transforms/NAIndicatorTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ private static string TestType(ColumnType type)
// The output column types, parallel to Infos.
private readonly ColumnType[] _types;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
public NAIndicatorTransform(IHostEnvironment env, IDataView input, string name, string source = null)
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
{
}

/// <summary>
/// Public constructor corresponding to SignatureDataTransform.
/// </summary>
Expand Down
13 changes: 13 additions & 0 deletions src/Microsoft.ML.Transforms/NAReplaceTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,19 @@ private static string TestType<T>(ColumnType type)

public override bool CanSaveOnnx => true;

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
/// <param name="replacementKind">The replacement method to utilize.</param>
public NAReplaceTransform(IHostEnvironment env, IDataView input, string name, string source = null, ReplacementKind replacementKind = ReplacementKind.DefaultValue)
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ReplacementKind = replacementKind }, input)
{
}

/// <summary>
/// Public constructor corresponding to SignatureDataTransform.
/// </summary>
Expand Down
20 changes: 19 additions & 1 deletion src/Microsoft.ML.Transforms/OptionalColumnTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@

namespace Microsoft.ML.Runtime.DataPipe
{
public class OptionalColumnTransform : RowToRowMapperTransformBase
/// <summary>
/// This transform is used to mark some of the columns (e.g. Label) optional during training so that the columns are not required during scoring.
/// When applied to new data, if any of the optional columns is not present a dummy columns is created having the same properties (e.g. 'name', 'type' etc.) as used during training.
/// The columns are filled with default values. The value is
/// - scalar for scalar column
/// - totally sparse vector for vector column.
/// </summary>
public sealed class OptionalColumnTransform : RowToRowMapperTransformBase
{
public sealed class Arguments : TransformInputBase
{
Expand Down Expand Up @@ -232,6 +239,17 @@ private static VersionInfo GetVersionInfo()

private const string RegistrationName = "OptionalColumn";

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="columns">Columns to transform.</param>
public OptionalColumnTransform(IHostEnvironment env, IDataView input, params string[] columns)
: this(env, new Arguments() { Column = columns }, input)
{
}

/// <summary>
/// Public constructor corresponding to SignatureDataTransform.
/// </summary>
Expand Down
27 changes: 25 additions & 2 deletions src/Microsoft.ML.Transforms/RffTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,26 @@ namespace Microsoft.ML.Runtime.Data

public sealed class RffTransform : OneToOneTransformBase
{
private static class Defaults
{
public const int NewDim = 1000;
public const bool UseSin = false;
}

public sealed class Arguments
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
public Column[] Column;

[Argument(ArgumentType.AtMostOnce, HelpText = "The number of random Fourier features to create", ShortName = "dim")]
public int NewDim = 1000;
public int NewDim = Defaults.NewDim;

[Argument(ArgumentType.Multiple, HelpText = "which kernel to use?", ShortName = "kernel")]
public SubComponent<IFourierDistributionSampler, SignatureFourierDistributionSampler> MatrixGenerator =
new SubComponent<IFourierDistributionSampler, SignatureFourierDistributionSampler>(GaussianFourierSampler.LoadName);

[Argument(ArgumentType.AtMostOnce, HelpText = "create two features for every random Fourier frequency? (one for cos and one for sin)")]
public bool UseSin = false;
public bool UseSin = Defaults.UseSin;

[Argument(ArgumentType.LastOccurenceWins,
HelpText = "The seed of the random number generator for generating the new features (if unspecified, " +
Expand Down Expand Up @@ -232,6 +238,23 @@ private static string TestColumnType(ColumnType type)
return "Expected R4 or vector of R4 with known size";
}

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="newDim">The number of random Fourier features to create.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
public RffTransform(IHostEnvironment env,
IDataView input,
int newDim,
string name,
string source = null)
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, NewDim = newDim }, input)
{
}

/// <summary>
/// Public constructor corresponding to <see cref="SignatureDataTransform"/>.
/// </summary>
Expand Down
Loading