@@ -627,32 +627,43 @@ def _deserialize_graph(
627
627
628
628
# Initialize the values dictionary for this graph scope with the inputs and initializers
629
629
values : dict [str , _core .Value ] = {v .name : v for v in inputs } # type: ignore[misc]
630
+
631
+ # Enter the graph scope by pushing the values for this scope to the stack
630
632
scoped_values .append (values )
633
+
631
634
initializer_values = []
632
- for tensor in initializer_tensors :
633
- if tensor .name in values :
635
+ for i , tensor in enumerate (initializer_tensors ):
636
+ initializer_name = tensor .name
637
+ if not initializer_name :
638
+ logger .warning (
639
+ "Initializer tensor must have a name but the %s-th initializer does not. Skipping this initializer." ,
640
+ i ,
641
+ )
642
+ continue
643
+ if initializer_name in values :
634
644
# The initializer is for an input
635
- initializer_value = values [tensor . name ]
645
+ initializer_value = values [initializer_name ]
636
646
initializer_value .const_value = tensor
637
647
else :
638
648
# The initializer is for some other value. Create this value first
639
649
initializer_value = _core .Value (
640
650
None ,
641
651
index = None ,
642
- name = tensor . name ,
643
- # TODO(justinchuby): Fix type hinting for shape and dtype
644
- shape = tensor . shape , # type: ignore
652
+ name = initializer_name ,
653
+ # Include shape and type even if the shape or type is not provided as ValueInfoProto.
654
+ # Users expect initialized values to have shape and type information.
645
655
type = _core .TensorType (tensor .dtype ),
656
+ shape = tensor .shape , # type: ignore[arg-type]
646
657
const_value = tensor ,
647
658
)
648
659
if initializer_value .name in quantization_annotations :
649
660
_deserialize_quantization_annotation (
650
661
quantization_annotations [initializer_value .name ], initializer_value
651
662
)
652
- values [tensor . name ] = initializer_value # type: ignore[index]
663
+ values [initializer_name ] = initializer_value
653
664
initializer_values .append (initializer_value )
654
665
655
- # Add ValueInfos for this graph scope
666
+ # Build the value info dictionary to allow for quick lookup for this graph scope
656
667
value_info = {info .name : info for info in proto .value_info }
657
668
658
669
# Deserialize nodes with all known values
@@ -663,7 +674,10 @@ def _deserialize_graph(
663
674
664
675
# Fill in values for graph outputs
665
676
outputs = [deserialize_value_info_proto (info , values [info .name ]) for info in proto .output ]
677
+
678
+ # Exit the graph scope by popping the values for this scope from the stack
666
679
scoped_values .pop ()
680
+
667
681
return _core .Graph (
668
682
inputs ,
669
683
outputs ,
@@ -1204,24 +1218,24 @@ def _serialize_opset_imports_into(
1204
1218
opset_ids .add (domain = domain , version = version )
1205
1219
1206
1220
1207
- def _serialize_metadata_props_into (
1221
+ def _serialize_string_string_maps (
1208
1222
string_string_entries : proto_containers .RepeatedCompositeFieldContainer [
1209
1223
onnx .StringStringEntryProto
1210
1224
],
1211
1225
from_ : Mapping [str , str ],
1212
1226
) -> None :
1213
- """Serialize metadata properties into a repeated field of string-string entries.
1227
+ """Serialize a <str, str> mapping into a repeated field of string-string entries.
1214
1228
1215
1229
Args:
1216
1230
string_string_entries: The repeated field to serialize into.
1217
- from_: The mapping of metadata properties to serialize.
1231
+ from_: The mapping of a <str, str> mapping to serialize.
1218
1232
"""
1219
1233
# Sort names for deterministic serialization
1220
1234
for key in sorted (from_ ):
1221
1235
string_string_entries .add (key = key , value = from_ [key ])
1222
1236
1223
1237
1224
- _serialize_string_string_maps = _serialize_metadata_props_into
1238
+ _serialize_metadata_props_into = _serialize_string_string_maps
1225
1239
1226
1240
1227
1241
def _maybe_add_quantization_annotation (
@@ -1284,18 +1298,21 @@ def serialize_graph_into(
1284
1298
# TODO(justinchuby): We should add a method is_initializer() on Value when
1285
1299
# the initializer list is tracked
1286
1300
_maybe_add_quantization_annotation (graph_proto , input_ )
1301
+ input_names = {input_ .name for input_ in from_ .inputs }
1287
1302
# TODO(justinchuby): Support sparse_initializer
1288
- for initializer in from_ .initializers .values ():
1289
- _maybe_add_quantization_annotation (graph_proto , initializer )
1290
- if initializer .const_value is None :
1303
+ for value in from_ .initializers .values ():
1304
+ _maybe_add_quantization_annotation (graph_proto , value )
1305
+ if _should_create_value_info_for_value (value ) and value .name not in input_names :
1306
+ # Serialize information about all initializers into value_info,
1307
+ # except for those that are also graph inputs
1308
+ serialize_value_into (graph_proto .value_info .add (), value )
1309
+ if value .const_value is None :
1291
1310
# Skip initializers without constant values
1292
- logger .warning (
1293
- "Initializer '%s' does not have a constant value set." , initializer .name
1294
- )
1311
+ logger .warning ("Initializer '%s' does not have a constant value set." , value .name )
1295
1312
continue
1296
1313
# Make sure the tensor's name is the same as the value's name
1297
- initializer .const_value .name = initializer .name
1298
- serialize_tensor_into (graph_proto .initializer .add (), from_ = initializer .const_value )
1314
+ value .const_value .name = value .name
1315
+ serialize_tensor_into (graph_proto .initializer .add (), from_ = value .const_value )
1299
1316
for node in from_ :
1300
1317
serialize_node_into (graph_proto .node .add (), from_ = node )
1301
1318
for node_output in node .outputs :
0 commit comments