Skip to content

Added in new MissingValueReplacing method. #5205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/Microsoft.ML.Transforms/MissingValueReplacing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ internal enum ReplacementKind : byte
Minimum = 2,
Maximum = 3,
SpecifiedValue = 4,
Mode = 5,

[HideEnumValue]
Def = DefaultValue,
Expand Down Expand Up @@ -306,6 +307,7 @@ private void GetReplacementValues(IDataView input, MissingValueReplacingEstimato
case ReplacementKind.Mean:
case ReplacementKind.Minimum:
case ReplacementKind.Maximum:
case ReplacementKind.Mode:
if (!(type.GetItemType() is NumberDataViewType))
throw Host.Except("Cannot perform mean imputations on non-numeric '{0}'", type.GetItemType());
imputationModes[iinfo] = kind;
Expand Down Expand Up @@ -944,6 +946,10 @@ public enum ReplacementMode : byte
/// Replace with the maximum value of the column.
/// </summary>
Maximum = 3,
/// <summary>
/// Replace with the most frequent value of the column.
/// </summary>
Mode = 5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we skip 4 here? It went from 0, 1, 2, 3 and then jumped to 5.

}

[BestFriend]
Expand Down
203 changes: 202 additions & 1 deletion src/Microsoft.ML.Transforms/MissingValueReplacingUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Utilities;
Expand All @@ -26,13 +27,20 @@ private static StatAggregator CreateStatAggregator(IChannel ch, DataViewType typ
else if (type.RawType == typeof(double))
return new R8.MeanAggregatorOne(ch, cursor, col);
}
if (kind == ReplacementKind.Min || kind == ReplacementKind.Max)
else if (kind == ReplacementKind.Min || kind == ReplacementKind.Max)
{
if (type.RawType == typeof(float))
return new R4.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max);
else if (type.RawType == typeof(double))
return new R8.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max);
}
else if (kind == ReplacementKind.Mode)
{
if (type.RawType == typeof(float))
return new R4.ModeAggregatorOne(ch, cursor, col);
else if (type.RawType == typeof(double))
return new R8.ModeAggregatorOne(ch, cursor, col);
}
}
else if (bySlot)
{
Expand All @@ -55,6 +63,13 @@ private static StatAggregator CreateStatAggregator(IChannel ch, DataViewType typ
else if (vectorType.ItemType.RawType == typeof(double))
return new R8.MinMaxAggregatorBySlot(ch, vectorType, cursor, col, kind == ReplacementKind.Max);
}
else if (kind == ReplacementKind.Mode)
{
if (vectorType.ItemType.RawType == typeof(float))
return new R4.ModeAggregatorBySlot(ch, vectorType, cursor, col);
else if (vectorType.ItemType.RawType == typeof(double))
return new R8.ModeAggregatorBySlot(ch, vectorType, cursor, col);
}
}
else
{
Expand All @@ -73,6 +88,13 @@ private static StatAggregator CreateStatAggregator(IChannel ch, DataViewType typ
else if (vectorType.ItemType.RawType == typeof(double))
return new R8.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max);
}
else if (kind == ReplacementKind.Mode)
{
if (vectorType.ItemType.RawType == typeof(float))
return new R4.ModeAggregatorAcrossSlots(ch, cursor, col);
else if (vectorType.ItemType.RawType == typeof(double))
return new R8.ModeAggregatorAcrossSlots(ch, cursor, col);
}
}
ch.Assert(false);
throw ch.Except("Internal error, unrecognized imputation method ReplacementKind '{0}' or unrecognized type '{1}' " +
Expand Down Expand Up @@ -340,6 +362,55 @@ protected long GetValuesProcessed(int slot)
protected abstract void ProcessValueMax(in TItem val, int slot);
}

/// <summary>
/// A mutable struct for keeping the appropriate statistics for mode calculations, whose scope is restricted
/// and only exists as an instance in a field or an array, utilizing the mutation of the struct correctly.
/// </summary>
private class ModeStat<TType>
{
// Delegate used to check if the value is valid. We use a delegate so that this class can support modes of all types.
public delegate bool IsValid(in TType val);

private TType _modeSoFar;
private int _maxCount;
private Dictionary<TType, int> _valueCounts;
private IsValid _validityCheck;
public ModeStat(IsValid valid)
{
_modeSoFar = default;
_maxCount = 0;
_valueCounts = new Dictionary<TType, int>();
_validityCheck = valid;
}

public void Update(TType val)
{
// We don't include non finite values in the mode, so if its not finite then just return.
if (!_validityCheck(val))
return;

// If the key is already in the dictionary, we want to get the current count and increment it.
// If the key is not in the dictionary, we want to set count to 1 so the count is correct.
if (_valueCounts.TryGetValue(val, out int count))
count++;
else
count = 1;

_valueCounts[val] = count;

if (count > _maxCount)
{
_modeSoFar = val;
_maxCount = count;
}
}

public TType GetCurrentValue()
{
return _modeSoFar;
}
}

/// <summary>
/// This is a mutable struct (so is evil). However, its scope is restricted
/// and the only instances are in a field or an array, so the mutation does
Expand Down Expand Up @@ -618,6 +689,71 @@ public override object GetStat()
return Stat;
}
}

public sealed class ModeAggregatorOne : StatAggregator<float, ModeStat<float>>
{
public ModeAggregatorOne(IChannel ch, DataViewRowCursor cursor, int col)
:base(ch, cursor, col)
{
Stat = new ModeStat<float>((in float val) => FloatUtils.IsFinite(val));
}

public override object GetStat()
{
return Stat.GetCurrentValue();
}

protected override void ProcessRow(in float val)
{
Stat.Update(val);
}
}

public sealed class ModeAggregatorAcrossSlots : StatAggregatorAcrossSlots<float, ModeStat<float>>
{
public ModeAggregatorAcrossSlots(IChannel ch, DataViewRowCursor cursor, int col)
:base(ch, cursor, col)
{
Stat = new ModeStat<float>((in float val) => FloatUtils.IsFinite(val));
}

public override object GetStat()
{
return Stat.GetCurrentValue();
}

protected override void ProcessValue(in float val)
{
Stat.Update(val);
}
}

public sealed class ModeAggregatorBySlot : StatAggregatorBySlot<float, ModeStat<float>>
{
public ModeAggregatorBySlot(IChannel ch, VectorDataViewType type, DataViewRowCursor cursor, int col)
:base(ch, type, cursor, col)
{
for(int i = 0; i < Stat.Length; i++)
{
Stat[i] = new ModeStat<float>((in float val) => FloatUtils.IsFinite(val));
}
}

public override object GetStat()
{
float[] stat = new float[Stat.Length];
for (int slot = 0; slot < stat.Length; slot++)
{
stat[slot] = Stat[slot].GetCurrentValue();
}
return stat;
}

protected override void ProcessValue(in float val, int slot)
{
Stat[slot].Update(val);
}
}
}

private static class R8
Expand Down Expand Up @@ -777,6 +913,71 @@ public override object GetStat()
return Stat;
}
}

public sealed class ModeAggregatorOne : StatAggregator<double, ModeStat<double>>
{
public ModeAggregatorOne(IChannel ch, DataViewRowCursor cursor, int col)
: base(ch, cursor, col)
{
Stat = new ModeStat<double>((in double val) => FloatUtils.IsFinite(val));
}

public override object GetStat()
{
return Stat.GetCurrentValue();
}

protected override void ProcessRow(in double val)
{
Stat.Update(val);
}
}

public sealed class ModeAggregatorAcrossSlots : StatAggregatorAcrossSlots<double, ModeStat<double>>
{
public ModeAggregatorAcrossSlots(IChannel ch, DataViewRowCursor cursor, int col)
: base(ch, cursor, col)
{
Stat = new ModeStat<double>((in double val) => FloatUtils.IsFinite(val));
}

public override object GetStat()
{
return Stat.GetCurrentValue();
}

protected override void ProcessValue(in double val)
{
Stat.Update(val);
}
}

public sealed class ModeAggregatorBySlot : StatAggregatorBySlot<double, ModeStat<double>>
{
public ModeAggregatorBySlot(IChannel ch, VectorDataViewType type, DataViewRowCursor cursor, int col)
: base(ch, type, cursor, col)
{
for (int i = 0; i < Stat.Length; i++)
{
Stat[i] = new ModeStat<double>((in double val) => FloatUtils.IsFinite(val));
}
}

public override object GetStat()
{
double[] stat = new double[Stat.Length];
for (int slot = 0; slot < stat.Length; slot++)
{
stat[slot] = Stat[slot].GetCurrentValue();
}
return stat;
}

protected override void ProcessValue(in double val, int slot)
{
Stat[slot].Update(val);
}
}
}
}
}
8 changes: 5 additions & 3 deletions test/BaselineOutput/Common/EntryPoints/core_manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -21677,7 +21677,8 @@
"Mean",
"Minimum",
"Maximum",
"SpecifiedValue"
"SpecifiedValue",
"Mode"
]
},
"Desc": "The replacement method to utilize",
Expand Down Expand Up @@ -21747,7 +21748,8 @@
"Mean",
"Minimum",
"Maximum",
"SpecifiedValue"
"SpecifiedValue",
"Mode"
]
},
"Desc": "The replacement method to utilize",
Expand All @@ -21757,7 +21759,7 @@
"Required": false,
"SortOrder": 150.0,
"IsNullable": false,
"Default": "Default"
"Default": "Def"
},
{
"Name": "ImputeBySlot",
Expand Down
11 changes: 6 additions & 5 deletions test/BaselineOutput/Common/NAReplace/featurized.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
#@ col=B:R8:1
#@ col=C:R4:2-5
#@ col=D:R8:6-9
#@ col=E:R8:10-13
#@ }
A B 8 0:""
5 5 5 1 1 1 5 1 1 1
5 5 5 4 4 5 5 4 4 5
3 3 3 1 1 1 3 1 1 1
6 6 6 8 8 1 6 8 8 1
A B 12 0:""
5 5 5 1 1 1 5 1 1 1 5 1 1 1
5 5 5 4 4 5 5 4 4 5 5 4 4 5
3 3 3 1 1 1 3 1 1 1 3 1 1 1
6 6 6 8 8 1 6 8 8 1 6 8 8 1
Loading