Skip to content

Commit 9660423

Browse files
committed
possibly a fix
1 parent d6d383e commit 9660423

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

src/Microsoft.ML.Transforms/Text/TextNormalizing.cs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,23 @@ public void SaveAsOnnx(OnnxContext ctx)
233233

234234
private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
235235
{
236-
var opType = "StringNormalizer";
237-
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
236+
// StringNormalizer only takes input of shapes [C] or [1,C],
237+
// so the input is squeezed to support inferred shapes ( e.g. [-1,C] ).
238+
var opType = "Squeeze";
239+
var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true);
240+
var node = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), "");
241+
node.AddAttribute("axes", new long[] { 0 });
242+
243+
opType = "StringNormalizer";
244+
var normalizerOutput = ctx.AddIntermediateVariable(null, "NormalizerOutput", true);
245+
node = ctx.CreateNode(opType, squeezeOutput, normalizerOutput, ctx.GetNodeName(opType), "");
238246
var isCaseChange = (_parent._caseMode == TextNormalizingEstimator.CaseMode.Lower) ? "LOWER" :
239247
(_parent._caseMode == TextNormalizingEstimator.CaseMode.Upper) ? "UPPER" : "NONE";
240248
node.AddAttribute("case_change_action", isCaseChange);
249+
250+
opType = "Unsqueeze";
251+
node = ctx.CreateNode(opType, normalizerOutput, dstVariableName, ctx.GetNodeName(opType), "");
252+
node.AddAttribute("axes", new long[] { 0 });
241253
}
242254
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
243255
{

0 commit comments

Comments
 (0)