Skip to content

Commit 82d4bb7

Browse files
kere-nelKeren Fuentes
andauthored
ProduceWordBags Onnx Export Fix (#5435)
* fix for issue * fix documentation * aligning test * adding back line * aligning fix Co-authored-by: Keren Fuentes <kedejesu@microsoft.com>
1 parent afba0bd commit 82d4bb7

File tree

3 files changed

+13
-10
lines changed

3 files changed

+13
-10
lines changed

src/Microsoft.ML.Transforms/Text/TextCatalog.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ public static CustomStopWordsRemovingEstimator RemoveStopWords(this TransformsCa
334334
=> new CustomStopWordsRemovingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), outputColumnName, inputColumnName, stopwords);
335335

336336
/// <summary>
337-
/// Create a <see cref="WordHashBagEstimator"/>, which maps the column specified in <paramref name="inputColumnName"/>
337+
/// Create a <see cref="WordBagEstimator"/>, which maps the column specified in <paramref name="inputColumnName"/>
338338
/// to a vector of n-gram counts in a new column named <paramref name="outputColumnName"/>.
339339
/// </summary>
340340
/// <remarks>
@@ -363,7 +363,7 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf
363363
outputColumnName, inputColumnName, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting);
364364

365365
/// <summary>
366-
/// Create a <see cref="WordHashBagEstimator"/>, which maps the multiple columns specified in <paramref name="inputColumnNames"/>
366+
/// Create a <see cref="WordBagEstimator"/>, which maps the multiple columns specified in <paramref name="inputColumnNames"/>
367367
/// to a vector of n-gram counts in a new column named <paramref name="outputColumnName"/>.
368368
/// </summary>
369369
/// <remarks>

src/Microsoft.ML.Transforms/Text/WordTokenizing.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,10 @@ public void SaveAsOnnx(OnnxContext ctx)
415415
string[] separators = column.SeparatorsArray.Select(c => c.ToString()).ToArray();
416416
tokenizerNode.AddAttribute("separators", separators);
417417

418-
opType = "Squeeze";
419-
var squeezeOutput = ctx.AddIntermediateVariable(_type, column.Name);
420-
var squeezeNode = ctx.CreateNode(opType, intermediateVar, squeezeOutput, ctx.GetNodeName(opType), "");
421-
squeezeNode.AddAttribute("axes", new long[] { 1 });
418+
opType = "Reshape";
419+
var shape = ctx.AddInitializer(new long[] { 1, -1 }, new long[] { 2 }, "Shape");
420+
var reshapeOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, 1), column.Name);
421+
var reshapeNode = ctx.CreateNode(opType, new[] { intermediateVar, shape }, new[] { reshapeOutput }, ctx.GetNodeName(opType), "");
422422
}
423423
}
424424
}

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,9 +1323,12 @@ public void NgramOnnxConversionTest(
13231323
weighting: weighting)),
13241324

13251325
mlContext.Transforms.Text.ProduceWordBags("Tokens", "Text",
1326-
ngramLength: ngramLength,
1327-
useAllLengths: useAllLength,
1328-
weighting: weighting)
1326+
ngramLength: ngramLength,
1327+
useAllLengths: useAllLength,
1328+
weighting: weighting),
1329+
1330+
mlContext.Transforms.Text.TokenizeIntoWords("Tokens0", "Text")
1331+
.Append(mlContext.Transforms.Text.ProduceWordBags("Tokens", "Tokens0"))
13291332
};
13301333

13311334
for (int i = 0; i < pipelines.Length; i++)
@@ -1346,7 +1349,7 @@ public void NgramOnnxConversionTest(
13461349
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxFilePath, gpuDeviceId: _gpuDeviceId, fallbackToCpu: _fallbackToCpu);
13471350
var onnxTransformer = onnxEstimator.Fit(dataView);
13481351
var onnxResult = onnxTransformer.Transform(dataView);
1349-
var columnName = i == pipelines.Length - 1 ? "Tokens" : "NGrams";
1352+
var columnName = i >= pipelines.Length - 2 ? "Tokens" : "NGrams";
13501353
CompareResults(columnName, columnName, transformedData, onnxResult, 3);
13511354

13521355
VBuffer<ReadOnlyMemory<char>> mlNetSlots = default;

0 commit comments

Comments
 (0)