Skip to content

Commit 9d48fd8

Browse files
committed
Simplify extract_metadata.py by moving logic in tools/webassembly.py. NFC
This creates a caching layer and some extra helper functions on the module object which avoid the need to track start such as the number of imports elements of a given type.
1 parent ea65615 commit 9d48fd8

File tree

2 files changed

+73
-19
lines changed

2 files changed

+73
-19
lines changed

tools/extract_metadata.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def find_segment_with_address(module, address, size=0):
6868
if seg.size == size:
6969
return (seg, 0)
7070

71+
raise AssertionError('unable to find segment for address: %s' % address)
72+
7173

7274
def data_to_string(data):
7375
data = data.decode('utf8')
@@ -82,14 +84,14 @@ def data_to_string(data):
8284
return data
8385

8486

85-
def get_asm_strings(module, globls, export_map, imported_globals):
87+
def get_asm_strings(module, export_map):
8688
if '__start_em_asm' not in export_map or '__stop_em_asm' not in export_map:
8789
return {}
8890

8991
start = export_map['__start_em_asm']
9092
end = export_map['__stop_em_asm']
91-
start_global = globls[start.index - imported_globals]
92-
end_global = globls[end.index - imported_globals]
93+
start_global = module.get_global(start.index)
94+
end_global = module.get_global(end.index)
9395
start_addr = get_global_value(start_global)
9496
end_addr = get_global_value(end_global)
9597

@@ -110,29 +112,28 @@ def get_asm_strings(module, globls, export_map, imported_globals):
110112
return asm_strings
111113

112114

113-
def get_main_reads_params(module, export_map, imported_funcs):
115+
def get_main_reads_params(module, export_map):
114116
if settings.STANDALONE_WASM:
115117
return 1
116118

117119
main = export_map.get('main') or export_map.get('__main_argc_argv')
118120
if not main or main.kind != webassembly.ExternType.FUNC:
119121
return 0
120122

121-
functions = module.get_functions()
122-
main_func = functions[main.index - imported_funcs]
123+
main_func = module.get_function(main.index)
123124
if is_wrapper_function(module, main_func):
124125
return 0
125126
else:
126127
return 1
127128

128129

129-
def get_names_globals(globls, exports, imported_globals):
130+
def get_named_globals(module, exports):
130131
named_globals = {}
131132
for export in exports:
132133
if export.kind == webassembly.ExternType.GLOBAL:
133134
if export.name in ('__start_em_asm', '__stop_em_asm') or export.name.startswith('__em_js__'):
134135
continue
135-
g = globls[export.index - imported_globals]
136+
g = module.get_global(export.index)
136137
named_globals[export.name] = str(get_global_value(g))
137138
return named_globals
138139

@@ -167,26 +168,20 @@ def extract_metadata(filename):
167168
export_names = []
168169
declares = []
169170
invoke_funcs = []
170-
imported_funcs = 0
171-
imported_globals = 0
172171
global_imports = []
173172
em_js_funcs = {}
174173
exports = module.get_exports()
175174
imports = module.get_imports()
176-
globls = module.get_globals()
177175

178176
for i in imports:
179-
if i.kind == webassembly.ExternType.FUNC:
180-
imported_funcs += 1
181-
elif i.kind == webassembly.ExternType.GLOBAL:
182-
imported_globals += 1
177+
if i.kind == webassembly.ExternType.GLOBAL:
183178
global_imports.append(i.field)
184179

185180
export_map = {e.name: e for e in exports}
186181
for e in exports:
187182
if e.kind == webassembly.ExternType.GLOBAL and e.name.startswith('__em_js__'):
188183
name = e.name[len('__em_js__'):]
189-
globl = globls[e.index - imported_globals]
184+
globl = module.get_global(e.index)
190185
string_address = get_global_value(globl)
191186
em_js_funcs[name] = get_string_at(module, string_address)
192187

@@ -208,14 +203,14 @@ def extract_metadata(filename):
208203
# If main does not read its parameters, it will just be a stub that
209204
# calls __original_main (which has no parameters).
210205
metadata = {}
211-
metadata['asmConsts'] = get_asm_strings(module, globls, export_map, imported_globals)
206+
metadata['asmConsts'] = get_asm_strings(module, export_map)
212207
metadata['declares'] = declares
213208
metadata['emJsFuncs'] = em_js_funcs
214209
metadata['exports'] = export_names
215210
metadata['features'] = features
216211
metadata['globalImports'] = global_imports
217212
metadata['invokeFuncs'] = invoke_funcs
218-
metadata['mainReadsParams'] = get_main_reads_params(module, export_map, imported_funcs)
219-
metadata['namedGlobals'] = get_names_globals(globls, exports, imported_globals)
213+
metadata['mainReadsParams'] = get_main_reads_params(module, export_map)
214+
metadata['namedGlobals'] = get_named_globals(module, exports)
220215
# print("Metadata parsed: " + pprint.pformat(metadata))
221216
return metadata

tools/webassembly.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,28 @@ def read_sleb(iobuf):
5454
return leb128.i.decode_reader(iobuf)[0]
5555

5656

57+
# TODO(sbc): Use the builtin functools.cache once we update to python 3.9
58+
def cache(f):
59+
results = {}
60+
def helper(*args, **kwargs):
61+
assert not kwargs
62+
key = args
63+
if key not in results:
64+
results[key] = f(*args, **kwargs)
65+
return results[key]
66+
return helper
67+
68+
69+
def once(f):
70+
done = False
71+
def helper(*args, **kwargs):
72+
nonlocal done
73+
if not done:
74+
done = True
75+
f(*args, **kwargs)
76+
return helper
77+
78+
5779
class Type(IntEnum):
5880
I32 = 0x7f # -0x1
5981
I64 = 0x7e # -0x2
@@ -141,6 +163,7 @@ def __init__(self, filename):
141163
version = self.buf.read(4)
142164
if magic != MAGIC or version != VERSION:
143165
raise InvalidWasmError(f'{filename} is not a valid wasm file')
166+
self._done_calc_indexes = False
144167

145168
def __del__(self):
146169
if self.buf:
@@ -250,6 +273,7 @@ def parse_features_section(self):
250273
feature_count -= 1
251274
return features
252275

276+
@cache
253277
def parse_dylink_section(self):
254278
dylink_section = next(self.sections())
255279
assert dylink_section.type == SecType.CUSTOM
@@ -314,6 +338,7 @@ def parse_dylink_section(self):
314338

315339
return Dylink(mem_size, mem_align, table_size, table_align, needed, export_info, import_info)
316340

341+
@cache
317342
def get_exports(self):
318343
export_section = self.get_section(SecType.EXPORT)
319344
if not export_section:
@@ -330,6 +355,7 @@ def get_exports(self):
330355

331356
return exports
332357

358+
@cache
333359
def get_imports(self):
334360
import_section = self.get_section(SecType.IMPORT)
335361
if not import_section:
@@ -362,6 +388,7 @@ def get_imports(self):
362388

363389
return imports
364390

391+
@cache
365392
def get_globals(self):
366393
global_section = self.get_section(SecType.GLOBAL)
367394
if not global_section:
@@ -376,6 +403,7 @@ def get_globals(self):
376403
globls.append(Global(global_type, mutable, init))
377404
return globls
378405

406+
@cache
379407
def get_functions(self):
380408
code_section = self.get_section(SecType.CODE)
381409
if not code_section:
@@ -393,12 +421,14 @@ def get_functions(self):
393421
def get_section(self, section_code):
394422
return next((s for s in self.sections() if s.type == section_code), None)
395423

424+
@cache
396425
def get_custom_section(self, name):
397426
for section in self.sections():
398427
if section.type == SecType.CUSTOM and section.name == name:
399428
return section
400429
return None
401430

431+
@cache
402432
def get_segments(self):
403433
segments = []
404434
data_section = self.get_section(SecType.DATA)
@@ -416,6 +446,7 @@ def get_segments(self):
416446
self.seek(offset + size)
417447
return segments
418448

449+
@cache
419450
def get_tables(self):
420451
table_section = self.get_section(SecType.TABLE)
421452
if not table_section:
@@ -434,6 +465,34 @@ def get_tables(self):
434465
def has_name_section(self):
435466
return self.get_custom_section('name') is not None
436467

468+
@once
469+
def _calc_indexes(self):
470+
self.num_imported_funcs = 0
471+
self.num_imported_globals = 0
472+
self.num_imported_memories = 0
473+
self.num_imported_tables = 0
474+
for i in self.get_imports():
475+
if i.kind == ExternType.FUNC:
476+
self.num_imported_funcs += 1
477+
elif i.kind == ExternType.GLOBAL:
478+
self.num_imported_globals += 1
479+
elif i.kind == ExternType.MEMORY:
480+
self.num_imported_memories += 1
481+
elif i.kind == ExternType.TABLE:
482+
self.num_imported_tables += 1
483+
else:
484+
assert False
485+
486+
def get_function(self, idx):
487+
self._calc_indexes()
488+
assert idx >= self.num_imported_funcs
489+
return self.get_functions()[idx - self.num_imported_funcs]
490+
491+
def get_global(self, idx):
492+
self._calc_indexes()
493+
assert idx >= self.num_imported_globals
494+
return self.get_globals()[idx - self.num_imported_globals]
495+
437496

438497
def parse_dylink_section(wasm_file):
439498
module = Module(wasm_file)

0 commit comments

Comments
 (0)