Skip to content

Commit 427af05

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Check the associated files in the process units.
PiperOrigin-RevId: 318554536
1 parent 02e5df0 commit 427af05

File tree

4 files changed

+283
-67
lines changed

4 files changed

+283
-67
lines changed

tensorflow_lite_support/metadata/BUILD

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -79,36 +79,33 @@ py_library(
7979
],
8080
)
8181

82-
# TODO(b/160127851): Temporarily disable the test because it doesn't build with TF internal libs.
83-
# py_test(
84-
# name = "metadata_test",
85-
# srcs = ["metadata_test.py"],
86-
# data = ["testdata/golden_json.json"],
87-
# python_version = "PY3",
88-
# srcs_version = "PY2AND3",
89-
# tags = [
90-
# "no_mac", # TODO(b/148247402): flatbuffers import broken on Mac OS.
91-
# ],
92-
# deps = [
93-
# ":metadata",
94-
# ":metadata_schema_py",
95-
# ":schema_py",
96-
# "@flatbuffers//:runtime_py",
97-
# "@six_archive//:six",
98-
# "@org_tensorflow//tensorflow/python:client_testlib",
99-
# "@org_tensorflow//tensorflow/python:platform",
100-
# "@org_tensorflow//tensorflow/python:platform_test",
101-
# ],
102-
# )
82+
py_test(
83+
name = "metadata_test",
84+
srcs = ["metadata_test.py"],
85+
data = ["testdata/golden_json.json"],
86+
python_version = "PY3",
87+
srcs_version = "PY2AND3",
88+
tags = [
89+
"no_mac", # TODO(b/148247402): flatbuffers import broken on Mac OS.
90+
],
91+
deps = [
92+
":metadata",
93+
":metadata_schema_py",
94+
":schema_py",
95+
"//third_party/py/tensorflow",
96+
"@flatbuffers//:runtime_py",
97+
"@org_tensorflow//tensorflow/python:platform",
98+
"@six_archive//:six",
99+
],
100+
)
103101

104-
# TODO(b/160127851): Temporarily disable the test because it doesn't build with TF internal libs.
105-
# py_test(
106-
# name = "metadata_parser_test",
107-
# srcs = ["metadata_parser_test.py"],
108-
# python_version = "PY3",
109-
# srcs_version = "PY2AND3",
110-
# deps = [
111-
# ":metadata",
112-
# "@org_tensorflow//tensorflow/python:client_testlib",
113-
# ],
114-
# )
102+
py_test(
103+
name = "metadata_parser_test",
104+
srcs = ["metadata_parser_test.py"],
105+
python_version = "PY3",
106+
srcs_version = "PY2AND3",
107+
deps = [
108+
":metadata",
109+
"//third_party/py/tensorflow",
110+
],
111+
)

tensorflow_lite_support/metadata/metadata.py

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -179,23 +179,37 @@ def get_recorded_associated_file_list(self):
179179
metadata = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
180180
self._metadata_buf, 0)
181181

182-
# Add associated files attached to ModelMetadata
183-
self._get_associated_files_from_metadata_struct(metadata, recorded_files)
182+
# Add associated files attached to ModelMetadata.
183+
recorded_files += self._get_associated_files_from_table(
184+
metadata, "AssociatedFiles")
184185

185-
# Add associated files attached to each SubgraphMetadata
186+
# Add associated files attached to each SubgraphMetadata.
186187
for j in range(metadata.SubgraphMetadataLength()):
187188
subgraph = metadata.SubgraphMetadata(j)
188-
self._get_associated_files_from_metadata_struct(subgraph, recorded_files)
189+
recorded_files += self._get_associated_files_from_table(
190+
subgraph, "AssociatedFiles")
189191

190-
# Add associated files attached to each input tensor
192+
# Add associated files attached to each input tensor.
191193
for k in range(subgraph.InputTensorMetadataLength()):
192-
tensor = subgraph.InputTensorMetadata(k)
193-
self._get_associated_files_from_metadata_struct(tensor, recorded_files)
194+
recorded_files += self._get_associated_files_from_table(
195+
subgraph.InputTensorMetadata(k), "AssociatedFiles")
196+
recorded_files += self._get_associated_files_from_process_units(
197+
subgraph.InputTensorMetadata(k), "ProcessUnits")
194198

195-
# Add associated files attached to each output tensor
199+
# Add associated files attached to each output tensor.
196200
for k in range(subgraph.OutputTensorMetadataLength()):
197-
tensor = subgraph.OutputTensorMetadata(k)
198-
self._get_associated_files_from_metadata_struct(tensor, recorded_files)
201+
recorded_files += self._get_associated_files_from_table(
202+
subgraph.OutputTensorMetadata(k), "AssociatedFiles")
203+
recorded_files += self._get_associated_files_from_process_units(
204+
subgraph.OutputTensorMetadata(k), "ProcessUnits")
205+
206+
# Add associated files attached to the input_process_units.
207+
recorded_files += self._get_associated_files_from_process_units(
208+
subgraph, "InputProcessUnits")
209+
210+
# Add associated files attached to the output_process_units.
211+
recorded_files += self._get_associated_files_from_process_units(
212+
subgraph, "OutputProcessUnits")
199213

200214
return recorded_files
201215

@@ -317,9 +331,69 @@ def _copy_archived_files(self, src_zip, dst_zip, file_list):
317331
file_buffer = src_zf.read(f)
318332
dst_zf.writestr(f, file_buffer)
319333

320-
def _get_associated_files_from_metadata_struct(self, file_holder, file_list):
321-
for j in range(file_holder.AssociatedFilesLength()):
322-
file_list.append(file_holder.AssociatedFiles(j).Name().decode("utf-8"))
334+
def _get_associated_files_from_process_units(self, table, field_name):
335+
"""Gets the files that are attached the process units field of a table.
336+
337+
Args:
338+
table: a Flatbuffers table object that contains fields of an array of
339+
ProcessUnit, such as TensorMetadata and SubGraphMetadata.
340+
field_name: the name of the field in the table that represents an array of
341+
ProcessUnit. If the table is TensorMetadata, field_name can be
342+
"ProcessUnits". If the table is SubGraphMetadata, field_name can be
343+
either "InputProcessUnits" or "OutputProcessUnits".
344+
345+
Returns:
346+
the associated files list.
347+
"""
348+
349+
if table is None:
350+
return
351+
352+
file_list = []
353+
length_method = getattr(table, field_name + "Length")
354+
member_method = getattr(table, field_name)
355+
for k in range(length_method()):
356+
process_unit = member_method(k)
357+
tokenizer = process_unit.Options()
358+
if (process_unit.OptionsType() is
359+
_metadata_fb.ProcessUnitOptions.BertTokenizerOptions):
360+
bert_tokenizer = _metadata_fb.BertTokenizerOptions()
361+
bert_tokenizer.Init(tokenizer.Bytes, tokenizer.Pos)
362+
file_list += self._get_associated_files_from_table(
363+
bert_tokenizer, "VocabFile")
364+
elif (process_unit.OptionsType() is
365+
_metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions):
366+
sentence_piece = _metadata_fb.SentencePieceTokenizerOptions()
367+
sentence_piece.Init(tokenizer.Bytes, tokenizer.Pos)
368+
file_list += self._get_associated_files_from_table(
369+
sentence_piece, "SentencePieceModel")
370+
file_list += self._get_associated_files_from_table(
371+
sentence_piece, "VocabFile")
372+
return file_list
373+
374+
def _get_associated_files_from_table(self, table, field_name):
375+
"""Gets the associated files that are attached a table directly.
376+
377+
Args:
378+
table: a Flatbuffers table object that contains fields of an array of
379+
AssociatedFile, such as TensorMetadata and BertTokenizerOptions.
380+
field_name: the name of the field in the table that represents an array of
381+
ProcessUnit. If the table is TensorMetadata, field_name can be
382+
"AssociatedFiles". If the table is BertTokenizerOptions, field_name can
383+
be "VocabFile".
384+
385+
Returns:
386+
the associated files list.
387+
"""
388+
389+
if table is None:
390+
return
391+
file_list = []
392+
length_method = getattr(table, field_name + "Length")
393+
member_method = getattr(table, field_name)
394+
for j in range(length_method()):
395+
file_list.append(member_method(j).Name().decode("utf-8"))
396+
return file_list
323397

324398
def _populate_associated_files(self):
325399
"""Concatenates associated files after TensorFlow Lite model file.

tensorflow_lite_support/metadata/metadata_parser_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,19 @@
1919
from __future__ import print_function
2020

2121
import re
22+
import tensorflow as tf
2223

23-
from third_party.tensorflow.python.framework import test_util
24-
from third_party.tensorflow.python.platform import test
2524
from tensorflow_lite_support.metadata import metadata_parser
2625

2726

28-
class MetadataParserTest(test_util.TensorFlowTestCase):
27+
class MetadataParserTest(tf.test.TestCase):
2928

30-
def test_version_wellFormedSemanticVersion(self):
29+
def testVersionWellFormedSemanticVersion(self):
3130
# Validates that the version is well-formed (x.y.z).
3231
self.assertTrue(
3332
re.match('[0-9]+\\.[0-9]+\\.[0-9]+',
3433
metadata_parser.MetadataParser.VERSION))
3534

3635

3736
if __name__ == '__main__':
38-
test.main()
37+
tf.test.main()

0 commit comments

Comments
 (0)