Skip to content

Commit 8d1809e

Browse files
authored
Added slot names support for OnnxTransformer (#4857)
* Added slot names support for OnnxTransformer * Updated baselines for failing tests
1 parent f0b9aa4 commit 8d1809e

File tree

7 files changed

+312
-8
lines changed

7 files changed

+312
-8
lines changed

src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Generic;
67
using System.IO;
8+
using System.Linq;
79
using Google.Protobuf;
810
using Microsoft.ML;
911
using Microsoft.ML.Command;
@@ -188,7 +190,9 @@ internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx,
188190
if (outputData.Schema[i].IsHidden)
189191
continue;
190192

191-
var idataviewColumnName = outputData.Schema[i].Name;
193+
var column = outputData.Schema[i];
194+
195+
var idataviewColumnName = column.Name;
192196

193197
// Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
194198
// _inputToDrop should be removed too.
@@ -204,11 +208,39 @@ internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx,
204208
var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName + ".output", true);
205209
ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
206210
ctx.AddOutputVariable(outputData.Schema[i].Type, trueVariableName);
211+
212+
if (column.HasSlotNames())
213+
AddSlotNames(ctx, column);
207214
}
208215

216+
// Add metadata graph outputs
217+
209218
return ctx.MakeModel();
210219
}
211220

221+
private static void AddSlotNames(OnnxContextImpl ctx, DataViewSchema.Column column)
222+
{
223+
VBuffer<ReadOnlyMemory<char>> slotNames = default;
224+
column.GetSlotNames(ref slotNames);
225+
IEnumerable<string> slotNamesAsStrings = slotNames.DenseValues().Select(name => name.ToString());
226+
227+
string opType = "LabelEncoder";
228+
string labelEncoderInputName = $"mlnet.{column.Name}.unusedInput";
229+
string labelEncoderOutputName = $"mlnet.{column.Name}.unusedOutput";
230+
string labelEncoderNodeName = $"mlnet.{column.Name}.SlotNames";
231+
232+
string[] oneVals = new string[] { "one" };
233+
long[] dims = new long[] { 1, 1 };
234+
var one = ctx.AddInitializer(oneVals, dims, labelEncoderNodeName);
235+
236+
var labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, labelEncoderOutputName, true);
237+
var node = ctx.CreateNode(opType, one, labelEncoderOutput, labelEncoderNodeName);
238+
node.AddAttribute("keys_strings", slotNamesAsStrings);
239+
node.AddAttribute("values_int64s", Enumerable.Range(0, slotNames.Length).Select(x => (long)x));
240+
241+
ctx.AddOutputVariable(NumberDataViewType.Int64, labelEncoderOutput);
242+
}
243+
212244
private void Run(IChannel ch)
213245
{
214246
ILegacyDataLoader loader = null;

src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
using Microsoft.ML.OnnxRuntime;
1616
using Microsoft.ML.Runtime;
1717
using Microsoft.ML.Transforms.Onnx;
18+
using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper;
1819
using OnnxShape = System.Collections.Generic.List<int>;
1920

2021
[assembly: LoadableClass(OnnxTransformer.Summary, typeof(IDataTransform), typeof(OnnxTransformer),
@@ -416,11 +417,40 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
416417
{
417418
var onnxOutputName = _parent.Outputs[i];
418419
var columnName = onnxOutputName.EndsWith(stdSuffix) ? onnxOutputName.Replace(stdSuffix, "") : onnxOutputName;
419-
info[i] = new DataViewSchema.DetachedColumn(columnName, _parent.OutputTypes[i], null);
420+
421+
var builder = new DataViewSchema.Annotations.Builder();
422+
AddSlotNames(columnName, builder);
423+
424+
info[i] = new DataViewSchema.DetachedColumn(columnName, _parent.OutputTypes[i], builder.ToAnnotations());
420425
}
421426
return info;
422427
}
423428

429+
private void AddSlotNames(string columnName, DataViewSchema.Annotations.Builder builder)
430+
{
431+
var graph = _parent.Model.Graph;
432+
var nodes = graph.Node;
433+
434+
var slotNamesNodeName = $"mlnet.{columnName}.SlotNames";
435+
var slotsNode = nodes.FirstOrDefault(node => node.Name == slotNamesNodeName);
436+
var slotsAttr = slotsNode?.Attribute.FirstOrDefault(attr => attr.Name == "keys_strings");
437+
if (slotsAttr == null)
438+
return;
439+
440+
int count = slotsAttr.Strings.Count();
441+
ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter = (ref VBuffer<ReadOnlyMemory<char>> dst) =>
442+
{
443+
var dstEditor = VBufferEditor.Create(ref dst, count);
444+
for (int i = 0; i < count; i++)
445+
{
446+
dstEditor.Values[i] = slotsAttr.Strings[i].ToString(Encoding.UTF8).AsMemory();
447+
}
448+
dst = dstEditor.Commit();
449+
};
450+
451+
builder.AddSlotNames(count, getter);
452+
}
453+
424454
private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
425455
{
426456
return col => Enumerable.Range(0, _parent.Outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col);

src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using Microsoft.ML.OnnxRuntime;
1212
using Microsoft.ML.OnnxRuntime.Tensors;
1313
using Microsoft.ML.Runtime;
14+
using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper;
1415
using OnnxShape = System.Collections.Generic.List<int>;
1516

1617
namespace Microsoft.ML.Transforms.Onnx
@@ -157,6 +158,8 @@ public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, Da
157158
/// </summary>
158159
internal OnnxModelInfo ModelInfo { get; }
159160

161+
internal GraphProto Graph { get; }
162+
160163
/// <summary>
161164
/// Constructs OnnxModel object from file.
162165
/// </summary>
@@ -217,6 +220,8 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
217220

218221
// Create a view to the used ONNX model from ONNXRuntime's perspective.
219222
ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);
223+
224+
Graph = model.Graph;
220225
}
221226

222227
private List<OnnxVariableInfo> GetOnnxVariablesFromMetadata(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata,
@@ -233,6 +238,10 @@ private List<OnnxVariableInfo> GetOnnxVariablesFromMetadata(IReadOnlyDictionary<
233238
var dataViewType = typePool[name];
234239
var caster = casterPool?[name];
235240

241+
if (name.StartsWith("mlnet.") &&
242+
(name.EndsWith(".unusedInput") || name.EndsWith(".unusedOutput")))
243+
continue;
244+
236245
OnnxVariableInfo info = null;
237246
if (shapeDictionary != null && shapeDictionary.ContainsKey(name))
238247
{

src/Microsoft.ML.Transforms/Text/NgramTransform.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -770,9 +770,6 @@ public void SaveAsOnnx(OnnxContext ctx)
770770

771771
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
772772
{
773-
VBuffer<ReadOnlyMemory<char>> slotNames = default;
774-
GetSlotNames(iinfo, 0, ref slotNames);
775-
776773
var transformInfo = _parent._transformInfos[iinfo];
777774

778775
// TfIdfVectorizer accepts strings, int32 and int64 tensors.

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt

Lines changed: 152 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,8 @@
324324
{
325325
"name": "target_weights",
326326
"floats": [
327-
0.50476193,
328-
-0.97911227
327+
0.504761934,
328+
-0.979112267
329329
],
330330
"type": "FLOATS"
331331
}
@@ -428,6 +428,51 @@
428428
"name": "Identity1",
429429
"opType": "Identity"
430430
},
431+
{
432+
"input": [
433+
"mlnet.F2.SlotNames"
434+
],
435+
"output": [
436+
"mlnet.F2.unusedOutput"
437+
],
438+
"name": "mlnet.F2.SlotNames",
439+
"opType": "LabelEncoder",
440+
"attribute": [
441+
{
442+
"name": "keys_strings",
443+
"strings": [
444+
"NA==",
445+
"MQ==",
446+
"OA==",
447+
"MTA=",
448+
"Mg==",
449+
"Mw==",
450+
"Nw==",
451+
"NQ==",
452+
"Ng==",
453+
"OQ=="
454+
],
455+
"type": "STRINGS"
456+
},
457+
{
458+
"name": "values_int64s",
459+
"ints": [
460+
"0",
461+
"1",
462+
"2",
463+
"3",
464+
"4",
465+
"5",
466+
"6",
467+
"7",
468+
"8",
469+
"9"
470+
],
471+
"type": "INTS"
472+
}
473+
],
474+
"domain": "ai.onnx.ml"
475+
},
431476
{
432477
"input": [
433478
"Features"
@@ -438,6 +483,53 @@
438483
"name": "Identity2",
439484
"opType": "Identity"
440485
},
486+
{
487+
"input": [
488+
"mlnet.Features.SlotNames"
489+
],
490+
"output": [
491+
"mlnet.Features.unusedOutput"
492+
],
493+
"name": "mlnet.Features.SlotNames",
494+
"opType": "LabelEncoder",
495+
"attribute": [
496+
{
497+
"name": "keys_strings",
498+
"strings": [
499+
"RjE=",
500+
"RjIuNA==",
501+
"RjIuMQ==",
502+
"RjIuOA==",
503+
"RjIuMTA=",
504+
"RjIuMg==",
505+
"RjIuMw==",
506+
"RjIuNw==",
507+
"RjIuNQ==",
508+
"RjIuNg==",
509+
"RjIuOQ=="
510+
],
511+
"type": "STRINGS"
512+
},
513+
{
514+
"name": "values_int64s",
515+
"ints": [
516+
"0",
517+
"1",
518+
"2",
519+
"3",
520+
"4",
521+
"5",
522+
"6",
523+
"7",
524+
"8",
525+
"9",
526+
"10"
527+
],
528+
"type": "INTS"
529+
}
530+
],
531+
"domain": "ai.onnx.ml"
532+
},
441533
{
442534
"input": [
443535
"PredictedLabel"
@@ -484,6 +576,28 @@
484576
0
485577
],
486578
"name": "Offset"
579+
},
580+
{
581+
"dims": [
582+
"1",
583+
"1"
584+
],
585+
"dataType": 8,
586+
"stringData": [
587+
"b25l"
588+
],
589+
"name": "mlnet.F2.SlotNames"
590+
},
591+
{
592+
"dims": [
593+
"1",
594+
"1"
595+
],
596+
"dataType": 8,
597+
"stringData": [
598+
"b25l"
599+
],
600+
"name": "mlnet.Features.SlotNames"
487601
}
488602
],
489603
"input": [
@@ -597,6 +711,24 @@
597711
}
598712
}
599713
},
714+
{
715+
"name": "mlnet.F2.unusedOutput",
716+
"type": {
717+
"tensorType": {
718+
"elemType": 7,
719+
"shape": {
720+
"dim": [
721+
{
722+
"dimValue": "-1"
723+
},
724+
{
725+
"dimValue": "1"
726+
}
727+
]
728+
}
729+
}
730+
}
731+
},
600732
{
601733
"name": "Features.output",
602734
"type": {
@@ -615,6 +747,24 @@
615747
}
616748
}
617749
},
750+
{
751+
"name": "mlnet.Features.unusedOutput",
752+
"type": {
753+
"tensorType": {
754+
"elemType": 7,
755+
"shape": {
756+
"dim": [
757+
{
758+
"dimValue": "-1"
759+
},
760+
{
761+
"dimValue": "1"
762+
}
763+
]
764+
}
765+
}
766+
}
767+
},
618768
{
619769
"name": "PredictedLabel.output",
620770
"type": {

0 commit comments

Comments
 (0)