@@ -89,6 +89,391 @@ def check_fastmath(pkg_dir, pkg_name):
89
89
return
90
90
91
91
92
+ class FunctionCallVisitor (ast .NodeVisitor ):
93
+ """
94
+ A class to traverse the AST of the modules of a package to collect
95
+ the call stacks of njit functions.
96
+
97
+ Parameters
98
+ ----------
99
+ pkg_dir : str
100
+ The path to the package directory containing some .py files.
101
+
102
+ pkg_name : str
103
+ The name of the package.
104
+
105
+ Attributes
106
+ ----------
107
+ module_names : list
108
+ A list of module names to track the modules as the visitor traverses them.
109
+
110
+ call_stack : list
111
+ A list of njit functions, representing a chain of function calls,
112
+ where each element is a string of the form "module_name.func_name".
113
+
114
+ out : list
115
+ A list of unique `call_stack`s.
116
+
117
+ njit_funcs : list
118
+ A list of all njit functions in `pkg_dir`'s modules. Each element is a tuple
119
+ of the form `(module_name, func_name)`.
120
+
121
+ njit_modules : set
122
+ A set that contains the names of all modules, each of which contains at least
123
+ one njit function.
124
+
125
+ njit_nodes : dict
126
+ A dictionary mapping njit function names to their corresponding AST nodes.
127
+ A key is a string, and it is of the form "module_name.func_name", and its
128
+ corresponding value is the AST node- with type ast.FunctionDef- of that
129
+ function.
130
+
131
+ ast_modules : dict
132
+ A dictionary mapping module names to their corresponding AST objects. A key
133
+ is the name of a module, and its corresponding value is the content of that
134
+ module as an AST object.
135
+
136
+ Methods
137
+ -------
138
+ push_module(module_name)
139
+ Push the name of a module onto the stack `module_names`.
140
+
141
+ pop_module()
142
+ Pop the last module name from the stack `module_names`.
143
+
144
+ push_call_stack(module_name, func_name)
145
+ Push a function call onto the stack of function calls, `call_stack`.
146
+
147
+ pop_call_stack()
148
+ Pop the last function call from the stack of function calls, `call_stack`
149
+
150
+ goto_deeper_func(node)
151
+ Calls the visit method from class `ast.NodeVisitor` on all children of
152
+ the `node`.
153
+
154
+ goto_next_func(node)
155
+ Calls the visit method from class `ast.NodeVisitor` on all children of
156
+ the `node`.
157
+
158
+ push_out()
159
+ Push the current function call stack, `call_stack`, onto the output list, `out`,
160
+ unless it is already included in one of the so-far-collected call stacks.
161
+
162
+ visit_Call(node)
163
+ This method is called when the visitor encounters a function call in the AST. It
164
+ checks if the called function is a njit function and, if so, traverses its AST
165
+ to collect its call stack.
166
+ """
167
+
168
+ def __init__ (self , pkg_dir , pkg_name ):
169
+ """
170
+ Initialize the FunctionCallVisitor class. This method sets up the necessary
171
+ attributes and prepares the visitor for traversing the AST of STUMPY's modules.
172
+
173
+ Parameters
174
+ ----------
175
+ pkg_dir : str
176
+ The path to the package directory containing some .py files.
177
+
178
+ pkg_name : str
179
+ The name of the package.
180
+
181
+ Returns
182
+ -------
183
+ None
184
+ """
185
+ super ().__init__ ()
186
+ self .module_names = []
187
+ self .call_stack = []
188
+ self .out = []
189
+
190
+ # Setup lists, dicts, and ast objects
191
+ self .njit_funcs = get_njit_funcs (pkg_dir )
192
+ self .njit_modules = set (mod_name for mod_name , func_name in self .njit_funcs )
193
+ self .njit_nodes = {}
194
+ self .ast_modules = {}
195
+
196
+ filepaths = sorted (f for f in pathlib .Path (pkg_dir ).iterdir () if f .is_file ())
197
+ ignore = ["__init__.py" , "__pycache__" ]
198
+
199
+ for filepath in filepaths :
200
+ file_name = filepath .name
201
+ if (
202
+ file_name not in ignore
203
+ and not file_name .startswith ("gpu" )
204
+ and str (filepath ).endswith (".py" )
205
+ ):
206
+ module_name = file_name .replace (".py" , "" )
207
+ file_contents = ""
208
+ with open (filepath , encoding = "utf8" ) as f :
209
+ file_contents = f .read ()
210
+ self .ast_modules [module_name ] = ast .parse (file_contents )
211
+
212
+ for node in self .ast_modules [module_name ].body :
213
+ if isinstance (node , ast .FunctionDef ):
214
+ func_name = node .name
215
+ if (module_name , func_name ) in self .njit_funcs :
216
+ self .njit_nodes [f"{ module_name } .{ func_name } " ] = node
217
+
218
+ def push_module (self , module_name ):
219
+ """
220
+ Push a module name onto the stack of module names.
221
+
222
+ Parameters
223
+ ----------
224
+ module_name : str
225
+ The name of the module to be pushed onto the stack.
226
+
227
+ Returns
228
+ -------
229
+ None
230
+ """
231
+ self .module_names .append (module_name )
232
+
233
+ return
234
+
235
+ def pop_module (self ):
236
+ """
237
+ Pop the last module name from the stack of module names.
238
+
239
+ Parameters
240
+ ----------
241
+ None
242
+
243
+ Returns
244
+ -------
245
+ None
246
+ """
247
+ if self .module_names :
248
+ self .module_names .pop ()
249
+
250
+ return
251
+
252
+ def push_call_stack (self , module_name , func_name ):
253
+ """
254
+ Push a function call onto the stack of function calls.
255
+
256
+ Parameters
257
+ ----------
258
+ module_name : str
259
+ A module's name
260
+
261
+ func_name : str
262
+ A function's name
263
+
264
+ Returns
265
+ -------
266
+ None
267
+ """
268
+ self .call_stack .append (f"{ module_name } .{ func_name } " )
269
+
270
+ return
271
+
272
+ def pop_call_stack (self ):
273
+ """
274
+ Pop the last function call from the stack of function calls.
275
+
276
+ Parameters
277
+ ----------
278
+ None
279
+
280
+ Returns
281
+ -------
282
+ None
283
+ """
284
+ if self .call_stack :
285
+ self .call_stack .pop ()
286
+
287
+ return
288
+
289
+ def goto_deeper_func (self , node ):
290
+ """
291
+ Calls the visit method from class `ast.NodeVisitor` on
292
+ all children of the `node`.
293
+
294
+ Parameters
295
+ ----------
296
+ node : ast.AST
297
+ The AST node to be visited.
298
+
299
+ Returns
300
+ -------
301
+ None
302
+ """
303
+ self .generic_visit (node )
304
+
305
+ return
306
+
307
+ def goto_next_func (self , node ):
308
+ """
309
+ Calls the visit method from class `ast.NodeVisitor` on
310
+ all children of the node.
311
+
312
+ Parameters
313
+ ----------
314
+ node : ast.AST
315
+ The AST node to be visited.
316
+
317
+ Returns
318
+ -------
319
+ None
320
+ """
321
+ self .generic_visit (node )
322
+
323
+ return
324
+
325
+ def push_out (self ):
326
+ """
327
+ Push the current function call stack onto the output list unless it
328
+ is already included in one of the so-far-collected call stacks.
329
+
330
+
331
+ Parameters
332
+ ----------
333
+ None
334
+
335
+ Returns
336
+ -------
337
+ None
338
+ """
339
+ unique = True
340
+ for cs in self .out :
341
+ if " " .join (self .call_stack ) in " " .join (cs ):
342
+ unique = False
343
+ break
344
+
345
+ if unique :
346
+ self .out .append (self .call_stack .copy ())
347
+
348
+ return
349
+
350
+ def visit_Call (self , node ):
351
+ """
352
+ Called when visiting an AST node of type `ast.Call`.
353
+
354
+ Parameters
355
+ ----------
356
+ node : ast.Call
357
+ The AST node representing a function call.
358
+
359
+ Returns
360
+ -------
361
+ None
362
+ """
363
+ callee_name = ast .unparse (node .func )
364
+
365
+ module_changed = False
366
+ if "." in callee_name :
367
+ new_module_name , new_func_name = callee_name .split ("." )[:2 ]
368
+
369
+ if new_module_name in self .njit_modules :
370
+ self .push_module (new_module_name )
371
+ module_changed = True
372
+ else :
373
+ if self .module_names :
374
+ new_module_name = self .module_names [- 1 ]
375
+ new_func_name = callee_name
376
+ callee_name = f"{ new_module_name } .{ new_func_name } "
377
+
378
+ if callee_name in self .njit_nodes .keys ():
379
+ callee_node = self .njit_nodes [callee_name ]
380
+ self .push_call_stack (new_module_name , new_func_name )
381
+ self .goto_deeper_func (callee_node )
382
+ self .push_out ()
383
+ self .pop_call_stack ()
384
+ if module_changed :
385
+ self .pop_module ()
386
+
387
+ self .goto_next_func (node )
388
+
389
+ return
390
+
391
+
392
+ def get_njit_call_stacks (pkg_dir , pkg_name ):
393
+ """
394
+ Get the call stacks of all njit functions in `pkg_dir`
395
+
396
+ Parameters
397
+ ----------
398
+ pkg_dir : str
399
+ The path to the package directory containing some .py files
400
+
401
+ pkg_name : str
402
+ The name of the package
403
+
404
+ Returns
405
+ -------
406
+ out : list
407
+ A list of unique function call stacks. Each item is of type list,
408
+ representing a chain of function calls.
409
+ """
410
+ visitor = FunctionCallVisitor (pkg_dir , pkg_name )
411
+
412
+ for module_name in visitor .njit_modules :
413
+ visitor .push_module (module_name )
414
+
415
+ for node in visitor .ast_modules [module_name ].body :
416
+ if isinstance (node , ast .FunctionDef ):
417
+ func_name = node .name
418
+ if (module_name , func_name ) in visitor .njit_funcs :
419
+ visitor .push_call_stack (module_name , func_name )
420
+ visitor .visit (node )
421
+ visitor .pop_call_stack ()
422
+
423
+ visitor .pop_module ()
424
+
425
+ return visitor .out
426
+
427
+
428
+ def check_call_stack_fastmath (pkg_dir , pkg_name ):
429
+ """
430
+ Check if all njit functions in a call stack have the same `fastmath` flag.
431
+ This function raises a ValueError if it finds any inconsistencies in the
432
+ `fastmath` flags in at lease one call stack of njit functions.
433
+
434
+ Parameters
435
+ ----------
436
+ pkg_dir : str
437
+ The path to the directory containing some .py files
438
+
439
+ pkg_name : str
440
+ The name of the package
441
+
442
+ Returns
443
+ -------
444
+ None
445
+ """
446
+ # List of call stacks with inconsistent fastmath flags
447
+ inconsistent_call_stacks = []
448
+
449
+ njit_call_stacks = get_njit_call_stacks (pkg_dir , pkg_name )
450
+ for cs in njit_call_stacks :
451
+ # Set the fastmath flag of the first function in the call stack
452
+ # as the reference flag
453
+ module_name , func_name = cs [0 ].split ("." )
454
+ module = importlib .import_module (f".{ module_name } " , package = "stumpy" )
455
+ func = getattr (module , func_name )
456
+ flag_ref = func .targetoptions ["fastmath" ]
457
+
458
+ for item in cs [1 :]:
459
+ module_name , func_name = cs [0 ].split ("." )
460
+ module = importlib .import_module (f".{ module_name } " , package = "stumpy" )
461
+ func = getattr (module , func_name )
462
+ flag = func .targetoptions ["fastmath" ]
463
+ if flag != flag_ref :
464
+ inconsistent_call_stacks .append (cs )
465
+ break
466
+
467
+ if len (inconsistent_call_stacks ) > 0 :
468
+ msg = (
469
+ "Found at least one call stack that has inconsistent `fastmath` flags. "
470
+ + f"Those call stacks are:\n { inconsistent_call_stacks } \n "
471
+ )
472
+ raise ValueError (msg )
473
+
474
+ return
475
+
476
+
92
477
if __name__ == "__main__" :
93
478
parser = argparse .ArgumentParser ()
94
479
parser .add_argument ("--check" , dest = "pkg_dir" )
@@ -98,3 +483,4 @@ def check_fastmath(pkg_dir, pkg_name):
98
483
pkg_dir = pathlib .Path (args .pkg_dir )
99
484
pkg_name = pkg_dir .name
100
485
check_fastmath (str (pkg_dir ), pkg_name )
486
+ check_call_stack_fastmath (str (pkg_dir ), pkg_name )
0 commit comments