2020import typing as tp
2121from types import MappingProxyType
2222import functools
23+ import itertools
2324from types import SimpleNamespace
2425
2526import jax
@@ -83,7 +84,7 @@ def lower(self, *args, **kwargs):
8384 except TypeError as e :
8485 result = self .f (* args , ** kwargs )
8586 # Mock a `Lowered` instance with a SimpleNamespace
86- return SimpleNamespace (cost_analysis = - 1 , lowered = SimpleNamespace (out_info = (None , None , result )))
87+ return SimpleNamespace (cost_analysis = "0" , lowered = SimpleNamespace (out_info = (None , None , result )))
8788
8889class SizeBytes (typing .SizeBytes ):
8990 def __repr__ (self ) -> str :
@@ -165,6 +166,7 @@ def __str__(self):
165166
166167@dataclasses .dataclass
167168class CallInfo :
169+ call_order : int
168170 object_id : int
169171 type : type
170172 path : statelib .PathParts
@@ -185,13 +187,14 @@ def __repr__(self):
185187
186188
187189def _to_dummy_array (x ):
188- try :
190+ if isinstance ( x , jax . ShapeDtypeStruct ) :
189191 return ArrayRepr (x .shape , x .dtype )
190- except :
191- if graph .is_graph_node (x ):
192- return SimpleObjectRepr (x )
193- else :
194- return x
192+ elif isinstance (x , jax .Array | np .ndarray ):
193+ return ArrayRepr .from_array (x )
194+ elif graph .is_graph_node (x ):
195+ return SimpleObjectRepr (x )
196+ else :
197+ return x
195198
196199def _pure_nnx_vjp (f , model , * args , ** kwargs ):
197200 "Wrap nnx functional api around jax.vjp. Only handles pure method calls."
@@ -201,28 +204,6 @@ def inner(state, *args, **kwargs):
201204 return f (model , * args , ** kwargs )
202205 return jax .vjp (inner , state , * args , ** kwargs )
203206
204- def _get_call_info (lowered , method_name , node_stats , obj , compute_flops , inputs_repr , vjp_flops ):
205- flops = _get_flops (lowered ) if compute_flops else None
206- outputs = lowered .lowered .out_info [2 ]
207- output_repr = jax .tree .map (_to_dummy_array , outputs )
208- object_id : int = getattr (obj , '_nnx_tabulate_id' )
209- node_info = node_stats [object_id ]
210- assert node_info is not None
211- path = node_info .path
212- if method_name != '__call__' :
213- path = (* path , method_name )
214-
215- return CallInfo (
216- object_id = object_id ,
217- type = type (obj ),
218- path = path ,
219- inputs_repr = inputs_repr ,
220- outputs = output_repr ,
221- flops = flops ,
222- vjp_flops = vjp_flops
223- )
224-
225-
226207def filter_rng_streams (row : CallInfo ):
227208 return not issubclass (row .type , nnx .RngStream )
228209
@@ -235,38 +216,52 @@ def _create_obj_env(object_types):
235216 result [(obj_type , name )] = top_method
236217 return result
237218
238- def _argsave (counter , tracer_args , f , compute_vjp_flops ):
219+ def _get_inputs_repr (args , kwargs ):
220+ input_args , input_kwargs = jax .tree .map (
221+ _to_dummy_array , (args , kwargs )
222+ )
223+ inputs_repr = ''
224+ if input_args :
225+ if len (input_args ) == 1 and not input_kwargs :
226+ inputs_repr += _as_yaml_str (input_args [0 ])
227+ else :
228+ inputs_repr += _as_yaml_str (input_args )
229+ if input_kwargs :
230+ inputs_repr += '\n '
231+ if input_kwargs :
232+ inputs_repr += _as_yaml_str (input_kwargs )
233+ return inputs_repr
234+
235+ def _save_call_info (counter , tracer_args , f , node_stats , compute_flops , compute_vjp_flops ):
239236 "Wrap a function to save its arguments"
237+
238+ # Used when computing vjp flops
240239 def do_vjp (* args , ** kwargs ):
241240 primals , f_vjp = jax .vjp (f , * args , ** kwargs )
242241 return f_vjp (primals )
243- n = f .__name__
242+
243+ method_name = f .__name__
244+
244245 @wraps (f )
245246 def wrapper (obj , * args , ** kwargs ):
246- input_args , input_kwargs = jax .tree .map (
247- _to_dummy_array , (args , kwargs )
248- )
249- inputs_repr = ''
250- if input_args :
251- if len (input_args ) == 1 and not input_kwargs :
252- inputs_repr += _as_yaml_str (input_args [0 ])
253- else :
254- inputs_repr += _as_yaml_str (input_args )
255- if input_kwargs :
256- inputs_repr += '\n '
257- if input_kwargs :
258- inputs_repr += _as_yaml_str (input_kwargs )
259-
260- identifier = (inputs_repr , getattr (obj , '_nnx_tabulate_id' ))
247+ inputs_repr = _get_inputs_repr (args , kwargs )
248+ object_id = getattr (obj , '_nnx_tabulate_id' )
249+ node_info = node_stats [object_id ]
250+ path = node_info .path
251+ if method_name != '__call__' :
252+ path = (* path , method_name )
253+ identifier = (inputs_repr , object_id )
261254 if identifier not in f .seen :
262- counter_val = counter [0 ]
263- counter [0 ] += 1
255+ counter_val = next (counter )
264256 lowered = f .lower (obj , * args , ** kwargs )
265- if compute_vjp_flops :
266- vjp_flops = _get_flops (jax .jit (do_vjp ).lower (obj , * args , ** kwargs ))
267- else :
268- vjp_flops = None
269- tracer_args .append ((counter_val , obj , n , lowered , inputs_repr , vjp_flops ))
257+ flops = _get_flops (lowered ) if compute_flops else None
258+ outputs = lowered .lowered .out_info [2 ]
259+ output_repr = jax .tree .map (_to_dummy_array , outputs )
260+ vjp_flops = _get_flops (jax .jit (do_vjp ).lower (
261+ obj , * args , ** kwargs )) if compute_vjp_flops else None
262+ tracer_args .append (
263+ CallInfo (counter_val , object_id , type (obj ), path , inputs_repr ,
264+ output_repr , flops , vjp_flops ))
270265 f .seen .add (identifier )
271266 return f (obj , * args , ** kwargs )
272267 return wrapper
@@ -278,7 +273,7 @@ def _overwrite_methods(env):
278273
279274def _get_flops (e ) -> int :
280275 cost = e .cost_analysis () or e .compile ().cost_analysis ()
281- return - 1 if cost is None or 'flops' not in cost else int (cost ['flops' ])
276+ return 0 if cost is None or 'flops' not in cost else int (cost ['flops' ])
282277
283278def tabulate (
284279 obj ,
@@ -424,22 +419,20 @@ def tabulate(
424419 env = _create_obj_env (object_types )
425420
426421 # Information is recorded in post-order, but should be presented as a pre-order traversal.
427- # This counter is incremented in pre-order traversal to keep track of the order of calls.
428- counter = [ 0 ]
422+ # This keeps track of the order of calls.
423+ counter = itertools . count ( 0 )
429424
430425 # Modify all the object's methods to save their lowered JIT representations.
431- tracer_args = []
432- jits = {k : _argsave (counter , tracer_args , MaybeJit (v ), compute_vjp_flops ) for k ,v in env .items ()}
433- _overwrite_methods (jits )
426+ rows : list [CallInfo ] = []
427+ maybejits = {k : _save_call_info (counter , rows , MaybeJit (v ), node_stats , compute_flops , compute_vjp_flops )
428+ for k , v in env .items ()}
429+ _overwrite_methods (maybejits )
434430
435431 # Trace the top function (which indirectly traces all the others)
436- jits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
432+ maybejits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
437433
438- # Get call_info
439- rows : list [CallInfo ] = [_get_call_info (
440- lowered , name , node_stats , object ,
441- compute_flops , inputs_repr , vjp_flops )
442- for (_ , object , name , lowered , inputs_repr , vjp_flops ) in sorted (tracer_args , key = lambda x : x [0 ])]
434+ # Sort call info in pre-order traversal order
435+ rows .sort (key = lambda x : x .call_order )
443436
444437 # Restore the object's original methods
445438 _overwrite_methods (env )
0 commit comments