Skip to content

Commit 209b496

Browse files
committed
Add 'SerializationArtifacts' to hold program+segments
1 parent 24c6961 commit 209b496

File tree

2 files changed

+79
-24
lines changed

2 files changed

+79
-24
lines changed

exir/_serialize/_program.py

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@
4545
# endian.
4646
_HEADER_BYTEORDER: Literal["little"] = "little"
4747

48+
@dataclass
49+
class SerializationArtifacts:
50+
"""
51+
Holds data required to serialize into a PTE.
52+
"""
53+
program: Program
54+
mutable_data: Optional[List[Buffer]]
55+
named_data: Optional[NamedDataStoreOutput]
56+
# TODO: add constants here and remove constant buffer.
4857

4958
@dataclass
5059
class AlignedData:
@@ -575,7 +584,7 @@ def serialize_pte_binary(
575584
return pte_data
576585

577586

578-
def _restore_segments(program: Program, segment_data: bytes) -> Program:
587+
def _restore_segments(program: Program, segment_data: bytes) -> SerializationArtifacts:
579588
"""Moves segments from `segment_data` into `program`.
580589
581590
This should recreate the original Program that the segments were extracted
@@ -589,7 +598,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
589598
the preceding data has been stripped off so that the first segment
590599
begins at offset zero.
591600
Returns:
592-
The Program with segments restored.
601+
SerializationArtifacts, containing the Program with delegate and constant segments restored, as well as mutable and named data segments.
593602
"""
594603
# Extract the list of segment data blobs, which parallel program.segments.
595604
segments: List[bytes] = []
@@ -624,7 +633,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
624633

625634
# Replace constants from constant_segment into constant_buffer.
626635
if program.constant_segment and len(program.constant_segment.offsets) > 0:
627-
buffers: List[Buffer] = []
636+
constant_buffers: List[Buffer] = []
628637
constant_segment = segments[program.constant_segment.segment_index]
629638
for i in range(len(program.constant_segment.offsets)):
630639
start_offset = program.constant_segment.offsets[i]
@@ -635,17 +644,60 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
635644
if i < len(program.constant_segment.offsets) - 1
636645
else len(constant_segment)
637646
)
638-
buffers.append(Buffer(storage=constant_segment[start_offset:end_offset]))
639-
program.constant_buffer = buffers
647+
constant_buffers.append(Buffer(storage=constant_segment[start_offset:end_offset]))
648+
program.constant_buffer = constant_buffers
640649
program.constant_segment.segment_index = 0
641650
program.constant_segment.offsets = []
642651

652+
# Extract mutable segments.
653+
mutable_data = None
654+
if program.mutable_data_segment and len(program.mutable_data_segments.offsets) > 0:
655+
mutable_buffers: List[Buffer] = []
656+
mutable_segment = segments[program.mutable_segment.segment_index]
657+
for i in range(len(program.mutable_segments.offsets)):
658+
start_offset = program.mutable_segment.offsets[i]
659+
# Note: this is the original end offset plus any padding between
660+
# it and the next start offset.
661+
end_offset = (
662+
program.mutable_segment.offsets[i + 1]
663+
if i < len(program.mutable_segment.offsets) - 1
664+
else len(mutable_segment)
665+
)
666+
mutable_buffers.append(Buffer(storage=mutable_segment[start_offset:end_offset]))
667+
mutable_data = mutable_buffers
668+
# Is this correct?
669+
program.mutable_segment.segment_index = 0
670+
program.mutable_segment.offsets = []
671+
672+
# Extract named data.
673+
named_data = None
674+
if program.named_data:
675+
named_data_buffers: List[bytes] = []
676+
pte_data: Dict[str, DataEntry] = {}
677+
678+
for entry in program.named_data:
679+
if (entry.segment_index >= len(segments)):
680+
raise ValueError(
681+
"Named data segment index "
682+
f"{entry.segment_index} >= num segments {len(segments)}"
683+
)
684+
named_data_buffers.append(segments[entry.segment_index])
685+
pte_data[entry.key] = DataEntry(
686+
buffer_index = len(named_data_buffers) - 1,
687+
alignment = 1, # Deserialization does not preserve alignment.
688+
tensor_layout = None
689+
)
690+
named_data = NamedDataStoreOutput(buffers=named_data_buffers, pte_data=pte_data, external_data=None)
691+
643692
# Clear out the segments list since the original Program didn't have one.
644693
program.segments = []
645-
return program
646-
694+
return SerializationArtifacts(
695+
program=program,
696+
mutable_data=mutable_data,
697+
named_data=named_data
698+
)
647699

648-
def deserialize_pte_binary(program_data: bytes) -> Program:
700+
def deserialize_pte_binary(program_data: bytes) -> SerializationArtifacts:
649701
"""Returns a Program deserialized from the given runtime binary data."""
650702
program_size = len(program_data)
651703
segment_base_offset = 0
@@ -664,8 +716,12 @@ def deserialize_pte_binary(program_data: bytes) -> Program:
664716

665717
if segment_base_offset != 0:
666718
# Move segment data back into the Program.
667-
program = _restore_segments(
719+
return _restore_segments(
668720
program=program, segment_data=program_data[segment_base_offset:]
669721
)
670722

671-
return program
723+
return SerializationArtifacts(
724+
program=program,
725+
mutable_data=None,
726+
named_data=None,
727+
)

exir/_serialize/test/test_program.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,13 @@ def constant_segment_with_tensor_alignment(
281281
)
282282

283283
# Convert back.
284-
program2 = deserialize_pte_binary(pte_data)
284+
deserialized = deserialize_pte_binary(pte_data)
285285
# Programs are the same besides constant_buffer, as deserialization
286286
# does not preserve constant segment; padding may be added
287287
# during serialization.
288-
self.assertEqual(program2.execution_plan, program.execution_plan)
288+
self.assertEqual(deserialized.program.execution_plan, program.execution_plan)
289289
# Number of constant tensors should be the same.
290-
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))
290+
self.assertEqual(len(deserialized.program.constant_buffer), len(program.constant_buffer))
291291

292292
def test_canonicalize_delegate_indices(self) -> None:
293293
def make_execution_plan(
@@ -426,10 +426,9 @@ def test_round_trip_no_header_no_segments(self) -> None:
426426
self.assertIsNone(eh)
427427

428428
# Convert back.
429-
program2 = deserialize_pte_binary(pte_data)
430-
429+
deserialized = deserialize_pte_binary(pte_data)
431430
# Programs should be the same.
432-
self.assert_programs_equal(program, program2)
431+
self.assert_programs_equal(program, deserialized.program)
433432

434433
def test_round_trip_large_buffer_sizes(self) -> None:
435434
"""Tests that when the non_const_buffer_sizes contains integers
@@ -439,7 +438,7 @@ def test_round_trip_large_buffer_sizes(self) -> None:
439438
program = get_test_program()
440439
program.execution_plan[0].non_const_buffer_sizes = [0, 2**48]
441440
flatbuffer_from_py = bytes(serialize_pte_binary(program))
442-
self.assert_programs_equal(program, deserialize_pte_binary(flatbuffer_from_py))
441+
self.assert_programs_equal(program, deserialize_pte_binary(flatbuffer_from_py).program)
443442

444443
def test_round_trip_no_segments_and_no_header(self) -> None:
445444
"""Tests that a Program serialized with extract_delegate_segments=True
@@ -463,10 +462,10 @@ def test_round_trip_no_segments_and_no_header(self) -> None:
463462
self.assertEqual(program_with_segments.segments, [])
464463

465464
# Convert back.
466-
program2 = deserialize_pte_binary(pte_data)
465+
deserialized = deserialize_pte_binary(pte_data)
467466

468467
# Programs should be the same.
469-
self.assert_programs_equal(program, program2)
468+
self.assert_programs_equal(program, deserialized.program)
470469

471470
@staticmethod
472471
def gen_blob_data(size: int, pattern: bytes) -> bytes:
@@ -598,8 +597,8 @@ def test_round_trip_with_segments(self) -> None:
598597
# meaning that the segments were moved back to inline. This also
599598
# demonstrates that the contents of all segments survived, and weren't
600599
# truncated or corrupted.
601-
program2 = deserialize_pte_binary(pte_data)
602-
self.assert_programs_equal(program, program2)
600+
deserialized = deserialize_pte_binary(pte_data)
601+
self.assert_programs_equal(program, deserialized.program)
603602

604603
def test_no_constants(self) -> None:
605604
program = get_test_program()
@@ -884,13 +883,13 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
884883
)
885884

886885
# Convert back.
887-
program2 = deserialize_pte_binary(pte_data)
886+
deserialized = deserialize_pte_binary(pte_data)
888887
# Programs are the same besides constant_buffer, as deserialization
889888
# does not preserve constant segment; padding may be added
890889
# during serialization.
891-
self.assertEqual(program2.execution_plan, program.execution_plan)
890+
self.assertEqual(deserialized.program.execution_plan, program.execution_plan)
892891
# Number of constant tensors should be the same.
893-
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))
892+
self.assertEqual(len(deserialized.program.constant_buffer), len(program.constant_buffer))
894893

895894
def test_named_data_segments(self) -> None:
896895
# Set segment alignment to 12 to test the padding.

0 commit comments

Comments
 (0)