@@ -92,9 +92,17 @@ internal sealed class Options : TransformInputBase
92
92
internal const string ShortName = "Onnx" ;
93
93
internal const string LoaderSignature = "OnnxTransform" ;
94
94
95
- internal readonly string [ ] Inputs ;
96
- internal readonly string [ ] Outputs ;
97
- internal readonly DataViewType [ ] OutputTypes ;
95
+ /// <summary>
96
+ /// Input column names from ML.NET's perspective. It can be ordered differently than ONNX model's input list.
97
+ /// It's also possible that the <see cref="Inputs"/> contains less variables than ONNX model's input list.
98
+ /// </summary>
99
+ internal string [ ] Inputs { get ; }
100
+ /// <summary>
101
+ /// Output column names from ML.NET's perspective. It can be ordered differently than ONNX model's output list.
102
+ /// It's also possible that the <see cref="Outputs"/> contains less variables than ONNX model's output list.
103
+ /// </summary>
104
+ internal string [ ] Outputs { get ; }
105
+ internal DataViewType [ ] OutputTypes { get ; }
98
106
99
107
private static VersionInfo GetVersionInfo ( )
100
108
{
@@ -196,7 +204,7 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
196
204
var shape = outputNodeInfo . Shape ;
197
205
var dims = AdjustDimensions ( shape ) ;
198
206
// OutputTypes[i] = new VectorDataViewType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray());
199
- OutputTypes [ i ] = Model . OutputTypes [ i ] ;
207
+ OutputTypes [ i ] = Model . ModelInfo . OutputsInfo [ idx ] . MlnetType ;
200
208
}
201
209
_options = options ;
202
210
}
@@ -302,9 +310,22 @@ private static IEnumerable<int> AdjustDimensions(OnnxShape shape)
302
310
private sealed class Mapper : MapperBase
303
311
{
304
312
private readonly OnnxTransformer _parent ;
313
+ /// <summary>
314
+ /// <see cref="_inputColIndices"/>'s i-th element value tells the <see cref="IDataView"/> column index to
315
+ /// find the i-th ONNX input.
316
+ /// </summary>
305
317
private readonly int [ ] _inputColIndices ;
318
+ /// <summary>
319
+ /// <see cref="_isInputVector"/>'s i-th element value tells if the i-th ONNX input is a tensor.
320
+ /// </summary>
306
321
private readonly bool [ ] _isInputVector ;
322
+ /// <summary>
323
+ /// <see cref="_inputTensorShapes"/>'s i-th element value tells if the i-th ONNX input's shape if it's a tensor.
324
+ /// </summary>
307
325
private readonly OnnxShape [ ] _inputTensorShapes ;
326
+ /// <summary>
327
+ /// <see cref="_inputOnnxTypes"/>'s i-th element value tells if the <see cref="Type"/> of the i-th ONNX input.
328
+ /// </summary>
308
329
private readonly System . Type [ ] _inputOnnxTypes ;
309
330
310
331
public Mapper ( OnnxTransformer parent , DataViewSchema inputSchema ) :
@@ -327,11 +348,11 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
327
348
var inputNodeInfo = model . ModelInfo . InputsInfo [ idx ] ;
328
349
329
350
var shape = inputNodeInfo . Shape ;
330
- var inputType = OnnxUtils . OnnxToMlNetType ( inputNodeInfo . Type ) ;
351
+ var inputType = OnnxUtils . OnnxToMlNetType ( inputNodeInfo . OrtType ) ;
331
352
332
353
var inputShape = AdjustDimensions ( inputNodeInfo . Shape ) ;
333
354
_inputTensorShapes [ i ] = inputShape . ToList ( ) ;
334
- _inputOnnxTypes [ i ] = inputNodeInfo . Type ;
355
+ _inputOnnxTypes [ i ] = inputNodeInfo . OrtType ;
335
356
336
357
var col = inputSchema . GetColumnOrNull ( _parent . Inputs [ i ] ) ;
337
358
if ( ! col . HasValue )
@@ -417,22 +438,21 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
417
438
{
418
439
disposer = null ;
419
440
Host . AssertValue ( input ) ;
420
- //Host.Assert(typeof(T) == _outputItemRawType);
421
441
422
442
var outputCache = new OutputCache ( ) ;
423
443
var activeOutputColNames = _parent . Outputs . Where ( ( x , i ) => activeOutput ( i ) ) . ToArray ( ) ;
424
444
425
- if ( _parent . Model . OutputTypes [ iinfo ] is VectorDataViewType )
445
+ if ( _parent . Model . ModelInfo . OutputsInfo [ iinfo ] . MlnetType is VectorDataViewType vectorType )
426
446
{
427
447
//var type = _parent.OutputTypes[iinfo].RawType;
428
- var type = OnnxUtils . OnnxToMlNetType ( _parent . Model . ModelInfo . OutputsInfo [ iinfo ] . Type ) . RawType ;
448
+ var elemRawType = vectorType . ItemType . RawType ;
429
449
//Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType);
430
450
var srcNamedValueGetters = GetNamedOnnxValueGetters ( input , _parent . Inputs , _inputColIndices , _isInputVector , _inputOnnxTypes , _inputTensorShapes ) ;
431
- return Utils . MarshalInvoke ( MakeTensorGetter < int > , type , input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
451
+ return Utils . MarshalInvoke ( MakeTensorGetter < int > , elemRawType , input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
432
452
}
433
453
else
434
454
{
435
- var type = _parent . Model . OutputTypes [ iinfo ] . RawType ;
455
+ var type = _parent . Model . ModelInfo . OutputsInfo [ iinfo ] . MlnetType . RawType ;
436
456
var srcNamedValueGetters = GetNamedOnnxValueGetters ( input , _parent . Inputs , _inputColIndices , _isInputVector , _inputOnnxTypes , _inputTensorShapes ) ;
437
457
return Utils . MarshalInvoke ( MakeObjectGetter < int > , type , input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
438
458
}
@@ -441,7 +461,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
441
461
private Delegate MakeTensorGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames , OutputCache outputCache )
442
462
{
443
463
Host . AssertValue ( input ) ;
444
- ValueGetter < VBuffer < T > > valuegetter = ( ref VBuffer < T > dst ) =>
464
+ ValueGetter < VBuffer < T > > valueGetter = ( ref VBuffer < T > dst ) =>
445
465
{
446
466
UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
447
467
var namedOnnxValue = outputCache . Outputs [ _parent . Outputs [ iinfo ] ] ;
@@ -452,20 +472,20 @@ private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxVal
452
472
denseTensor . Buffer . Span . CopyTo ( editor . Values ) ;
453
473
dst = editor . Commit ( ) ;
454
474
} ;
455
- return valuegetter ;
475
+ return valueGetter ;
456
476
}
457
477
458
478
private Delegate MakeObjectGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames , OutputCache outputCache )
459
479
{
460
480
Host . AssertValue ( input ) ;
461
- ValueGetter < T > valuegetter = ( ref T dst ) =>
481
+ ValueGetter < T > valueGetter = ( ref T dst ) =>
462
482
{
463
483
UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
464
484
var namedOnnxValue = outputCache . Outputs [ _parent . Outputs [ iinfo ] ] ;
465
485
var trueValue = namedOnnxValue . AsEnumerable < NamedOnnxValue > ( ) . Select ( value => value . AsDictionary < string , float > ( ) ) ;
466
486
dst = ( T ) trueValue ;
467
487
} ;
468
- return valuegetter ;
488
+ return valueGetter ;
469
489
}
470
490
471
491
private static INamedOnnxValueGetter [ ] GetNamedOnnxValueGetters ( DataViewRow input ,
@@ -634,7 +654,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
634
654
throw Host . Except ( $ "Column { input } doesn't match input node names of model.") ;
635
655
636
656
var inputNodeInfo = inputsInfo [ idx ] ;
637
- var expectedType = OnnxUtils . OnnxToMlNetType ( inputNodeInfo . Type ) ;
657
+ var expectedType = OnnxUtils . OnnxToMlNetType ( inputNodeInfo . OrtType ) ;
638
658
if ( col . ItemType != expectedType )
639
659
throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , input , expectedType . ToString ( ) , col . ItemType . ToString ( ) ) ;
640
660
}
0 commit comments