@@ -83,7 +83,7 @@ def lower(self, *args, **kwargs):
8383 except TypeError as e :
8484 result = self .f (* args , ** kwargs )
8585 # Mock a `Lowered` instance with a SimpleNamespace
86- return SimpleNamespace (cost_analysis = - 1 , lowered = SimpleNamespace (out_info = (None , None , result )))
86+ return SimpleNamespace (cost_analysis = "0" , lowered = SimpleNamespace (out_info = (None , None , result )))
8787
8888class SizeBytes (typing .SizeBytes ):
8989 def __repr__ (self ) -> str :
@@ -165,6 +165,7 @@ def __str__(self):
165165
166166@dataclasses .dataclass
167167class CallInfo :
168+ call_order : int
168169 object_id : int
169170 type : type
170171 path : statelib .PathParts
@@ -201,28 +202,6 @@ def inner(state, *args, **kwargs):
201202 return f (model , * args , ** kwargs )
202203 return jax .vjp (inner , state , * args , ** kwargs )
203204
204- def _get_call_info (lowered , method_name , node_stats , obj , compute_flops , inputs_repr , vjp_flops ):
205- flops = _get_flops (lowered ) if compute_flops else None
206- outputs = lowered .lowered .out_info [2 ]
207- output_repr = jax .tree .map (_to_dummy_array , outputs )
208- object_id : int = getattr (obj , '_nnx_tabulate_id' )
209- node_info = node_stats [object_id ]
210- assert node_info is not None
211- path = node_info .path
212- if method_name != '__call__' :
213- path = (* path , method_name )
214-
215- return CallInfo (
216- object_id = object_id ,
217- type = type (obj ),
218- path = path ,
219- inputs_repr = inputs_repr ,
220- outputs = output_repr ,
221- flops = flops ,
222- vjp_flops = vjp_flops
223- )
224-
225-
226205def filter_rng_streams (row : CallInfo ):
227206 return not issubclass (row .type , nnx .RngStream )
228207
@@ -235,38 +214,53 @@ def _create_obj_env(object_types):
235214 result [(obj_type , name )] = top_method
236215 return result
237216
238- def _argsave (counter , tracer_args , f , compute_vjp_flops ):
217+ def _get_inputs_repr (args , kwargs ):
218+ input_args , input_kwargs = jax .tree .map (
219+ _to_dummy_array , (args , kwargs )
220+ )
221+ inputs_repr = ''
222+ if input_args :
223+ if len (input_args ) == 1 and not input_kwargs :
224+ inputs_repr += _as_yaml_str (input_args [0 ])
225+ else :
226+ inputs_repr += _as_yaml_str (input_args )
227+ if input_kwargs :
228+ inputs_repr += '\n '
229+ if input_kwargs :
230+ inputs_repr += _as_yaml_str (input_kwargs )
231+ return inputs_repr
232+
233+ def _argsave (counter , tracer_args , f , node_stats , compute_flops , compute_vjp_flops ):
239234 "Wrap a function to save its arguments"
235+
236+ # Used when computing vjp flops
240237 def do_vjp (* args , ** kwargs ):
241238 primals , f_vjp = jax .vjp (f , * args , ** kwargs )
242239 return f_vjp (primals )
243- n = f .__name__
240+
241+ method_name = f .__name__
242+
244243 @wraps (f )
245244 def wrapper (obj , * args , ** kwargs ):
246- input_args , input_kwargs = jax .tree .map (
247- _to_dummy_array , (args , kwargs )
248- )
249- inputs_repr = ''
250- if input_args :
251- if len (input_args ) == 1 and not input_kwargs :
252- inputs_repr += _as_yaml_str (input_args [0 ])
253- else :
254- inputs_repr += _as_yaml_str (input_args )
255- if input_kwargs :
256- inputs_repr += '\n '
257- if input_kwargs :
258- inputs_repr += _as_yaml_str (input_kwargs )
259-
260- identifier = (inputs_repr , getattr (obj , '_nnx_tabulate_id' ))
245+ inputs_repr = _get_inputs_repr (args , kwargs )
246+ object_id = getattr (obj , '_nnx_tabulate_id' )
247+ node_info = node_stats [object_id ]
248+ path = node_info .path
249+ if method_name != '__call__' :
250+ path = (* path , method_name )
251+ identifier = (inputs_repr , object_id )
261252 if identifier not in f .seen :
262253 counter_val = counter [0 ]
263254 counter [0 ] += 1
264255 lowered = f .lower (obj , * args , ** kwargs )
265- if compute_vjp_flops :
266- vjp_flops = _get_flops (jax .jit (do_vjp ).lower (obj , * args , ** kwargs ))
267- else :
268- vjp_flops = None
269- tracer_args .append ((counter_val , obj , n , lowered , inputs_repr , vjp_flops ))
256+ flops = _get_flops (lowered ) if compute_flops else None
257+ outputs = lowered .lowered .out_info [2 ]
258+ output_repr = jax .tree .map (_to_dummy_array , outputs )
259+ vjp_flops = _get_flops (jax .jit (do_vjp ).lower (
260+ obj , * args , ** kwargs )) if compute_vjp_flops else None
261+ tracer_args .append (
262+ CallInfo (counter_val , object_id , type (obj ), path , inputs_repr ,
263+ output_repr , flops , vjp_flops ))
270264 f .seen .add (identifier )
271265 return f (obj , * args , ** kwargs )
272266 return wrapper
@@ -278,7 +272,7 @@ def _overwrite_methods(env):
278272
279273def _get_flops (e ) -> int :
280274 cost = e .cost_analysis () or e .compile ().cost_analysis ()
281- return - 1 if cost is None or 'flops' not in cost else int (cost ['flops' ])
275+ return 0 if cost is None or 'flops' not in cost else int (cost ['flops' ])
282276
283277def tabulate (
284278 obj ,
@@ -428,18 +422,16 @@ def tabulate(
428422 counter = [0 ]
429423
430424 # Modify all the object's methods to save their lowered JIT representations.
431- tracer_args = []
432- jits = {k : _argsave (counter , tracer_args , MaybeJit (v ), compute_vjp_flops ) for k ,v in env .items ()}
425+ rows = []
426+ jits = {k : _argsave (counter , rows , MaybeJit (v ), node_stats , compute_flops , compute_vjp_flops )
427+ for k ,v in env .items ()}
433428 _overwrite_methods (jits )
434429
435430 # Trace the top function (which indirectly traces all the others)
436431 jits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
437432
438- # Get call_info
439- rows : list [CallInfo ] = [_get_call_info (
440- lowered , name , node_stats , object ,
441- compute_flops , inputs_repr , vjp_flops )
442- for (_ , object , name , lowered , inputs_repr , vjp_flops ) in sorted (tracer_args , key = lambda x : x [0 ])]
433+ # Sort call info in pre-order traversal order
434+ rows .sort (key = lambda x : x .call_order )
443435
444436 # Restore the object's original methods
445437 _overwrite_methods (env )
0 commit comments