Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Robust Scaler now added to the Normalizer catalog #5166

Merged
merged 7 commits into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ private static class Defaults
public const bool LogMeanVarCdf = true;
public const int NumBins = 1024;
public const int MinBinSize = 10;
public const bool CenterData = true;
public const int QuantileMin = 25;
public const int QuantileMax = 75;
}

public abstract class ControlZeroArgumentsBase : ArgumentsBase
Expand Down Expand Up @@ -245,6 +248,18 @@ public sealed class SupervisedBinArguments : BinArgumentsBase
public int MinBinSize = Defaults.MinBinSize;
}

public sealed class RobustScalingArguments : AffineArgumentsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Should the data be centered around 0", Name = "CenterData", ShortName = "center", SortOrder = 1)]
public bool CenterData = Defaults.CenterData;

[Argument(ArgumentType.AtMostOnce, HelpText = "Minimum quantile value. Defaults to 25", Name = "QuantileMin", ShortName = "qmin", SortOrder = 2)]
public uint QuantileMin = Defaults.QuantileMin;

[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum quantile value. Defaults to 75", Name = "QuantileMax", ShortName = "qmax", SortOrder = 3)]
public uint QuantileMax = Defaults.QuantileMax;
}

internal const string MinMaxNormalizerSummary = "Normalizes the data based on the observed minimum and maximum values of the data.";
internal const string MeanVarNormalizerSummary = "Normalizes the data based on the computed mean and variance of the data.";
internal const string LogMeanVarNormalizerSummary = "Normalizes the data based on the computed mean and variance of the logarithm of the data.";
Expand Down Expand Up @@ -1145,6 +1160,46 @@ public static int GetLabelColumnId(IExceptionContext host, DataViewSchema schema
return labelColumnId;
}
}

internal static partial class RobustScaleUtils
{
public static IColumnFunctionBuilder CreateBuilder(RobustScalingArguments args, IHost host,
int icol, int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
{
Contracts.AssertValue(host);
host.AssertValue(args);

return CreateBuilder(new NormalizingEstimator.RobustScalingColumnOptions(
args.Columns[icol].Name,
args.Columns[icol].Source ?? args.Columns[icol].Name,
args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
args.CenterData,
args.QuantileMin,
args.QuantileMax), host, srcIndex, srcType, cursor);
}

public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.RobustScalingColumnOptions column, IHost host,
int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
{
var srcColumn = cursor.Schema[srcIndex];
if (srcType is NumberDataViewType)
{
if (srcType == NumberDataViewType.Single)
return Sng.RobustScalerOneColumnFunctionBuilder.Create(column, host, srcType, column.CenterData, column.QuantileMin, column.QuantileMax, cursor.GetGetter<Single>(srcColumn));
if (srcType == NumberDataViewType.Double)
return Dbl.RobustScalerOneColumnFunctionBuilder.Create(column, host, srcType, column.CenterData, column.QuantileMin, column.QuantileMax, cursor.GetGetter<double>(srcColumn));
}
if (srcType is VectorDataViewType vectorType && vectorType.IsKnownSize && vectorType.ItemType is NumberDataViewType)
{
if (vectorType.ItemType == NumberDataViewType.Single)
return Sng.RobustScalerVecFunctionBuilder.Create(column, host, srcType as VectorDataViewType, column.CenterData, column.QuantileMin, column.QuantileMax, cursor.GetGetter<VBuffer<float>>(srcColumn));
if (vectorType.ItemType == NumberDataViewType.Double)
return Dbl.RobustScalerVecFunctionBuilder.Create(column, host, srcType as VectorDataViewType, column.CenterData, column.QuantileMin, column.QuantileMax, cursor.GetGetter<VBuffer<double>>(srcColumn));
}

throw host.ExceptParam(nameof(srcType), "Wrong column type for input column. Expected: Single, Double, Vector of Single or Vector of Double. Got: {0}.", srcType.ToString());
}
}
}

internal static partial class AffineNormSerializationUtils
Expand Down
249 changes: 249 additions & 0 deletions src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,117 @@ private void Update(int j, TFloat origVal)
}
}

[BestFriend]
internal static partial class MedianAggregatorUtils
{
/// <summary>
/// Based on the algorithm on GeeksForGeeks https://www.geeksforgeeks.org/median-of-stream-of-integers-running-integers/.
/// </summary>
/// <param name="num">The new number to account for in our median calculation.</param>
/// <param name="median">The current median.</param>
/// <param name="belowMedianHeap">The MaxHeap that has all the numbers below the median.</param>
/// <param name="aboveMedianHeap">The MinHeap that has all the numbers above the median.</param>
[BestFriend]
internal static void GetMedianSoFar(in double num, ref double median, ref MaxHeap<double> belowMedianHeap, ref MinHeap<double> aboveMedianHeap)
{
int comparison = belowMedianHeap.Count().CompareTo(aboveMedianHeap.Count());

if (comparison < 0)
{ // More elements in aboveMedianHeap than belowMedianHeap.
if (num < median)
{ // Current element belongs in the belowMedianHeap.
// Insert new number into belowMedianHeap
belowMedianHeap.Add(num);

}
else
{ // Current element belongs in aboveMedianHeap.
// Need to move one to belowMedianHeap to keep heeps balanced.
belowMedianHeap.Add(aboveMedianHeap.Pop());

aboveMedianHeap.Add(num);
}

// Both heaps are balanced so median is the average of the 2 heaps.
median = (aboveMedianHeap.Peek() + belowMedianHeap.Peek()) / 2;

}
else if (comparison == 0)
{ // Both heaps have the same number of elements. Simple put the number where it belongs.
if (num < median)
{ // Current element belongs in the belowMedianHeap.
belowMedianHeap.Add(num);

// Now we have an odd number of items, median is the new root of the belowMedianHeap
median = belowMedianHeap.Peek();

}
else
{ // Current element belongs in above median heap.
aboveMedianHeap.Add(num);

// Now we have an odd number of items, median is the new root of the aboveMedianHeap
median = aboveMedianHeap.Peek();
}

}
else
{ // More elements in belowMedianHeap than aboveMedianHeap.
if (num < median)
{ // Current element belongs in the belowMedianHeap.
// Need to move one to aboveMedianHeap to keep heeps balanced.
aboveMedianHeap.Add(belowMedianHeap.Pop());

// Insert new number into belowMedianHeap
belowMedianHeap.Add(num);

}
else
{ // Current element belongs in aboveMedianHeap.
aboveMedianHeap.Add(num);
}

// Both heaps are balanced so median is the average of the 2 heaps.
median = (aboveMedianHeap.Peek() + belowMedianHeap.Peek()) / 2;
}
}
}

/// <summary>
/// Base class for tracking median values for a single valued column.
/// It tracks median values of non-sparse values (vCount).
/// NaNs are ignored when updating min and max.
/// </summary>
internal sealed class MedianDblAggregator : IColumnAggregator<double>
{
private MedianAggregatorUtils.MaxHeap<double> _belowMedianHeap;
private MedianAggregatorUtils.MinHeap<double> _aboveMedianHeap;
private double _median;

public MedianDblAggregator(int contatinerStartingSize = 1000)
{
Contracts.Check(contatinerStartingSize > 0);
_belowMedianHeap = new MedianAggregatorUtils.MaxHeap<double>(contatinerStartingSize);
_aboveMedianHeap = new MedianAggregatorUtils.MinHeap<double>(contatinerStartingSize);
_median = default;
}

public double Median
{
get { return _median; }
}

public void ProcessValue(in double value)
{
MedianAggregatorUtils.GetMedianSoFar(value, ref _median, ref _belowMedianHeap, ref _aboveMedianHeap);
}

public void Finish()
{
// Finish is a no-op because we are updating the median continually as we go
}
}

internal sealed partial class NormalizeTransform
{
internal abstract partial class AffineColumnFunction
Expand Down Expand Up @@ -1912,6 +2023,144 @@ public static IColumnFunctionBuilder Create(NormalizingEstimator.SupervisedBinni
return new SupervisedBinVecColumnFunctionBuilder(host, lim, fix, numBins, column.MininimumBinSize, valueColumnId, labelColumnId, dataRow);
}
}

public sealed class RobustScalerOneColumnFunctionBuilder : OneColumnFunctionBuilderBase<double>
{
private readonly MinMaxDblAggregator _minMaxAggregator;
private readonly MedianDblAggregator _medianAggregator;
private readonly bool _centerData;
private readonly uint _quantileMin;
private readonly uint _quantileMax;
private VBuffer<double> _buffer;

private RobustScalerOneColumnFunctionBuilder(IHost host, long lim, bool centerData, uint quantileMin, uint quantileMax, ValueGetter<double> getSrc)
: base(host, lim, getSrc)
{
// Using the MinMax aggregator since that is what needs to be found here as well.
// The difference is how the min/max are used.
_minMaxAggregator = new MinMaxDblAggregator(1);
_medianAggregator = new MedianDblAggregator();
_buffer = new VBuffer<double>(1, new double[1]);
_centerData = centerData;
_quantileMin = quantileMin;
_quantileMax = quantileMax;
}

protected override bool ProcessValue(in double val)
{
if (!base.ProcessValue(in val))
return false;
VBufferEditor.CreateFromBuffer(ref _buffer).Values[0] = val;
_minMaxAggregator.ProcessValue(in _buffer);
_medianAggregator.ProcessValue(in val);
return true;
}

public static IColumnFunctionBuilder Create(NormalizingEstimator.RobustScalingColumnOptions column, IHost host, DataViewType srcType,
bool centerData, uint quantileMin, uint quantileMax, ValueGetter<double> getter)
{
host.CheckUserArg(column.MaximumExampleCount > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
return new RobustScalerOneColumnFunctionBuilder(host, column.MaximumExampleCount, centerData, quantileMin, quantileMax, getter);
}

public override IColumnFunction CreateColumnFunction()
{
_minMaxAggregator.Finish();
_medianAggregator.Finish();

double median = _medianAggregator.Median;
double range = _minMaxAggregator.Max[0] - _minMaxAggregator.Min[0];
// Divide the range by 100 because we need to make the number, i.e. 75, into a decimal, .75
double quantileRange = (_quantileMax - _quantileMin) / 100f;
double scale = 1 / (range * quantileRange);

if (_centerData)
return AffineColumnFunction.Create(Host, scale, median);
else
return AffineColumnFunction.Create(Host, scale, 0);
}
}

public sealed class RobustScalerVecFunctionBuilder : OneColumnFunctionBuilderBase<VBuffer<double>>
{
private readonly MinMaxDblAggregator _minMaxAggregator;
private readonly MedianDblAggregator[] _medianAggregators;
private readonly bool _centerData;
private readonly uint _quantileMin;
private readonly uint _quantileMax;

private RobustScalerVecFunctionBuilder(IHost host, long lim, int vectorSize, bool centerData, uint quantileMin, uint quantileMax, ValueGetter<VBuffer<double>> getSrc)
: base(host, lim, getSrc)
{
// Using the MinMax aggregator since that is what needs to be found here as well.
// The difference is how the min/max are used.
_minMaxAggregator = new MinMaxDblAggregator(vectorSize);

// If we aren't centering data dont need the median.
_medianAggregators = new MedianDblAggregator[vectorSize];

for (int i = 0; i < vectorSize; i++)
{
_medianAggregators[i] = new MedianDblAggregator();
}

_centerData = centerData;
_quantileMin = quantileMin;
_quantileMax = quantileMax;
}

protected override bool ProcessValue(in VBuffer<double> val)
{
if (!base.ProcessValue(in val))
return false;
_minMaxAggregator.ProcessValue(in val);

// Have to calculate the median per slot
var span = val.GetValues();
for (int i = 0; i < _medianAggregators.Length; i++)
{
_medianAggregators[i].ProcessValue(span[i]);
}

return true;
}

public static IColumnFunctionBuilder Create(NormalizingEstimator.RobustScalingColumnOptions column, IHost host, VectorDataViewType srcType,
bool centerData, uint quantileMin, uint quantileMax, ValueGetter<VBuffer<double>> getter)
{
host.CheckUserArg(column.MaximumExampleCount > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
var vectorSize = srcType.Size;
return new RobustScalerVecFunctionBuilder(host, column.MaximumExampleCount, vectorSize, centerData, quantileMin, quantileMax, getter);
}

public override IColumnFunction CreateColumnFunction()
{
_minMaxAggregator.Finish();

double[] scale = new double[_medianAggregators.Length];
double[] median = new double[_medianAggregators.Length];

// Have to calculate the median per slot
for (int i = 0; i < _medianAggregators.Length; i++)
{
_medianAggregators[i].Finish();
median[i] = _medianAggregators[i].Median;

double range = _minMaxAggregator.Max[i] - _minMaxAggregator.Min[i];

// Divide the range by 100 because we need to make the number, i.e. 75, into a decimal, .75
double quantileRange = (_quantileMax - _quantileMin) / 100f;
scale[i] = 1 / (range * quantileRange);

}

if (_centerData)
return AffineColumnFunction.Create(Host, scale, median, null);
else
return AffineColumnFunction.Create(Host, scale, null, null);

}
}
}
}
}
Loading