Skip to content

Commit 9e539b6

Browse files
committed
Address comments
1 parent 75f8ff0 commit 9e539b6

File tree

4 files changed

+38
-33
lines changed

4 files changed

+38
-33
lines changed

src/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
<ItemGroup>
1010
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
1111
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
12-
<ProjectReference Include="..\Microsoft.ML.OnnxConverter\Microsoft.ML.OnnxConverter.csproj" />
1312
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)" />
13+
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
1414
</ItemGroup>
1515

1616
</Project>

src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
212212
for (int i = 0; i < Outputs.Length; i++)
213213
{
214214
var outputInfo = Model.ModelInfo.GetOutput(Outputs[i]);
215-
OutputTypes[i] = outputInfo.MlnetType;
215+
OutputTypes[i] = outputInfo.DataViewType;
216216
}
217217
_options = options;
218218
}
@@ -351,7 +351,7 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
351351

352352
var inputShape = AdjustDimensions(inputNodeInfo.Shape);
353353
_inputTensorShapes[i] = inputShape.ToList();
354-
_inputOnnxTypes[i] = inputNodeInfo.OrtType;
354+
_inputOnnxTypes[i] = inputNodeInfo.TypeInOnnxRuntime;
355355

356356
var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]);
357357
if (!col.HasValue)
@@ -365,8 +365,8 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
365365
if (vectorType != null && vectorType.Size == 0)
366366
throw Host.Except($"Variable length input columns not supported");
367367

368-
if (type.GetItemType() != inputNodeInfo.MlnetType.GetItemType())
369-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.MlnetType.GetItemType().ToString(), type.ToString());
368+
if (type.GetItemType() != inputNodeInfo.DataViewType.GetItemType())
369+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
370370

371371
// If the column is one dimension we make sure that the total size of the Onnx shape matches.
372372
// Compute the total size of the known dimensions of the shape.
@@ -400,7 +400,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
400400

401401
var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
402402

403-
if (_parent.Model.ModelInfo.OutputsInfo[iinfo].MlnetType is VectorDataViewType vectorType)
403+
if (_parent.Model.ModelInfo.OutputsInfo[iinfo].DataViewType is VectorDataViewType vectorType)
404404
{
405405
var elemRawType = vectorType.ItemType.RawType;
406406
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
@@ -411,7 +411,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
411411
}
412412
else
413413
{
414-
var type = _parent.Model.ModelInfo.OutputsInfo[iinfo].MlnetType.RawType;
414+
var type = _parent.Model.ModelInfo.OutputsInfo[iinfo].DataViewType.RawType;
415415
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
416416
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames);
417417
}
@@ -731,18 +731,18 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
731731
throw Host.Except($"Column {input} doesn't match input node names of model.");
732732

733733
var inputNodeInfo = inputsInfo[idx];
734-
var expectedType = ((VectorDataViewType)inputNodeInfo.MlnetType).ItemType;
734+
var expectedType = ((VectorDataViewType)inputNodeInfo.DataViewType).ItemType;
735735
if (col.ItemType != expectedType)
736736
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
737737
}
738738

739739
for (var i = 0; i < Transformer.Outputs.Length; i++)
740740
{
741-
var outputName = Transformer.Outputs[i];
742741
resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i],
743742
Transformer.OutputTypes[i].IsKnownSizeVector() ? SchemaShape.Column.VectorKind.Vector
744743
: SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].GetItemType(), false);
745744
}
745+
746746
return new SchemaShape(resultDic.Values);
747747
}
748748
}

src/Microsoft.ML.OnnxTransformer/OnnxTypes.cs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,6 @@ public static Func<NamedOnnxValue, object> GetDataViewValueCasterAndResultedType
289289

290290
// NamedOnnxValue to scalar.
291291
Func<NamedOnnxValue, object> caster = (NamedOnnxValue value) => {
292-
var tensor = methodSpecialized.Invoke(value, new object[] { });
293292
var scalar = accessSpecialized.Invoke(value, new object[] { 0 });
294293
return scalar;
295294
};
@@ -447,23 +446,24 @@ public override int GetHashCode()
447446
public sealed class OnnxSequenceTypeAttribute : DataViewTypeAttribute
448447
{
449448
private Type _elemType;
449+
450450
/// <summary>
451-
/// Create an image type without knowing its height and width.
451+
/// Create a sequence type.
452452
/// </summary>
453453
public OnnxSequenceTypeAttribute()
454454
{
455455
}
456456

457457
/// <summary>
458-
/// Create an image type with known height and width.
458+
/// Create a <paramref name="elemType"/>-sequence type.
459459
/// </summary>
460460
public OnnxSequenceTypeAttribute(Type elemType)
461461
{
462462
_elemType = elemType;
463463
}
464464

465465
/// <summary>
466-
/// Images with the same width and height should equal.
466+
/// Sequence types with the same element type should be equal.
467467
/// </summary>
468468
public override bool Equals(DataViewTypeAttribute other)
469469
{
@@ -473,7 +473,7 @@ public override bool Equals(DataViewTypeAttribute other)
473473
}
474474

475475
/// <summary>
476-
/// Produce the same hash code for all images with the same height and the same width.
476+
/// Produce the same hash code for sequence types with the same element type.
477477
/// </summary>
478478
public override int GetHashCode()
479479
{
@@ -501,15 +501,18 @@ public sealed class OnnxMapTypeAttribute : DataViewTypeAttribute
501501
{
502502
private Type _keyType;
503503
private Type _valueType;
504+
504505
/// <summary>
505-
/// Create an image type without knowing its height and width.
506+
/// Create a map (aka dictionary) type.
506507
/// </summary>
507508
public OnnxMapTypeAttribute()
508509
{
509510
}
510511

511512
/// <summary>
512-
/// Create an image type with known height and width.
513+
/// Create a map (aka dictionary) type. A map is a collection of key-value
514+
/// pairs. <paramref name="keyType"/> specifies the type of keys and <paramref name="valueType"/>
515+
/// is the type of values.
513516
/// </summary>
514517
public OnnxMapTypeAttribute(Type keyType, Type valueType)
515518
{
@@ -518,7 +521,7 @@ public OnnxMapTypeAttribute(Type keyType, Type valueType)
518521
}
519522

520523
/// <summary>
521-
/// Images with the same width and height should equal.
524+
/// Map types with the same key type and the same value type should be equal.
522525
/// </summary>
523526
public override bool Equals(DataViewTypeAttribute other)
524527
{
@@ -528,7 +531,7 @@ public override bool Equals(DataViewTypeAttribute other)
528531
}
529532

530533
/// <summary>
531-
/// Produce the same hash code for all images with the same height and the same width.
534+
/// Produce the same hash code for map types with the same key type and the same value type.
532535
/// </summary>
533536
public override int GetHashCode()
534537
{

src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -86,34 +86,36 @@ public OnnxVariableInfo GetOutput(string name)
8686
public class OnnxVariableInfo
8787
{
8888
/// <summary>
89-
/// The Name of the node
89+
/// The Name of the variable. Note that ONNX variable are named.
9090
/// </summary>
9191
public string Name { get; }
9292
/// <summary>
93-
/// The shape of the node
93+
/// The shape of the variable if the variable is a tensor. For other
94+
/// types such sequence and dictionary, <see cref="Shape"/> would be
95+
/// <see langword="null"/>.
9496
/// </summary>
9597
public OnnxShape Shape { get; }
9698
/// <summary>
97-
/// The type of the node
99+
/// The type of the variable produced by ONNXRuntime.
98100
/// </summary>
99-
public System.Type OrtType { get; }
101+
public Type TypeInOnnxRuntime { get; }
100102
/// <summary>
101-
/// The <see cref="DataViewType"/> that this ONNX variable corresponds
103+
/// The <see cref="Data.DataViewType"/> that this ONNX variable corresponds
102104
/// to in <see cref="IDataView"/>'s type system.
103105
/// </summary>
104-
public DataViewType MlnetType { get; }
106+
public DataViewType DataViewType { get; }
105107
/// <summary>
106108
/// A method to case <see cref="NamedOnnxValue"/> produced by
107-
/// ONNXRuntime to the type specified in <see cref="MlnetType"/>.
109+
/// ONNXRuntime to the type specified in <see cref="DataViewType"/>.
108110
/// </summary>
109111
public Func<NamedOnnxValue, object> Caster { get; }
110112

111-
public OnnxVariableInfo(string name, OnnxShape shape, System.Type ortType, DataViewType mlnetType, Func<NamedOnnxValue, object> caster)
113+
public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, DataViewType mlnetType, Func<NamedOnnxValue, object> caster)
112114
{
113115
Name = name;
114116
Shape = shape;
115-
OrtType = ortType;
116-
MlnetType = mlnetType;
117+
TypeInOnnxRuntime = typeInOnnxRuntime;
118+
DataViewType = mlnetType;
117119
Caster = caster;
118120
}
119121
}
@@ -301,8 +303,8 @@ private void Dispose(bool disposing)
301303

302304
internal sealed class OnnxUtils
303305
{
304-
private static HashSet<System.Type> _onnxTypeMap =
305-
new HashSet<System.Type>
306+
private static HashSet<Type> _onnxTypeMap =
307+
new HashSet<Type>
306308
{
307309
typeof(Double),
308310
typeof(Single),
@@ -313,8 +315,8 @@ internal sealed class OnnxUtils
313315
typeof(UInt32),
314316
typeof(UInt64)
315317
};
316-
private static Dictionary<System.Type, InternalDataKind> _typeToKindMap=
317-
new Dictionary<System.Type, InternalDataKind>
318+
private static Dictionary<Type, InternalDataKind> _typeToKindMap=
319+
new Dictionary<Type, InternalDataKind>
318320
{
319321
{ typeof(Single) , InternalDataKind.R4},
320322
{ typeof(Double) , InternalDataKind.R8},
@@ -364,7 +366,7 @@ public static NamedOnnxValue CreateNamedOnnxValue<T>(string name, ReadOnlySpan<T
364366
/// </summary>
365367
/// <param name="type"></param>
366368
/// <returns></returns>
367-
public static PrimitiveDataViewType OnnxToMlNetType(System.Type type)
369+
public static PrimitiveDataViewType OnnxToMlNetType(Type type)
368370
{
369371
if (!_typeToKindMap.ContainsKey(type))
370372
throw Contracts.ExceptNotSupp("Onnx type not supported", type);

0 commit comments

Comments
 (0)