Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/onnx_ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -1975,11 +1975,26 @@ def serialize_type(type_protocol: _protocols.TypeProtocol) -> onnx.TypeProto:
@_capture_errors(lambda type_proto, from_: repr(from_))
def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None:
value_field = type_proto.WhichOneof("value")
if value_field is None:
# We cannot write the shape because we do not know where to write it
logger.warning(
# TODO(justinchuby): Show more context about the value when move everything to an object
"The value type for shape %s is not known. Please set type for the value. Skipping serialization",
from_,
)
return
tensor_type = getattr(type_proto, value_field)
while not isinstance(tensor_type.elem_type, int):
# Find the leaf type that has the shape field
type_proto = tensor_type.elem_type
value_field = type_proto.WhichOneof("value")
if value_field is None:
logger.warning(
# TODO(justinchuby): Show more context about the value when move everything to an object
"The value type for shape %s is not known. Please set type for the value. Skipping serialization",
from_,
)
return
tensor_type = getattr(type_proto, value_field)
# When from is empty, we still need to set the shape field to an empty list by touching it
tensor_type.shape.ClearField("dim")
Expand Down
9 changes: 9 additions & 0 deletions src/onnx_ir/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,15 @@ def test_serialize_attribute(self, _: str, typ: ir.AttributeType, value, expecte
self.assertEqual(deserialized_attr.type, attr.type)
self.assertEqual(deserialized_attr.value, expected)

def test_serialize_shape_into_skips_writing_when_value_type_not_known(self):
shape = ir.Shape((1, 2, 3))
proto = onnx.TypeProto()
self.assertIsNone(proto.WhichOneof("value"))
serde.serialize_shape_into(proto, shape)
self.assertIsNone(proto.WhichOneof("value"))
deserialized = serde.deserialize_type_proto_for_shape(proto)
self.assertIsNone(deserialized, shape)


class QuantizationAnnotationTest(unittest.TestCase):
"""Test that quantization annotations are correctly serialized and deserialized."""
Expand Down