@@ -179,23 +179,37 @@ def get_recorded_associated_file_list(self):
179179 metadata = _metadata_fb .ModelMetadata .GetRootAsModelMetadata (
180180 self ._metadata_buf , 0 )
181181
182- # Add associated files attached to ModelMetadata
183- self ._get_associated_files_from_metadata_struct (metadata , recorded_files )
182+ # Add associated files attached to ModelMetadata.
183+ recorded_files += self ._get_associated_files_from_table (
184+ metadata , "AssociatedFiles" )
184185
185- # Add associated files attached to each SubgraphMetadata
186+ # Add associated files attached to each SubgraphMetadata.
186187 for j in range (metadata .SubgraphMetadataLength ()):
187188 subgraph = metadata .SubgraphMetadata (j )
188- self ._get_associated_files_from_metadata_struct (subgraph , recorded_files )
189+ recorded_files += self ._get_associated_files_from_table (
190+ subgraph , "AssociatedFiles" )
189191
190- # Add associated files attached to each input tensor
192+ # Add associated files attached to each input tensor.
191193 for k in range (subgraph .InputTensorMetadataLength ()):
192- tensor = subgraph .InputTensorMetadata (k )
193- self ._get_associated_files_from_metadata_struct (tensor , recorded_files )
194+ recorded_files += self ._get_associated_files_from_table (
195+ subgraph .InputTensorMetadata (k ), "AssociatedFiles" )
196+ recorded_files += self ._get_associated_files_from_process_units (
197+ subgraph .InputTensorMetadata (k ), "ProcessUnits" )
194198
195- # Add associated files attached to each output tensor
199+ # Add associated files attached to each output tensor.
196200 for k in range (subgraph .OutputTensorMetadataLength ()):
197- tensor = subgraph .OutputTensorMetadata (k )
198- self ._get_associated_files_from_metadata_struct (tensor , recorded_files )
201+ recorded_files += self ._get_associated_files_from_table (
202+ subgraph .OutputTensorMetadata (k ), "AssociatedFiles" )
203+ recorded_files += self ._get_associated_files_from_process_units (
204+ subgraph .OutputTensorMetadata (k ), "ProcessUnits" )
205+
206+ # Add associated files attached to the input_process_units.
207+ recorded_files += self ._get_associated_files_from_process_units (
208+ subgraph , "InputProcessUnits" )
209+
210+ # Add associated files attached to the output_process_units.
211+ recorded_files += self ._get_associated_files_from_process_units (
212+ subgraph , "OutputProcessUnits" )
199213
200214 return recorded_files
201215
@@ -317,9 +331,69 @@ def _copy_archived_files(self, src_zip, dst_zip, file_list):
317331 file_buffer = src_zf .read (f )
318332 dst_zf .writestr (f , file_buffer )
319333
320- def _get_associated_files_from_metadata_struct (self , file_holder , file_list ):
321- for j in range (file_holder .AssociatedFilesLength ()):
322- file_list .append (file_holder .AssociatedFiles (j ).Name ().decode ("utf-8" ))
334+ def _get_associated_files_from_process_units (self , table , field_name ):
335+ """Gets the files that are attached the process units field of a table.
336+
337+ Args:
338+ table: a Flatbuffers table object that contains fields of an array of
339+ ProcessUnit, such as TensorMetadata and SubGraphMetadata.
340+ field_name: the name of the field in the table that represents an array of
341+ ProcessUnit. If the table is TensorMetadata, field_name can be
342+ "ProcessUnits". If the table is SubGraphMetadata, field_name can be
343+ either "InputProcessUnits" or "OutputProcessUnits".
344+
345+ Returns:
346+ the associated files list.
347+ """
348+
349+ if table is None :
350+ return
351+
352+ file_list = []
353+ length_method = getattr (table , field_name + "Length" )
354+ member_method = getattr (table , field_name )
355+ for k in range (length_method ()):
356+ process_unit = member_method (k )
357+ tokenizer = process_unit .Options ()
358+ if (process_unit .OptionsType () is
359+ _metadata_fb .ProcessUnitOptions .BertTokenizerOptions ):
360+ bert_tokenizer = _metadata_fb .BertTokenizerOptions ()
361+ bert_tokenizer .Init (tokenizer .Bytes , tokenizer .Pos )
362+ file_list += self ._get_associated_files_from_table (
363+ bert_tokenizer , "VocabFile" )
364+ elif (process_unit .OptionsType () is
365+ _metadata_fb .ProcessUnitOptions .SentencePieceTokenizerOptions ):
366+ sentence_piece = _metadata_fb .SentencePieceTokenizerOptions ()
367+ sentence_piece .Init (tokenizer .Bytes , tokenizer .Pos )
368+ file_list += self ._get_associated_files_from_table (
369+ sentence_piece , "SentencePieceModel" )
370+ file_list += self ._get_associated_files_from_table (
371+ sentence_piece , "VocabFile" )
372+ return file_list
373+
374+ def _get_associated_files_from_table (self , table , field_name ):
375+ """Gets the associated files that are attached a table directly.
376+
377+ Args:
378+ table: a Flatbuffers table object that contains fields of an array of
379+ AssociatedFile, such as TensorMetadata and BertTokenizerOptions.
380+ field_name: the name of the field in the table that represents an array of
381+ ProcessUnit. If the table is TensorMetadata, field_name can be
382+ "AssociatedFiles". If the table is BertTokenizerOptions, field_name can
383+ be "VocabFile".
384+
385+ Returns:
386+ the associated files list.
387+ """
388+
389+ if table is None :
390+ return
391+ file_list = []
392+ length_method = getattr (table , field_name + "Length" )
393+ member_method = getattr (table , field_name )
394+ for j in range (length_method ()):
395+ file_list .append (member_method (j ).Name ().decode ("utf-8" ))
396+ return file_list
323397
324398 def _populate_associated_files (self ):
325399 """Concatenates associated files after TensorFlow Lite model file.
0 commit comments