Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@
# endian.
_HEADER_BYTEORDER: Literal["little"] = "little"

@dataclass
class SerializationArtifacts:
"""
Holds data required to serialize into a PTE.
"""
program: Program
mutable_data: Optional[List[Buffer]]
named_data: Optional[NamedDataStoreOutput]
# TODO: add constants here and remove constant buffer.

@dataclass
class AlignedData:
Expand Down Expand Up @@ -575,7 +584,7 @@ def serialize_pte_binary(
return pte_data


def _restore_segments(program: Program, segment_data: bytes) -> Program:
def _restore_segments(program: Program, segment_data: bytes) -> SerializationArtifacts:
"""Moves segments from `segment_data` into `program`.

This should recreate the original Program that the segments were extracted
Expand All @@ -589,7 +598,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
the preceding data has been stripped off so that the first segment
begins at offset zero.
Returns:
The Program with segments restored.
SerializationArtifacts, containing the Program with delegate and constant segments restored, as well as mutable and named data segments.
"""
# Extract the list of segment data blobs, which parallel program.segments.
segments: List[bytes] = []
Expand Down Expand Up @@ -624,7 +633,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:

# Replace constants from constant_segment into constant_buffer.
if program.constant_segment and len(program.constant_segment.offsets) > 0:
buffers: List[Buffer] = []
constant_buffers: List[Buffer] = []
constant_segment = segments[program.constant_segment.segment_index]
for i in range(len(program.constant_segment.offsets)):
start_offset = program.constant_segment.offsets[i]
Expand All @@ -635,17 +644,60 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
if i < len(program.constant_segment.offsets) - 1
else len(constant_segment)
)
buffers.append(Buffer(storage=constant_segment[start_offset:end_offset]))
program.constant_buffer = buffers
constant_buffers.append(Buffer(storage=constant_segment[start_offset:end_offset]))
program.constant_buffer = constant_buffers
program.constant_segment.segment_index = 0
program.constant_segment.offsets = []

# Extract mutable segments.
mutable_data = None
if program.mutable_data_segment and len(program.mutable_data_segments.offsets) > 0:
mutable_buffers: List[Buffer] = []
mutable_segment = segments[program.mutable_segment.segment_index]
for i in range(len(program.mutable_segments.offsets)):
start_offset = program.mutable_segment.offsets[i]
# Note: this is the original end offset plus any padding between
# it and the next start offset.
end_offset = (
program.mutable_segment.offsets[i + 1]
if i < len(program.mutable_segment.offsets) - 1
else len(mutable_segment)
)
mutable_buffers.append(Buffer(storage=mutable_segment[start_offset:end_offset]))
mutable_data = mutable_buffers
# Is this correct?
program.mutable_segment.segment_index = 0
program.mutable_segment.offsets = []

# Extract named data.
named_data = None
if program.named_data:
named_data_buffers: List[bytes] = []
pte_data: Dict[str, DataEntry] = {}

for entry in program.named_data:
if (entry.segment_index >= len(segments)):
raise ValueError(
"Named data segment index "
f"{entry.segment_index} >= num segments {len(segments)}"
)
named_data_buffers.append(segments[entry.segment_index])
pte_data[entry.key] = DataEntry(
buffer_index = len(named_data_buffers) - 1,
alignment = 1, # Deserialization does not preserve alignment.
tensor_layout = None
)
named_data = NamedDataStoreOutput(buffers=named_data_buffers, pte_data=pte_data, external_data=None)

# Clear out the segments list since the original Program didn't have one.
program.segments = []
return program

return SerializationArtifacts(
program=program,
mutable_data=mutable_data,
named_data=named_data
)

def deserialize_pte_binary(program_data: bytes) -> Program:
def deserialize_pte_binary(program_data: bytes) -> SerializationArtifacts:
"""Returns a Program deserialized from the given runtime binary data."""
program_size = len(program_data)
segment_base_offset = 0
Expand All @@ -664,8 +716,12 @@ def deserialize_pte_binary(program_data: bytes) -> Program:

if segment_base_offset != 0:
# Move segment data back into the Program.
program = _restore_segments(
return _restore_segments(
program=program, segment_data=program_data[segment_base_offset:]
)

return program
return SerializationArtifacts(
program=program,
mutable_data=None,
named_data=None,
)
27 changes: 13 additions & 14 deletions exir/_serialize/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,13 @@ def constant_segment_with_tensor_alignment(
)

# Convert back.
program2 = deserialize_pte_binary(pte_data)
deserialized = deserialize_pte_binary(pte_data)
# Programs are the same besides constant_buffer, as deserialization
# does not preserve constant segment; padding may be added
# during serialization.
self.assertEqual(program2.execution_plan, program.execution_plan)
self.assertEqual(deserialized.program.execution_plan, program.execution_plan)
# Number of constant tensors should be the same.
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))
self.assertEqual(len(deserialized.program.constant_buffer), len(program.constant_buffer))

def test_canonicalize_delegate_indices(self) -> None:
def make_execution_plan(
Expand Down Expand Up @@ -426,10 +426,9 @@ def test_round_trip_no_header_no_segments(self) -> None:
self.assertIsNone(eh)

# Convert back.
program2 = deserialize_pte_binary(pte_data)

deserialized = deserialize_pte_binary(pte_data)
# Programs should be the same.
self.assert_programs_equal(program, program2)
self.assert_programs_equal(program, deserialized.program)

def test_round_trip_large_buffer_sizes(self) -> None:
"""Tests that when the non_const_buffer_sizes contains integers
Expand All @@ -439,7 +438,7 @@ def test_round_trip_large_buffer_sizes(self) -> None:
program = get_test_program()
program.execution_plan[0].non_const_buffer_sizes = [0, 2**48]
flatbuffer_from_py = bytes(serialize_pte_binary(program))
self.assert_programs_equal(program, deserialize_pte_binary(flatbuffer_from_py))
self.assert_programs_equal(program, deserialize_pte_binary(flatbuffer_from_py).program)

def test_round_trip_no_segments_and_no_header(self) -> None:
"""Tests that a Program serialized with extract_delegate_segments=True
Expand All @@ -463,10 +462,10 @@ def test_round_trip_no_segments_and_no_header(self) -> None:
self.assertEqual(program_with_segments.segments, [])

# Convert back.
program2 = deserialize_pte_binary(pte_data)
deserialized = deserialize_pte_binary(pte_data)

# Programs should be the same.
self.assert_programs_equal(program, program2)
self.assert_programs_equal(program, deserialized.program)

@staticmethod
def gen_blob_data(size: int, pattern: bytes) -> bytes:
Expand Down Expand Up @@ -598,8 +597,8 @@ def test_round_trip_with_segments(self) -> None:
# meaning that the segments were moved back to inline. This also
# demonstrates that the contents of all segments survived, and weren't
# truncated or corrupted.
program2 = deserialize_pte_binary(pte_data)
self.assert_programs_equal(program, program2)
deserialized = deserialize_pte_binary(pte_data)
self.assert_programs_equal(program, deserialized.program)

def test_no_constants(self) -> None:
program = get_test_program()
Expand Down Expand Up @@ -884,13 +883,13 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
)

# Convert back.
program2 = deserialize_pte_binary(pte_data)
deserialized = deserialize_pte_binary(pte_data)
# Programs are the same besides constant_buffer, as deserialization
# does not preserve constant segment; padding may be added
# during serialization.
self.assertEqual(program2.execution_plan, program.execution_plan)
self.assertEqual(deserialized.program.execution_plan, program.execution_plan)
# Number of constant tensors should be the same.
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))
self.assertEqual(len(deserialized.program.constant_buffer), len(program.constant_buffer))

def test_named_data_segments(self) -> None:
# Set segment alignment to 12 to test the padding.
Expand Down
Loading