Skip to content

Commit 1706b8e

Browse files
authored
Simplify extract_metadata.py by moving logic in tools/webassembly.py. NFC (#17255)
This creates a caching layer and some extra helper functions on the module object which avoid the need to track start such as the number of imports elements of a given type.
1 parent a8a5f77 commit 1706b8e

File tree

2 files changed

+80
-19
lines changed

2 files changed

+80
-19
lines changed

tools/extract_metadata.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def find_segment_with_address(module, address, size=0):
6868
if seg.size == size:
6969
return (seg, 0)
7070

71+
raise AssertionError('unable to find segment for address: %s' % address)
72+
7173

7274
def data_to_string(data):
7375
data = data.decode('utf8')
@@ -82,14 +84,14 @@ def data_to_string(data):
8284
return data
8385

8486

85-
def get_asm_strings(module, globls, export_map, imported_globals):
87+
def get_asm_strings(module, export_map):
8688
if '__start_em_asm' not in export_map or '__stop_em_asm' not in export_map:
8789
return {}
8890

8991
start = export_map['__start_em_asm']
9092
end = export_map['__stop_em_asm']
91-
start_global = globls[start.index - imported_globals]
92-
end_global = globls[end.index - imported_globals]
93+
start_global = module.get_global(start.index)
94+
end_global = module.get_global(end.index)
9395
start_addr = get_global_value(start_global)
9496
end_addr = get_global_value(end_global)
9597

@@ -110,29 +112,28 @@ def get_asm_strings(module, globls, export_map, imported_globals):
110112
return asm_strings
111113

112114

113-
def get_main_reads_params(module, export_map, imported_funcs):
115+
def get_main_reads_params(module, export_map):
114116
if settings.STANDALONE_WASM:
115117
return 1
116118

117119
main = export_map.get('main') or export_map.get('__main_argc_argv')
118120
if not main or main.kind != webassembly.ExternType.FUNC:
119121
return 0
120122

121-
functions = module.get_functions()
122-
main_func = functions[main.index - imported_funcs]
123+
main_func = module.get_function(main.index)
123124
if is_wrapper_function(module, main_func):
124125
return 0
125126
else:
126127
return 1
127128

128129

129-
def get_names_globals(globls, exports, imported_globals):
130+
def get_named_globals(module, exports):
130131
named_globals = {}
131132
for export in exports:
132133
if export.kind == webassembly.ExternType.GLOBAL:
133134
if export.name in ('__start_em_asm', '__stop_em_asm') or export.name.startswith('__em_js__'):
134135
continue
135-
g = globls[export.index - imported_globals]
136+
g = module.get_global(export.index)
136137
named_globals[export.name] = str(get_global_value(g))
137138
return named_globals
138139

@@ -167,26 +168,20 @@ def extract_metadata(filename):
167168
export_names = []
168169
declares = []
169170
invoke_funcs = []
170-
imported_funcs = 0
171-
imported_globals = 0
172171
global_imports = []
173172
em_js_funcs = {}
174173
exports = module.get_exports()
175174
imports = module.get_imports()
176-
globls = module.get_globals()
177175

178176
for i in imports:
179-
if i.kind == webassembly.ExternType.FUNC:
180-
imported_funcs += 1
181-
elif i.kind == webassembly.ExternType.GLOBAL:
182-
imported_globals += 1
177+
if i.kind == webassembly.ExternType.GLOBAL:
183178
global_imports.append(i.field)
184179

185180
export_map = {e.name: e for e in exports}
186181
for e in exports:
187182
if e.kind == webassembly.ExternType.GLOBAL and e.name.startswith('__em_js__'):
188183
name = e.name[len('__em_js__'):]
189-
globl = globls[e.index - imported_globals]
184+
globl = module.get_global(e.index)
190185
string_address = get_global_value(globl)
191186
em_js_funcs[name] = get_string_at(module, string_address)
192187

@@ -208,14 +203,14 @@ def extract_metadata(filename):
208203
# If main does not read its parameters, it will just be a stub that
209204
# calls __original_main (which has no parameters).
210205
metadata = {}
211-
metadata['asmConsts'] = get_asm_strings(module, globls, export_map, imported_globals)
206+
metadata['asmConsts'] = get_asm_strings(module, export_map)
212207
metadata['declares'] = declares
213208
metadata['emJsFuncs'] = em_js_funcs
214209
metadata['exports'] = export_names
215210
metadata['features'] = features
216211
metadata['globalImports'] = global_imports
217212
metadata['invokeFuncs'] = invoke_funcs
218-
metadata['mainReadsParams'] = get_main_reads_params(module, export_map, imported_funcs)
219-
metadata['namedGlobals'] = get_names_globals(globls, exports, imported_globals)
213+
metadata['mainReadsParams'] = get_main_reads_params(module, export_map)
214+
metadata['namedGlobals'] = get_named_globals(module, exports)
220215
# print("Metadata parsed: " + pprint.pformat(metadata))
221216
return metadata

tools/webassembly.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,32 @@ def read_sleb(iobuf):
5454
return leb128.i.decode_reader(iobuf)[0]
5555

5656

57+
# TODO(sbc): Use the builtin functools.cache once we update to python 3.9
58+
def cache(f):
59+
results = {}
60+
61+
def helper(*args, **kwargs):
62+
assert not kwargs
63+
key = args
64+
if key not in results:
65+
results[key] = f(*args, **kwargs)
66+
return results[key]
67+
68+
return helper
69+
70+
71+
def once(f):
72+
done = False
73+
74+
def helper(*args, **kwargs):
75+
nonlocal done
76+
if not done:
77+
done = True
78+
f(*args, **kwargs)
79+
80+
return helper
81+
82+
5783
class Type(IntEnum):
5884
I32 = 0x7f # -0x1
5985
I64 = 0x7e # -0x2
@@ -141,6 +167,7 @@ def __init__(self, filename):
141167
version = self.buf.read(4)
142168
if magic != MAGIC or version != VERSION:
143169
raise InvalidWasmError(f'{filename} is not a valid wasm file')
170+
self._done_calc_indexes = False
144171

145172
def __del__(self):
146173
if self.buf:
@@ -250,6 +277,7 @@ def parse_features_section(self):
250277
feature_count -= 1
251278
return features
252279

280+
@cache
253281
def parse_dylink_section(self):
254282
dylink_section = next(self.sections())
255283
assert dylink_section.type == SecType.CUSTOM
@@ -314,6 +342,7 @@ def parse_dylink_section(self):
314342

315343
return Dylink(mem_size, mem_align, table_size, table_align, needed, export_info, import_info)
316344

345+
@cache
317346
def get_exports(self):
318347
export_section = self.get_section(SecType.EXPORT)
319348
if not export_section:
@@ -330,6 +359,7 @@ def get_exports(self):
330359

331360
return exports
332361

362+
@cache
333363
def get_imports(self):
334364
import_section = self.get_section(SecType.IMPORT)
335365
if not import_section:
@@ -362,6 +392,7 @@ def get_imports(self):
362392

363393
return imports
364394

395+
@cache
365396
def get_globals(self):
366397
global_section = self.get_section(SecType.GLOBAL)
367398
if not global_section:
@@ -376,6 +407,7 @@ def get_globals(self):
376407
globls.append(Global(global_type, mutable, init))
377408
return globls
378409

410+
@cache
379411
def get_functions(self):
380412
code_section = self.get_section(SecType.CODE)
381413
if not code_section:
@@ -393,12 +425,14 @@ def get_functions(self):
393425
def get_section(self, section_code):
394426
return next((s for s in self.sections() if s.type == section_code), None)
395427

428+
@cache
396429
def get_custom_section(self, name):
397430
for section in self.sections():
398431
if section.type == SecType.CUSTOM and section.name == name:
399432
return section
400433
return None
401434

435+
@cache
402436
def get_segments(self):
403437
segments = []
404438
data_section = self.get_section(SecType.DATA)
@@ -416,6 +450,7 @@ def get_segments(self):
416450
self.seek(offset + size)
417451
return segments
418452

453+
@cache
419454
def get_tables(self):
420455
table_section = self.get_section(SecType.TABLE)
421456
if not table_section:
@@ -434,6 +469,37 @@ def get_tables(self):
434469
def has_name_section(self):
435470
return self.get_custom_section('name') is not None
436471

472+
@once
473+
def _calc_indexes(self):
474+
self.num_imported_funcs = 0
475+
self.num_imported_globals = 0
476+
self.num_imported_memories = 0
477+
self.num_imported_tables = 0
478+
self.num_imported_tags = 0
479+
for i in self.get_imports():
480+
if i.kind == ExternType.FUNC:
481+
self.num_imported_funcs += 1
482+
elif i.kind == ExternType.GLOBAL:
483+
self.num_imported_globals += 1
484+
elif i.kind == ExternType.MEMORY:
485+
self.num_imported_memories += 1
486+
elif i.kind == ExternType.TABLE:
487+
self.num_imported_tables += 1
488+
elif i.kind == ExternType.TAG:
489+
self.num_imported_tags += 1
490+
else:
491+
assert False, 'unhandled export type: %s' % i.kind
492+
493+
def get_function(self, idx):
494+
self._calc_indexes()
495+
assert idx >= self.num_imported_funcs
496+
return self.get_functions()[idx - self.num_imported_funcs]
497+
498+
def get_global(self, idx):
499+
self._calc_indexes()
500+
assert idx >= self.num_imported_globals
501+
return self.get_globals()[idx - self.num_imported_globals]
502+
437503

438504
def parse_dylink_section(wasm_file):
439505
module = Module(wasm_file)

0 commit comments

Comments
 (0)