@@ -233,11 +233,23 @@ public void SaveAsOnnx(OnnxContext ctx)
233
233
234
234
private void SaveAsOnnxCore ( OnnxContext ctx , string srcVariableName , string dstVariableName )
235
235
{
236
- var opType = "StringNormalizer" ;
237
- var node = ctx . CreateNode ( opType , srcVariableName , dstVariableName , ctx . GetNodeName ( opType ) , "" ) ;
236
+ // StringNormalizer only takes input of shapes [C] or [1,C],
237
+ // so the input is squeezed to support inferred shapes ( e.g. [-1,C] ).
238
+ var opType = "Squeeze" ;
239
+ var squeezeOutput = ctx . AddIntermediateVariable ( null , "SqueezeOutput" , true ) ;
240
+ var node = ctx . CreateNode ( opType , srcVariableName , squeezeOutput , ctx . GetNodeName ( opType ) , "" ) ;
241
+ node . AddAttribute ( "axes" , new long [ ] { 0 } ) ;
242
+
243
+ opType = "StringNormalizer" ;
244
+ var normalizerOutput = ctx . AddIntermediateVariable ( null , "NormalizerOutput" , true ) ;
245
+ node = ctx . CreateNode ( opType , squeezeOutput , normalizerOutput , ctx . GetNodeName ( opType ) , "" ) ;
238
246
var isCaseChange = ( _parent . _caseMode == TextNormalizingEstimator . CaseMode . Lower ) ? "LOWER" :
239
247
( _parent . _caseMode == TextNormalizingEstimator . CaseMode . Upper ) ? "UPPER" : "NONE" ;
240
248
node . AddAttribute ( "case_change_action" , isCaseChange ) ;
249
+
250
+ opType = "Unsqueeze" ;
251
+ node = ctx . CreateNode ( opType , normalizerOutput , dstVariableName , ctx . GetNodeName ( opType ) , "" ) ;
252
+ node . AddAttribute ( "axes" , new long [ ] { 0 } ) ;
241
253
}
242
254
protected override DataViewSchema . DetachedColumn [ ] GetOutputColumnsCore ( )
243
255
{
0 commit comments