4545# endian.
4646_HEADER_BYTEORDER : Literal ["little" ] = "little"
4747
48+ @dataclass
49+ class SerializationArtifacts :
50+ """
51+ Holds data required to serialize into a PTE.
52+ """
53+ program : Program
54+ mutable_data : Optional [List [Buffer ]]
55+ named_data : Optional [NamedDataStoreOutput ]
56+ # TODO: add constants here and remove constant buffer.
4857
4958@dataclass
5059class AlignedData :
@@ -575,7 +584,7 @@ def serialize_pte_binary(
575584 return pte_data
576585
577586
578- def _restore_segments (program : Program , segment_data : bytes ) -> Program :
587+ def _restore_segments (program : Program , segment_data : bytes ) -> SerializationArtifacts :
579588 """Moves segments from `segment_data` into `program`.
580589
581590 This should recreate the original Program that the segments were extracted
@@ -589,7 +598,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
589598 the preceding data has been stripped off so that the first segment
590599 begins at offset zero.
591600 Returns:
592- The Program with segments restored.
601+ SerializationArtifacts, containing the Program with delegate and constant segments restored, as well as mutable and named data segments .
593602 """
594603 # Extract the list of segment data blobs, which parallel program.segments.
595604 segments : List [bytes ] = []
@@ -624,7 +633,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
624633
625634 # Replace constants from constant_segment into constant_buffer.
626635 if program .constant_segment and len (program .constant_segment .offsets ) > 0 :
627- buffers : List [Buffer ] = []
636+ constant_buffers : List [Buffer ] = []
628637 constant_segment = segments [program .constant_segment .segment_index ]
629638 for i in range (len (program .constant_segment .offsets )):
630639 start_offset = program .constant_segment .offsets [i ]
@@ -635,17 +644,60 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
635644 if i < len (program .constant_segment .offsets ) - 1
636645 else len (constant_segment )
637646 )
638- buffers .append (Buffer (storage = constant_segment [start_offset :end_offset ]))
639- program .constant_buffer = buffers
647+ constant_buffers .append (Buffer (storage = constant_segment [start_offset :end_offset ]))
648+ program .constant_buffer = constant_buffers
640649 program .constant_segment .segment_index = 0
641650 program .constant_segment .offsets = []
642651
652+ # Extract mutable segments.
653+ mutable_data = None
654+ if program .mutable_data_segment and len (program .mutable_data_segments .offsets ) > 0 :
655+ mutable_buffers : List [Buffer ] = []
656+ mutable_segment = segments [program .mutable_segment .segment_index ]
657+ for i in range (len (program .mutable_segments .offsets )):
658+ start_offset = program .mutable_segment .offsets [i ]
659+ # Note: this is the original end offset plus any padding between
660+ # it and the next start offset.
661+ end_offset = (
662+ program .mutable_segment .offsets [i + 1 ]
663+ if i < len (program .mutable_segment .offsets ) - 1
664+ else len (mutable_segment )
665+ )
666+ mutable_buffers .append (Buffer (storage = mutable_segment [start_offset :end_offset ]))
667+ mutable_data = mutable_buffers
668+ # Is this correct?
669+ program .mutable_segment .segment_index = 0
670+ program .mutable_segment .offsets = []
671+
672+ # Extract named data.
673+ named_data = None
674+ if program .named_data :
675+ named_data_buffers : List [bytes ] = []
676+ pte_data : Dict [str , DataEntry ] = {}
677+
678+ for entry in program .named_data :
679+ if (entry .segment_index >= len (segments )):
680+ raise ValueError (
681+ "Named data segment index "
682+ f"{ entry .segment_index } >= num segments { len (segments )} "
683+ )
684+ named_data_buffers .append (segments [entry .segment_index ])
685+ pte_data [entry .key ] = DataEntry (
686+ buffer_index = len (named_data_buffers ) - 1 ,
687+ alignment = 1 , # Deserialization does not preserve alignment.
688+ tensor_layout = None
689+ )
690+ named_data = NamedDataStoreOutput (buffers = named_data_buffers , pte_data = pte_data , external_data = None )
691+
643692 # Clear out the segments list since the original Program didn't have one.
644693 program .segments = []
645- return program
646-
694+ return SerializationArtifacts (
695+ program = program ,
696+ mutable_data = mutable_data ,
697+ named_data = named_data
698+ )
647699
648- def deserialize_pte_binary (program_data : bytes ) -> Program :
700+ def deserialize_pte_binary (program_data : bytes ) -> SerializationArtifacts :
649701 """Returns a Program deserialized from the given runtime binary data."""
650702 program_size = len (program_data )
651703 segment_base_offset = 0
@@ -664,8 +716,12 @@ def deserialize_pte_binary(program_data: bytes) -> Program:
664716
665717 if segment_base_offset != 0 :
666718 # Move segment data back into the Program.
667- program = _restore_segments (
719+ return _restore_segments (
668720 program = program , segment_data = program_data [segment_base_offset :]
669721 )
670722
671- return program
723+ return SerializationArtifacts (
724+ program = program ,
725+ mutable_data = None ,
726+ named_data = None ,
727+ )
0 commit comments