Skip to content

Support more types for HashEstimator #5104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 19, 2020
Merged
2 changes: 1 addition & 1 deletion build/Dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<GoogleProtobufPackageVersion>3.10.1</GoogleProtobufPackageVersion>
<LightGBMPackageVersion>2.2.3</LightGBMPackageVersion>
<MicrosoftExtensionsPackageVersion>2.1.0</MicrosoftExtensionsPackageVersion>
<MicrosoftMLOnnxRuntimePackageVersion>1.2</MicrosoftMLOnnxRuntimePackageVersion>
<MicrosoftMLOnnxRuntimePackageVersion>1.3.0</MicrosoftMLOnnxRuntimePackageVersion>
<MlNetMklDepsPackageVersion>0.0.0.9</MlNetMklDepsPackageVersion>
<ParquetDotNetPackageVersion>2.1.3</ParquetDotNetPackageVersion>
<SystemDrawingCommonPackageVersion>4.5.0</SystemDrawingCommonPackageVersion>
Expand Down
19 changes: 7 additions & 12 deletions src/Microsoft.ML.Data/Transforms/Hashing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ public uint HashCoreOld(uint seed, uint mask, in float value)

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCore(uint seed, uint mask, in float value)
=> float.IsNaN(value) ? 0 : (Hashing.MixHash(Hashing.MurmurRound(seed, FloatUtils.GetBits(value == 0 ? 0 : value)), sizeof(uint)) & mask) + 1;
=> float.IsNaN(value) ? 0 : (Hashing.MixHash(Hashing.MurmurRound(seed, FloatUtils.GetBits(value == 0 ? 0 : value)), sizeof(float)) & mask) + 1;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCore(uint seed, uint mask, in VBuffer<float> values)
Expand Down Expand Up @@ -578,7 +578,7 @@ public uint HashCore(uint seed, uint mask, in double value)
if (double.IsNaN(value))
return 0;

return (Hashing.MixHash(HashRound(seed, value), sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value), sizeof(double)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -600,8 +600,6 @@ private uint HashRound(uint seed, double value)
ulong v = FloatUtils.GetBits(value == 0 ? 0 : value);
var hash = Hashing.MurmurRound(seed, Utils.GetLo(v));
var hi = Utils.GetHi(v);
if (hi == 0)
return hash;
return Hashing.MurmurRound(hash, hi);
}
}
Expand Down Expand Up @@ -815,7 +813,7 @@ public uint HashCoreOld(uint seed, uint mask, in ulong value)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCore(uint seed, uint mask, in ulong value)
{
return (Hashing.MixHash(HashRound(seed, value), sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value), sizeof(ulong)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -832,8 +830,6 @@ private uint HashRound(uint seed, ulong value)
{
var hash = Hashing.MurmurRound(seed, Utils.GetLo(value));
var hi = Utils.GetHi(value);
if (hi == 0)
return hash;
return Hashing.MurmurRound(hash, hi);
}
}
Expand Down Expand Up @@ -970,7 +966,7 @@ public uint HashCoreOld(uint seed, uint mask, in long value)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCore(uint seed, uint mask, in long value)
{
return (Hashing.MixHash(HashRound(seed, value), sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value), sizeof(long)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -987,8 +983,6 @@ private uint HashRound(uint seed, long value)
{
var hash = Hashing.MurmurRound(seed, Utils.GetLo((ulong)value));
var hi = Utils.GetHi((ulong)value);
if (hi == 0)
return hash;
return Hashing.MurmurRound(hash, hi);
}
}
Expand Down Expand Up @@ -1378,8 +1372,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri
castNode.AddAttribute("to", NumberDataViewType.UInt32.RawType);
murmurNode = ctx.CreateNode(opType, castOutput, murmurOutput, ctx.GetNodeName(opType), "com.microsoft");
}
else if (srcType == NumberDataViewType.UInt32 ||
srcType == NumberDataViewType.Int32 || srcType == TextDataViewType.Instance)
else if (srcType == NumberDataViewType.UInt32 || srcType == NumberDataViewType.Int32 || srcType == NumberDataViewType.UInt64 ||
srcType == NumberDataViewType.Int64 || srcType == NumberDataViewType.Single || srcType == NumberDataViewType.Double || srcType == TextDataViewType.Instance)

{
murmurNode = ctx.CreateNode(opType, srcVariable, murmurOutput, ctx.GetNodeName(opType), "com.microsoft");
}
Expand Down
4 changes: 2 additions & 2 deletions test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ public sealed partial class TestDataPipe : TestDataPipeBase

private static Double[] _dataDouble = new Double[] { -0.0, 0, 1, -1, 2, -2, Double.NaN, Double.MinValue,
Double.MaxValue, Double.Epsilon, Double.NegativeInfinity, Double.PositiveInfinity };
private static uint[] _resultsDouble = new uint[] { 16, 16, 25, 27, 12, 2, 0, 6, 17, 4, 11, 30 };
private static uint[] _resultsDouble = new uint[] { 30, 30, 19, 24, 32, 25, 0, 2, 7, 30, 5, 3 };

private static VBuffer<Double> _dataDoubleSparse = new VBuffer<Double>(5, 3, new double[] { -0.0, 0, 1 }, new[] { 0, 3, 4 });
private static uint[] _resultsDoubleSparse = new uint[] { 16,16,16,16, 25 };
private static uint[] _resultsDoubleSparse = new uint[] { 30, 30, 30, 30, 19 };

[Fact()]
public void SavePipeLabelParsers()
Expand Down
28 changes: 20 additions & 8 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1201,8 +1201,8 @@ public void OneHotHashEncodingOnnxConversionTest()
// An InvalidOperationException stating that the onnx pipeline can't be fully converted is thrown
// when users try to convert the items mentioned above.
public void MurmurHashScalarTest(
[CombinatorialValues(DataKind.SByte, DataKind.Int16, DataKind.Int32, DataKind.Byte,
DataKind.UInt16, DataKind.UInt32, DataKind.String, DataKind.Boolean)] DataKind type,
[CombinatorialValues(DataKind.SByte, DataKind.Int16, DataKind.Int32, DataKind.Int64, DataKind.Byte,
DataKind.UInt16, DataKind.UInt32, DataKind.UInt64, DataKind.Single, DataKind.Double, DataKind.String, DataKind.Boolean)] DataKind type,
[CombinatorialValues(1, 5, 31)] int numberOfBits, bool useOrderedHashing)
{

Expand All @@ -1215,7 +1215,11 @@ public void MurmurHashScalarTest(
(type == DataKind.UInt16) ? 6 :
(type == DataKind.Int32) ? 8 :
(type == DataKind.UInt32) ? 10 :
(type == DataKind.String) ? 12 : 14;
(type == DataKind.Int64) ? 12 :
(type == DataKind.UInt64) ? 14 :
(type == DataKind.Single) ? 16 :
(type == DataKind.Double) ? 18 :
(type == DataKind.String) ? 20 : 22;

var dataView = mlContext.Data.LoadFromTextFile(dataPath, new[] {
new TextLoader.Column("Value", type, column),
Expand Down Expand Up @@ -1252,9 +1256,9 @@ public void MurmurHashScalarTest(
// An InvalidOperationException stating that the onnx pipeline can't be fully converted is thrown
// when users try to convert the items mentioned above.
public void MurmurHashVectorTest(
[CombinatorialValues(DataKind.SByte, DataKind.Int16, DataKind.Int32, DataKind.Byte,
DataKind.UInt16, DataKind.UInt32, DataKind.String, DataKind.Boolean)] DataKind type,
[CombinatorialValues(1, 5, 31)] int numberOfBits)
[CombinatorialValues(DataKind.SByte, DataKind.Int16, DataKind.Int32, DataKind.Int64, DataKind.Byte,
DataKind.UInt16, DataKind.UInt32, DataKind.UInt64, DataKind.Single, DataKind.Double, DataKind.String, DataKind.Boolean)] DataKind type,
[CombinatorialValues(1, 5, 31)] int numberOfBits)
{

var mlContext = new MLContext();
Expand All @@ -1266,15 +1270,23 @@ public void MurmurHashVectorTest(
(type == DataKind.UInt16) ? 6 :
(type == DataKind.Int32) ? 8 :
(type == DataKind.UInt32) ? 10 :
(type == DataKind.String) ? 12 : 14;
(type == DataKind.Int64) ? 12 :
(type == DataKind.UInt64) ? 14 :
(type == DataKind.Single) ? 16 :
(type == DataKind.Double) ? 18 :
(type == DataKind.String) ? 20 : 22;

var columnEnd = (type == DataKind.SByte) ? 1 :
(type == DataKind.Byte) ? 3 :
(type == DataKind.Int16) ? 5 :
(type == DataKind.UInt16) ? 7 :
(type == DataKind.Int32) ? 9 :
(type == DataKind.UInt32) ? 11 :
(type == DataKind.String) ? 13 : 15;
(type == DataKind.Int64) ? 13 :
(type == DataKind.UInt64) ? 15 :
(type == DataKind.Single) ? 17 :
(type == DataKind.Double) ? 19 :
(type == DataKind.String) ? 21 : 23;

var dataView = mlContext.Data.LoadFromTextFile(dataPath, new[] {
new TextLoader.Column("Value", type, columnStart, columnEnd),
Expand Down
14 changes: 6 additions & 8 deletions test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -416,18 +416,16 @@ public void TestTrainTestSplitWithStratification()
Assert.Contains(4, ids);
split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.DateTimeStrat));
ids = split.TestSet.GetColumn<int>(split.TestSet.Schema[nameof(Input.Id)]);
Assert.Contains(0, ids);
Assert.Contains(7, ids);
Assert.Contains(5, ids);
Assert.Contains(6, ids);
split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.DateTimeOffsetStrat));
ids = split.TrainSet.GetColumn<int>(split.TrainSet.Schema[nameof(Input.Id)]);
Assert.Contains(1, ids);
Assert.Contains(3, ids);
split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.TimeSpanStrat));
ids = split.TestSet.GetColumn<int>(split.TestSet.Schema[nameof(Input.Id)]);
Assert.Contains(4, ids);
Assert.Contains(5, ids);
Assert.Contains(6, ids);
Assert.Contains(7, ids);
split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.TimeSpanStrat));
ids = split.TestSet.GetColumn<int>(split.TestSet.Schema[nameof(Input.Id)]);
Assert.Contains(1, ids);
Assert.Contains(2, ids);
}
}
}
40 changes: 28 additions & 12 deletions test/Microsoft.ML.Tests/Transformers/HashTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ ValueGetter<TType> hashGetter<TType>(HashingEstimator.ColumnOptions colInfo)
Assert.Equal(expectedCombinedSparse, result);
}

private void HashTestPositiveIntegerCore(ulong value, uint expected, uint expectedOrdered, uint expectedOrdered3, uint expectedCombined, uint expectedCombinedSparse)
private void HashTestPositiveIntegerCore32Bits(ulong value, uint expected, uint expectedOrdered, uint expectedOrdered3, uint expectedCombined, uint expectedCombinedSparse)

{
uint eKey = value == 0 ? 0 : expected;
uint eoKey = value == 0 ? 0 : expectedOrdered;
Expand All @@ -241,29 +242,44 @@ private void HashTestPositiveIntegerCore(ulong value, uint expected, uint expect
HashTestCore((uint)value, NumberDataViewType.UInt32, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
HashTestCore((uint)value, new KeyDataViewType(typeof(uint), int.MaxValue - 1), eKey, eoKey, e3Key, ecKey, 0);
}
HashTestCore(value, NumberDataViewType.UInt64, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
HashTestCore((ulong)value, new KeyDataViewType(typeof(ulong), int.MaxValue - 1), eKey, eoKey, e3Key, ecKey, 0);

HashTestCore(new DataViewRowId(value, 0), RowIdDataViewType.Instance, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
HashTestCore((ulong)value, new KeyDataViewType(typeof(ulong), int.MaxValue - 1), eKey, eoKey, e3Key, ecKey, 0);

// Next let's check signed numbers.

if (value <= (ulong)sbyte.MaxValue)
HashTestCore((sbyte)value, NumberDataViewType.SByte, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
if (value <= (ulong)short.MaxValue)
HashTestCore((short)value, NumberDataViewType.Int16, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
if (value <= int.MaxValue)
HashTestCore((int)value, NumberDataViewType.Int32, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
}

private void HashTestPositiveIntegerCore64Bits(ulong value, uint expected, uint expectedOrdered, uint expectedOrdered3, uint expectedCombined, uint expectedCombinedSparse)

{
uint eKey = value == 0 ? 0 : expected;
uint eoKey = value == 0 ? 0 : expectedOrdered;
uint e3Key = value == 0 ? 0 : expectedOrdered3;
uint ecKey = value == 0 ? 0 : expectedCombined;

HashTestCore(value, NumberDataViewType.UInt64, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);

// Next let's check signed numbers.
if (value <= long.MaxValue)
HashTestCore((long)value, NumberDataViewType.Int64, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
HashTestCore((long)value, NumberDataViewType.Int64, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
}

[Fact]
public void TestHashIntegerNumbers()
{
HashTestPositiveIntegerCore(0, 842, 358, 20, 882, 1010);
HashTestPositiveIntegerCore(1, 502, 537, 746, 588, 286);
HashTestPositiveIntegerCore(2, 407, 801, 652, 696, 172);
HashTestPositiveIntegerCore32Bits(0, 842, 358, 20, 882, 1010);
HashTestPositiveIntegerCore32Bits(1, 502, 537, 746, 588, 286);
HashTestPositiveIntegerCore32Bits(2, 407, 801, 652, 696, 172);

HashTestPositiveIntegerCore64Bits(0, 512, 851, 795, 1010, 620);
HashTestPositiveIntegerCore64Bits(1, 329, 190, 574, 491, 805);
HashTestPositiveIntegerCore64Bits(2, 484, 713, 128, 606, 326);
}

[Fact]
Expand All @@ -279,10 +295,10 @@ public void TestHashFloatingPointNumbers()
HashTestCore(1f, NumberDataViewType.Single, 463, 855, 732, 75, 487);
HashTestCore(-1f, NumberDataViewType.Single, 252, 612, 780, 179, 80);
HashTestCore(0f, NumberDataViewType.Single, 842, 358, 20, 882, 1010);
// Note that while we have the hash for numeric types be equal, the same is not necessarily the case for floating point numbers.
HashTestCore(1d, NumberDataViewType.Double, 937, 667, 424, 727, 510);
HashTestCore(-1d, NumberDataViewType.Double, 930, 78, 813, 582, 179);
HashTestCore(0d, NumberDataViewType.Double, 842, 358, 20, 882, 1010);

HashTestCore(1d, NumberDataViewType.Double, 188, 57, 690, 727, 36);
HashTestCore(-1d, NumberDataViewType.Double, 885, 804, 22, 582, 346);
HashTestCore(0d, NumberDataViewType.Double, 512, 851, 795, 1010, 620);
}

[Fact]
Expand Down
12 changes: 6 additions & 6 deletions test/data/type-samples.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
sbyte byte short ushort int uint strings boolean
0 1 0 23 0 4554 0 53 0 25 0 35 0 rain 0 1
2 3 2 13 2 455 2 63 2 63 2 63 djldaoiejffjauhglehdlgh pink 1 0
127 23 127 65 127 93 127 99 127 69 127 91 alibaba bug
-128 24 255 25 32767 325 65535 632 2147483647 34 4294967295 45 to mato monkey
0 2 5 98 -32768 335 78 698 -2147483648 97 3 56 U+12w blue
sbyte byte short ushort int uint long ulong float double strings boolean
0 1 0 23 0 4554 0 53 0 25 0 35 0 -1 0 1 0 -1 0 -1 0 rain 0 1
2 3 2 13 2 455 2 63 2 63 2 63 2 63 2 63 1 2 1 2 djldaoiejffjauhglehdlgh pink 1 0
127 23 127 65 127 93 127 99 127 69 127 91 2147483647 34 2147483647 34 -2 300 -2 300 alibaba bug
-128 24 255 25 32767 325 65535 632 2147483647 34 4294967295 45 9223372036854775807 97 9223372036854775807 97 355 4 355 4 to mato monkey
0 2 5 98 -32768 335 78 698 -2147483648 97 3 56 -9223372036854775808 5 4 5 -4000 5 -4000 5 U+12w blue