Skip to content

Commit 883a74f

Browse files
authored
[IR][fix] Save value info for initializers (#1552)
Previously initializers are not included in the graph value_info because they are not easily accessible from the Graph object. Now what we store all the Values for initializers, we can serialize the value information into the graph. Updated test models to include the value info protos for initializers so the round tripping tests can pass. Fix #1501
1 parent d7955f4 commit 883a74f

File tree

5 files changed

+45
-28
lines changed

5 files changed

+45
-28
lines changed

onnxscript/ir/serde.py

+37-20
Original file line numberDiff line numberDiff line change
@@ -627,32 +627,43 @@ def _deserialize_graph(
627627

628628
# Initialize the values dictionary for this graph scope with the inputs and initializers
629629
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
630+
631+
# Enter the graph scope by pushing the values for this scope to the stack
630632
scoped_values.append(values)
633+
631634
initializer_values = []
632-
for tensor in initializer_tensors:
633-
if tensor.name in values:
635+
for i, tensor in enumerate(initializer_tensors):
636+
initializer_name = tensor.name
637+
if not initializer_name:
638+
logger.warning(
639+
"Initializer tensor must have a name but the %s-th initializer does not. Skipping this initializer.",
640+
i,
641+
)
642+
continue
643+
if initializer_name in values:
634644
# The initializer is for an input
635-
initializer_value = values[tensor.name]
645+
initializer_value = values[initializer_name]
636646
initializer_value.const_value = tensor
637647
else:
638648
# The initializer is for some other value. Create this value first
639649
initializer_value = _core.Value(
640650
None,
641651
index=None,
642-
name=tensor.name,
643-
# TODO(justinchuby): Fix type hinting for shape and dtype
644-
shape=tensor.shape, # type: ignore
652+
name=initializer_name,
653+
# Include shape and type even if the shape or type is not provided as ValueInfoProto.
654+
# Users expect initialized values to have shape and type information.
645655
type=_core.TensorType(tensor.dtype),
656+
shape=tensor.shape, # type: ignore[arg-type]
646657
const_value=tensor,
647658
)
648659
if initializer_value.name in quantization_annotations:
649660
_deserialize_quantization_annotation(
650661
quantization_annotations[initializer_value.name], initializer_value
651662
)
652-
values[tensor.name] = initializer_value # type: ignore[index]
663+
values[initializer_name] = initializer_value
653664
initializer_values.append(initializer_value)
654665

655-
# Add ValueInfos for this graph scope
666+
# Build the value info dictionary to allow for quick lookup for this graph scope
656667
value_info = {info.name: info for info in proto.value_info}
657668

658669
# Deserialize nodes with all known values
@@ -663,7 +674,10 @@ def _deserialize_graph(
663674

664675
# Fill in values for graph outputs
665676
outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output]
677+
678+
# Exit the graph scope by popping the values for this scope from the stack
666679
scoped_values.pop()
680+
667681
return _core.Graph(
668682
inputs,
669683
outputs,
@@ -1204,24 +1218,24 @@ def _serialize_opset_imports_into(
12041218
opset_ids.add(domain=domain, version=version)
12051219

12061220

1207-
def _serialize_metadata_props_into(
1221+
def _serialize_string_string_maps(
12081222
string_string_entries: proto_containers.RepeatedCompositeFieldContainer[
12091223
onnx.StringStringEntryProto
12101224
],
12111225
from_: Mapping[str, str],
12121226
) -> None:
1213-
"""Serialize metadata properties into a repeated field of string-string entries.
1227+
"""Serialize a <str, str> mapping into a repeated field of string-string entries.
12141228
12151229
Args:
12161230
string_string_entries: The repeated field to serialize into.
1217-
from_: The mapping of metadata properties to serialize.
1231+
from_: The mapping of a <str, str> mapping to serialize.
12181232
"""
12191233
# Sort names for deterministic serialization
12201234
for key in sorted(from_):
12211235
string_string_entries.add(key=key, value=from_[key])
12221236

12231237

1224-
_serialize_string_string_maps = _serialize_metadata_props_into
1238+
_serialize_metadata_props_into = _serialize_string_string_maps
12251239

12261240

12271241
def _maybe_add_quantization_annotation(
@@ -1284,18 +1298,21 @@ def serialize_graph_into(
12841298
# TODO(justinchuby): We should add a method is_initializer() on Value when
12851299
# the initializer list is tracked
12861300
_maybe_add_quantization_annotation(graph_proto, input_)
1301+
input_names = {input_.name for input_ in from_.inputs}
12871302
# TODO(justinchuby): Support sparse_initializer
1288-
for initializer in from_.initializers.values():
1289-
_maybe_add_quantization_annotation(graph_proto, initializer)
1290-
if initializer.const_value is None:
1303+
for value in from_.initializers.values():
1304+
_maybe_add_quantization_annotation(graph_proto, value)
1305+
if _should_create_value_info_for_value(value) and value.name not in input_names:
1306+
# Serialize information about all initializers into value_info,
1307+
# except for those that are also graph inputs
1308+
serialize_value_into(graph_proto.value_info.add(), value)
1309+
if value.const_value is None:
12911310
# Skip initializers without constant values
1292-
logger.warning(
1293-
"Initializer '%s' does not have a constant value set.", initializer.name
1294-
)
1311+
logger.warning("Initializer '%s' does not have a constant value set.", value.name)
12951312
continue
12961313
# Make sure the tensor's name is the same as the value's name
1297-
initializer.const_value.name = initializer.name
1298-
serialize_tensor_into(graph_proto.initializer.add(), from_=initializer.const_value)
1314+
value.const_value.name = value.name
1315+
serialize_tensor_into(graph_proto.initializer.add(), from_=value.const_value)
12991316
for node in from_:
13001317
serialize_node_into(graph_proto.node.add(), from_=node)
13011318
for node_output in node.outputs:
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:06d78f841f26ec59cea1d15dd2c2a086cb907d6644ef8dac15e6d366935413e8
3-
size 43087292
2+
oid sha256:6dcf6976f8e324c497b0b74b2b9733c4b454f2c259488f5544bbc1aaaf57714c
3+
size 43091738
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:a336102b11d8439daa2c1a164a851f34414529a5610a046943fd869b1b44336f
3-
size 14665355
2+
oid sha256:ba424976b53bf2f141bfd001b48c0cc1c5c798b49def51f39a72f17e1f74e3a2
3+
size 14673089
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:31fbebb580ff85ed8eefa7fb95d4e2cbda41fe267afeaae2d4f4177264d1f4e7
3-
size 46918368
2+
oid sha256:12d24be13a03ea8ddebcc5ea229390d49fb0da08ad1df896b03703c664e2def1
3+
size 46921843
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:efd167b736106103235f42b480027c28c798dd46117526ca49067a2bdbc7b327
3-
size 311182
2+
oid sha256:6519a87ecf89132a9d39c59c47a442ae5833faf14811575d0b2323e8d13e30a8
3+
size 313873

0 commit comments

Comments
 (0)