Skip to content

Commit

Permalink
Ngram with uint16 input fix (#4813)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lynx1820 authored Feb 7, 2020
1 parent c13565c commit 9fe474d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
10 changes: 6 additions & 4 deletions src/Microsoft.ML.Transforms/Text/NgramTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ public void SaveAsOnnx(OnnxContext ctx)
}
}

private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName )
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
VBuffer<ReadOnlyMemory<char>> slotNames = default;
GetSlotNames(iinfo, 0, ref slotNames);
Expand All @@ -777,13 +777,15 @@ private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName,

// TfIdfVectorizer accepts strings, int32 and int64 tensors.
// But in the ML.NET implementation of the NGramTransform, it only accepts keys as inputs
// That are the result of ValueToKeyMapping transformer, which outputs UInt32 values
// So, if it is UInt32 or UInt64, cast the output here to Int32/Int64
// That are the result of ValueToKeyMapping transformer, which outputs UInt32 values,
// Or TokenizingByCharacters, which outputs UInt16 values
// So, if it is UInt32, UInt64, or UInt16, cast the output here to Int32/Int64
string opType;
var vectorType = _srcTypes[iinfo] as VectorDataViewType;

if ((vectorType != null) &&
((vectorType.RawType == typeof(VBuffer<UInt32>)) || (vectorType.RawType == typeof(VBuffer<UInt64>))))
((vectorType.RawType == typeof(VBuffer<UInt32>)) || (vectorType.RawType == typeof(VBuffer<UInt64>)) ||
(vectorType.RawType == typeof(VBuffer<UInt16>))))
{
var dataKind = _srcTypes[iinfo] == NumberDataViewType.UInt32 ? DataKind.Int32 : DataKind.Int64;

Expand Down
11 changes: 8 additions & 3 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ public void WordTokenizerOnnxConversionTest()

[Theory]
[CombinatorialData]
public void NgramOnnxConnversionTest(
public void NgramOnnxConversionTest(
[CombinatorialValues(1, 2, 3)] int ngramLength,
bool useAllLength,
NgramExtractingEstimator.WeightingCriteria weighting)
Expand All @@ -1231,6 +1231,12 @@ public void NgramOnnxConnversionTest(
useAllLengths: useAllLength,
weighting: weighting)),

mlContext.Transforms.Text.TokenizeIntoCharactersAsKeys("Tokens", "Text")
.Append(mlContext.Transforms.Text.ProduceNgrams("NGrams", "Tokens",
ngramLength: ngramLength,
useAllLengths: useAllLength,
weighting: weighting)),

mlContext.Transforms.Text.ProduceWordBags("Tokens", "Text",
ngramLength: ngramLength,
useAllLengths: useAllLength,
Expand All @@ -1255,10 +1261,9 @@ public void NgramOnnxConnversionTest(
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxFilePath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4VectorColumns(transformedData.Schema[3].Name, outputNames[outputNames.Length-1], transformedData, onnxResult, 3);
CompareSelectedR4VectorColumns(transformedData.Schema[transformedData.Schema.Count-1].Name, outputNames[outputNames.Length-1], transformedData, onnxResult, 3); //comparing Ngrams
}
}

Done();
}

Expand Down

0 comments on commit 9fe474d

Please sign in to comment.