Skip to content

NAReplace estimator #917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ private ColInfo[] CreateInfos(ISchema inputSchema)
if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input);
var type = inputSchema.GetColumnType(colSrc);
_parent.CheckInputColumn(inputSchema, i, colSrc);
infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type);
}
return infos;
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Legacy/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13983,7 +13983,7 @@ public MinMaxNormalizerPipelineStep(Output output)

namespace Legacy.Transforms
{
public enum NAHandleTransformReplacementKind
public enum NAHandleTransformReplacementKind : byte
{
DefaultValue = 0,
Mean = 1,
Expand Down Expand Up @@ -14444,7 +14444,7 @@ public MissingValuesRowDropperPipelineStep(Output output)

namespace Legacy.Transforms
{
public enum NAReplaceTransformReplacementKind
public enum NAReplaceTransformReplacementKind : byte
Copy link
Contributor

@Zruty0 Zruty0 Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

: byte [](start = 53, length = 7)

fix codegen? Or is it already good? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's already good


In reply to: 218248180 [](ancestors = 218248180)

{
DefaultValue = 0,
Mean = 1,
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum
Contracts.CheckValue(columns, nameof(columns));
return columns.Select(x => (x.Input, x.Output)).ToArray();
}

public IReadOnlyCollection<ColumnInfo> Columns => _columns.AsReadOnly();
private readonly ColumnInfo[] _columns;

Expand Down Expand Up @@ -209,7 +210,7 @@ private ColInfo[] CreateInfos(ISchema inputSchema)
if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input);
var type = inputSchema.GetColumnType(colSrc);

_parent.CheckInputColumn(inputSchema, i, colSrc);
infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type);
}
return infos;
Expand Down
46 changes: 11 additions & 35 deletions src/Microsoft.ML.Transforms/NAHandleTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,28 @@ namespace Microsoft.ML.Runtime.Data
/// <include file='doc.xml' path='doc/members/member[@name="NAHandle"]'/>
public static class NAHandleTransform
{
public enum ReplacementKind
public enum ReplacementKind : byte
{
/// <summary>
/// Replace with the default value of the column based on it's type. For example, 'zero' for numeric and 'empty' for string/text columns.
/// </summary>
[EnumValueDisplay("Zero/empty")]
DefaultValue,
DefaultValue = 0,

/// <summary>
/// Replace with the mean value of the column. Supports only numeric/time span/ DateTime columns.
/// </summary>
Mean,
Mean = 1,

/// <summary>
/// Replace with the minimum value of the column. Supports only numeric/time span/ DateTime columns.
/// </summary>
Minimum,
Minimum = 2,

/// <summary>
/// Replace with the maximum value of the column. Supports only numeric/time span/ DateTime columns.
/// </summary>
Maximum,
Maximum = 3,

[HideEnumValue]
Def = DefaultValue,
Expand Down Expand Up @@ -135,7 +135,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
h.CheckValue(input, nameof(input));
h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column));

var replaceCols = new List<NAReplaceTransform.Column>();
var replaceCols = new List<NAReplaceTransform.ColumnInfo>();
var naIndicatorCols = new List<NAIndicatorTransform.Column>();
var naConvCols = new List<ConvertTransform.Column>();
var concatCols = new List<ConcatTransform.TaggedColumn>();
Expand All @@ -149,26 +149,16 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
var addInd = column.ConcatIndicator ?? args.Concat;
if (!addInd)
{
replaceCols.Add(
new NAReplaceTransform.Column()
{
Kind = (NAReplaceTransform.ReplacementKind?)column.Kind,
Name = column.Name,
Source = column.Source,
Slot = column.ImputeBySlot
});
replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, column.Name, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));
continue;
}

// Check that the indicator column has a type that can be converted to the NAReplaceTransform output type,
// so that they can be concatenated.
int inputCol;
if (!input.Schema.TryGetColumnIndex(column.Source, out inputCol))
if (!input.Schema.TryGetColumnIndex(column.Source, out int inputCol))
throw h.Except("Column '{0}' does not exist", column.Source);
var replaceType = input.Schema.GetColumnType(inputCol);
Delegate conv;
bool identity;
if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out conv, out identity))
if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out Delegate conv, out bool identity))
{
throw h.Except("Cannot concatenate indicator column of type '{0}' to input column of type '{1}'",
BoolType.Instance, replaceType.ItemType);
Expand All @@ -186,14 +176,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
naConvCols.Add(new ConvertTransform.Column() { Name = tmpIsMissingColName, Source = tmpIsMissingColName, ResultType = replaceType.ItemType.RawKind });

// Add the NAReplaceTransform column.
replaceCols.Add(
new NAReplaceTransform.Column()
{
Kind = (NAReplaceTransform.ReplacementKind?)column.Kind,
Name = tmpReplacementColName,
Source = column.Source,
Slot = column.ImputeBySlot
});
replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, tmpReplacementColName, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));

// Add the ConcatTransform column.
if (replaceType.IsVector)
Expand Down Expand Up @@ -237,15 +220,8 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
h.AssertValue(output);
output = new ConvertTransform(h, new ConvertTransform.Arguments() { Column = naConvCols.ToArray() }, output);
}

// Create the NAReplace transform.
output = new NAReplaceTransform(h,
new NAReplaceTransform.Arguments()
{
Column = replaceCols.ToArray(),
ReplacementKind = (NAReplaceTransform.ReplacementKind)args.ReplaceWith,
ImputeBySlot = args.ImputeBySlot
}, output ?? input);
output = NAReplaceTransform.Create(env, output ?? input, replaceCols.ToArray());

// Concat the NAReplaceTransform output and the NAIndicatorTransform output.
if (naIndicatorCols.Count > 0)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Transforms/NAHandling.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public static CommonOutputs.TransformOutput Indicator(IHostEnvironment env, NAIn
public static CommonOutputs.TransformOutput Replace(IHostEnvironment env, NAReplaceTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "NAReplace", input);
var xf = new NAReplaceTransform(h, input, input.Data);
var xf = NAReplaceTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
Expand Down
Loading