Skip to content

Commit 15f8aad

Browse files
committed
Add test
1 parent 4880ecd commit 15f8aad

File tree

4 files changed

+478
-318
lines changed

4 files changed

+478
-318
lines changed

src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,17 @@ internal sealed class Options : TransformInputBase
9292
internal const string ShortName = "Onnx";
9393
internal const string LoaderSignature = "OnnxTransform";
9494

95-
internal readonly string[] Inputs;
96-
internal readonly string[] Outputs;
97-
internal readonly DataViewType[] OutputTypes;
95+
/// <summary>
96+
/// Input column names from ML.NET's perspective. It can be ordered differently than ONNX model's input list.
97+
/// It's also possible that the <see cref="Inputs"/> contains less variables than ONNX model's input list.
98+
/// </summary>
99+
internal string[] Inputs { get; }
100+
/// <summary>
101+
/// Output column names from ML.NET's perspective. It can be ordered differently than ONNX model's output list.
102+
/// It's also possible that the <see cref="Outputs"/> contains less variables than ONNX model's output list.
103+
/// </summary>
104+
internal string[] Outputs { get; }
105+
internal DataViewType[] OutputTypes { get; }
98106

99107
private static VersionInfo GetVersionInfo()
100108
{
@@ -196,7 +204,7 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
196204
var shape = outputNodeInfo.Shape;
197205
var dims = AdjustDimensions(shape);
198206
// OutputTypes[i] = new VectorDataViewType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray());
199-
OutputTypes[i] = Model.OutputTypes[i];
207+
OutputTypes[i] = Model.ModelInfo.OutputsInfo[idx].MlnetType;
200208
}
201209
_options = options;
202210
}
@@ -302,9 +310,22 @@ private static IEnumerable<int> AdjustDimensions(OnnxShape shape)
302310
private sealed class Mapper : MapperBase
303311
{
304312
private readonly OnnxTransformer _parent;
313+
/// <summary>
314+
/// <see cref="_inputColIndices"/>'s i-th element value tells the <see cref="IDataView"/> column index to
315+
/// find the i-th ONNX input.
316+
/// </summary>
305317
private readonly int[] _inputColIndices;
318+
/// <summary>
319+
/// <see cref="_isInputVector"/>'s i-th element value tells if the i-th ONNX input is a tensor.
320+
/// </summary>
306321
private readonly bool[] _isInputVector;
322+
/// <summary>
323+
/// <see cref="_inputTensorShapes"/>'s i-th element value tells if the i-th ONNX input's shape if it's a tensor.
324+
/// </summary>
307325
private readonly OnnxShape[] _inputTensorShapes;
326+
/// <summary>
327+
/// <see cref="_inputOnnxTypes"/>'s i-th element value tells if the <see cref="Type"/> of the i-th ONNX input.
328+
/// </summary>
308329
private readonly System.Type[] _inputOnnxTypes;
309330

310331
public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
@@ -327,11 +348,11 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
327348
var inputNodeInfo = model.ModelInfo.InputsInfo[idx];
328349

329350
var shape = inputNodeInfo.Shape;
330-
var inputType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);
351+
var inputType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.OrtType);
331352

332353
var inputShape = AdjustDimensions(inputNodeInfo.Shape);
333354
_inputTensorShapes[i] = inputShape.ToList();
334-
_inputOnnxTypes[i] = inputNodeInfo.Type;
355+
_inputOnnxTypes[i] = inputNodeInfo.OrtType;
335356

336357
var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]);
337358
if (!col.HasValue)
@@ -417,22 +438,21 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
417438
{
418439
disposer = null;
419440
Host.AssertValue(input);
420-
//Host.Assert(typeof(T) == _outputItemRawType);
421441

422442
var outputCache = new OutputCache();
423443
var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
424444

425-
if (_parent.Model.OutputTypes[iinfo] is VectorDataViewType)
445+
if (_parent.Model.ModelInfo.OutputsInfo[iinfo].MlnetType is VectorDataViewType vectorType)
426446
{
427447
//var type = _parent.OutputTypes[iinfo].RawType;
428-
var type = OnnxUtils.OnnxToMlNetType(_parent.Model.ModelInfo.OutputsInfo[iinfo].Type).RawType;
448+
var elemRawType = vectorType.ItemType.RawType;
429449
//Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType);
430450
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _parent.Inputs, _inputColIndices, _isInputVector, _inputOnnxTypes, _inputTensorShapes);
431-
return Utils.MarshalInvoke(MakeTensorGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCache);
451+
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCache);
432452
}
433453
else
434454
{
435-
var type = _parent.Model.OutputTypes[iinfo].RawType;
455+
var type = _parent.Model.ModelInfo.OutputsInfo[iinfo].MlnetType.RawType;
436456
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _parent.Inputs, _inputColIndices, _isInputVector, _inputOnnxTypes, _inputTensorShapes);
437457
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCache);
438458
}
@@ -441,7 +461,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
441461
private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames, OutputCache outputCache)
442462
{
443463
Host.AssertValue(input);
444-
ValueGetter<VBuffer<T>> valuegetter = (ref VBuffer<T> dst) =>
464+
ValueGetter<VBuffer<T>> valueGetter = (ref VBuffer<T> dst) =>
445465
{
446466
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache);
447467
var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]];
@@ -452,20 +472,20 @@ private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxVal
452472
denseTensor.Buffer.Span.CopyTo(editor.Values);
453473
dst = editor.Commit();
454474
};
455-
return valuegetter;
475+
return valueGetter;
456476
}
457477

458478
private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames, OutputCache outputCache)
459479
{
460480
Host.AssertValue(input);
461-
ValueGetter<T> valuegetter = (ref T dst) =>
481+
ValueGetter<T> valueGetter = (ref T dst) =>
462482
{
463483
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache);
464484
var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]];
465485
var trueValue = namedOnnxValue.AsEnumerable<NamedOnnxValue>().Select(value => value.AsDictionary<string, float>());
466486
dst = (T)trueValue;
467487
};
468-
return valuegetter;
488+
return valueGetter;
469489
}
470490

471491
private static INamedOnnxValueGetter[] GetNamedOnnxValueGetters(DataViewRow input,
@@ -634,7 +654,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
634654
throw Host.Except($"Column {input} doesn't match input node names of model.");
635655

636656
var inputNodeInfo = inputsInfo[idx];
637-
var expectedType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);
657+
var expectedType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.OrtType);
638658
if (col.ItemType != expectedType)
639659
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
640660
}

0 commit comments

Comments
 (0)