2020import typing as tp
2121from types import MappingProxyType
2222import functools
23+ import itertools
2324from types import SimpleNamespace
2425
2526import jax
@@ -83,7 +84,7 @@ def lower(self, *args, **kwargs):
8384 except TypeError as e :
8485 result = self .f (* args , ** kwargs )
8586 # Mock a `Lowered` instance with a SimpleNamespace
86- return SimpleNamespace (cost_analysis = - 1 , lowered = SimpleNamespace (out_info = (None , None , result )))
87+ return SimpleNamespace (cost_analysis = "0" , lowered = SimpleNamespace (out_info = (None , None , result )))
8788
8889class SizeBytes (typing .SizeBytes ):
8990 def __repr__ (self ) -> str :
@@ -165,6 +166,7 @@ def __str__(self):
165166
166167@dataclasses .dataclass
167168class CallInfo :
169+ call_order : int
168170 object_id : int
169171 type : type
170172 path : statelib .PathParts
@@ -201,28 +203,6 @@ def inner(state, *args, **kwargs):
201203 return f (model , * args , ** kwargs )
202204 return jax .vjp (inner , state , * args , ** kwargs )
203205
204- def _get_call_info (lowered , method_name , node_stats , obj , compute_flops , inputs_repr , vjp_flops ):
205- flops = _get_flops (lowered ) if compute_flops else None
206- outputs = lowered .lowered .out_info [2 ]
207- output_repr = jax .tree .map (_to_dummy_array , outputs )
208- object_id : int = getattr (obj , '_nnx_tabulate_id' )
209- node_info = node_stats [object_id ]
210- assert node_info is not None
211- path = node_info .path
212- if method_name != '__call__' :
213- path = (* path , method_name )
214-
215- return CallInfo (
216- object_id = object_id ,
217- type = type (obj ),
218- path = path ,
219- inputs_repr = inputs_repr ,
220- outputs = output_repr ,
221- flops = flops ,
222- vjp_flops = vjp_flops
223- )
224-
225-
226206def filter_rng_streams (row : CallInfo ):
227207 return not issubclass (row .type , nnx .RngStream )
228208
@@ -235,38 +215,52 @@ def _create_obj_env(object_types):
235215 result [(obj_type , name )] = top_method
236216 return result
237217
238- def _argsave (counter , tracer_args , f , compute_vjp_flops ):
218+ def _get_inputs_repr (args , kwargs ):
219+ input_args , input_kwargs = jax .tree .map (
220+ _to_dummy_array , (args , kwargs )
221+ )
222+ inputs_repr = ''
223+ if input_args :
224+ if len (input_args ) == 1 and not input_kwargs :
225+ inputs_repr += _as_yaml_str (input_args [0 ])
226+ else :
227+ inputs_repr += _as_yaml_str (input_args )
228+ if input_kwargs :
229+ inputs_repr += '\n '
230+ if input_kwargs :
231+ inputs_repr += _as_yaml_str (input_kwargs )
232+ return inputs_repr
233+
234+ def _argsave (counter , tracer_args , f , node_stats , compute_flops , compute_vjp_flops ):
239235 "Wrap a function to save its arguments"
236+
237+ # Used when computing vjp flops
240238 def do_vjp (* args , ** kwargs ):
241239 primals , f_vjp = jax .vjp (f , * args , ** kwargs )
242240 return f_vjp (primals )
243- n = f .__name__
241+
242+ method_name = f .__name__
243+
244244 @wraps (f )
245245 def wrapper (obj , * args , ** kwargs ):
246- input_args , input_kwargs = jax .tree .map (
247- _to_dummy_array , (args , kwargs )
248- )
249- inputs_repr = ''
250- if input_args :
251- if len (input_args ) == 1 and not input_kwargs :
252- inputs_repr += _as_yaml_str (input_args [0 ])
253- else :
254- inputs_repr += _as_yaml_str (input_args )
255- if input_kwargs :
256- inputs_repr += '\n '
257- if input_kwargs :
258- inputs_repr += _as_yaml_str (input_kwargs )
259-
260- identifier = (inputs_repr , getattr (obj , '_nnx_tabulate_id' ))
246+ inputs_repr = _get_inputs_repr (args , kwargs )
247+ object_id = getattr (obj , '_nnx_tabulate_id' )
248+ node_info = node_stats [object_id ]
249+ path = node_info .path
250+ if method_name != '__call__' :
251+ path = (* path , method_name )
252+ identifier = (inputs_repr , object_id )
261253 if identifier not in f .seen :
262- counter_val = counter [0 ]
263- counter [0 ] += 1
254+ counter_val = next (counter )
264255 lowered = f .lower (obj , * args , ** kwargs )
265- if compute_vjp_flops :
266- vjp_flops = _get_flops (jax .jit (do_vjp ).lower (obj , * args , ** kwargs ))
267- else :
268- vjp_flops = None
269- tracer_args .append ((counter_val , obj , n , lowered , inputs_repr , vjp_flops ))
256+ flops = _get_flops (lowered ) if compute_flops else None
257+ outputs = lowered .lowered .out_info [2 ]
258+ output_repr = jax .tree .map (_to_dummy_array , outputs )
259+ vjp_flops = _get_flops (jax .jit (do_vjp ).lower (
260+ obj , * args , ** kwargs )) if compute_vjp_flops else None
261+ tracer_args .append (
262+ CallInfo (counter_val , object_id , type (obj ), path , inputs_repr ,
263+ output_repr , flops , vjp_flops ))
270264 f .seen .add (identifier )
271265 return f (obj , * args , ** kwargs )
272266 return wrapper
@@ -278,7 +272,7 @@ def _overwrite_methods(env):
278272
279273def _get_flops (e ) -> int :
280274 cost = e .cost_analysis () or e .compile ().cost_analysis ()
281- return - 1 if cost is None or 'flops' not in cost else int (cost ['flops' ])
275+ return 0 if cost is None or 'flops' not in cost else int (cost ['flops' ])
282276
283277def tabulate (
284278 obj ,
@@ -424,22 +418,20 @@ def tabulate(
424418 env = _create_obj_env (object_types )
425419
426420 # Information is recorded in post-order, but should be presented as a pre-order traversal.
427- # This counter is incremented in pre-order traversal to keep track of the order of calls.
428- counter = [ 0 ]
421+ # This keeps track of the order of calls.
422+ counter = itertools . count ( 0 )
429423
430424 # Modify all the object's methods to save their lowered JIT representations.
431- tracer_args = []
432- jits = {k : _argsave (counter , tracer_args , MaybeJit (v ), compute_vjp_flops ) for k ,v in env .items ()}
425+ rows : list [CallInfo ] = []
426+ jits = {k : _argsave (counter , rows , MaybeJit (v ), node_stats , compute_flops , compute_vjp_flops )
427+ for k ,v in env .items ()}
433428 _overwrite_methods (jits )
434429
435430 # Trace the top function (which indirectly traces all the others)
436431 jits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
437432
438- # Get call_info
439- rows : list [CallInfo ] = [_get_call_info (
440- lowered , name , node_stats , object ,
441- compute_flops , inputs_repr , vjp_flops )
442- for (_ , object , name , lowered , inputs_repr , vjp_flops ) in sorted (tracer_args , key = lambda x : x [0 ])]
433+ # Sort call info in pre-order traversal order
434+ rows .sort (key = lambda x : x .call_order )
443435
444436 # Restore the object's original methods
445437 _overwrite_methods (env )
0 commit comments