2121from types import MappingProxyType
2222import functools
2323import itertools
24- from types import SimpleNamespace
2524
2625import jax
2726import numpy as np
@@ -54,38 +53,6 @@ class NoneDumper(yaml.SafeDumper):
5453 lambda dumper , data : dumper .represent_scalar ('tag:yaml.org,2002:str' , 'None' ),
5554)
5655
57- class MaybeJit :
58- """
59- Wraps a function with nnx.jit, but saves the original to run
60- if the function turns out to be non-concrete. We can't get the flops of non-concrete functions,
61- but we should still be able to trace the input and output shapes.
62- """
63- def __init__ (self , f ):
64- functools .update_wrapper (self , f )
65- self .f = f
66- self .jitted = nnx .jit (f )
67- self .seen = set ()
68-
69- # implement descriptor protocol so that we can use this as a method
70- def __get__ (self , obj , objtype = None ):
71- if obj is None :
72- return self
73- return functools .partial (self , obj )
74-
75- def __call__ (self , * args , ** kwargs ):
76- try :
77- return self .jitted (* args , ** kwargs )
78- except TypeError as e :
79- return self .f (* args , ** kwargs )
80-
81- def lower (self , * args , ** kwargs ):
82- try :
83- return self .jitted .lower (* args , ** kwargs )
84- except TypeError as e :
85- result = self .f (* args , ** kwargs )
86- # Mock a `Lowered` instance with a SimpleNamespace
87- return SimpleNamespace (cost_analysis = "0" , lowered = SimpleNamespace (out_info = (None , None , result )))
88-
8956class SizeBytes (typing .SizeBytes ):
9057 def __repr__ (self ) -> str :
9158 bytes_repr = _bytes_repr (self .bytes )
@@ -232,7 +199,7 @@ def _get_inputs_repr(args, kwargs):
232199 inputs_repr += _as_yaml_str (input_kwargs )
233200 return inputs_repr
234201
235- def _save_call_info (counter , tracer_args , f , node_stats , compute_flops , compute_vjp_flops ):
202+ def _save_call_info (counter , tracer_args , f , node_stats , compute_flops , compute_vjp_flops , seen ):
236203 "Wrap a function to save its arguments"
237204
238205 # Used when computing vjp flops
@@ -242,6 +209,11 @@ def do_vjp(*args, **kwargs):
242209
243210 method_name = f .__name__
244211
212+ @functools .partial (jax .jit )
213+ def jit_f (graphdef , state ):
214+ args , kwargs = nnx .merge (graphdef , state )
215+ return f (* args , ** kwargs )
216+
245217 @wraps (f )
246218 def wrapper (obj , * args , ** kwargs ):
247219 inputs_repr = _get_inputs_repr (args , kwargs )
@@ -251,19 +223,20 @@ def wrapper(obj, *args, **kwargs):
251223 if method_name != '__call__' :
252224 path = (* path , method_name )
253225 identifier = (inputs_repr , object_id )
254- if identifier not in f .seen :
255- counter_val = next (counter )
256- lowered = f .lower (obj , * args , ** kwargs )
226+ counter_val = next (counter )
227+ graphdef , state = nnx .split (((obj , * args ), kwargs ))
228+ lowered = jit_f .lower (graphdef , state )
229+ if identifier not in seen :
230+ seen .add (identifier )
257231 flops = _get_flops (lowered ) if compute_flops else None
258- outputs = lowered .lowered . out_info [ 2 ]
232+ outputs = lowered .out_info
259233 output_repr = jax .tree .map (_to_dummy_array , outputs )
260234 vjp_flops = _get_flops (jax .jit (do_vjp ).lower (
261235 obj , * args , ** kwargs )) if compute_vjp_flops else None
262236 tracer_args .append (
263237 CallInfo (counter_val , object_id , type (obj ), path , inputs_repr ,
264238 output_repr , flops , vjp_flops ))
265- f .seen .add (identifier )
266- return f (obj , * args , ** kwargs )
239+ return jit_f (graphdef , state )
267240 return wrapper
268241
269242def _overwrite_methods (env ):
@@ -424,12 +397,13 @@ def tabulate(
424397
425398 # Modify all the object's methods to save their lowered JIT representations.
426399 rows : list [CallInfo ] = []
427- maybejits = {k : _save_call_info (counter , rows , MaybeJit (v ), node_stats , compute_flops , compute_vjp_flops )
400+ seen : set = set ()
401+ jits = {k : _save_call_info (counter , rows , v , node_stats , compute_flops , compute_vjp_flops , seen )
428402 for k , v in env .items ()}
429- _overwrite_methods (maybejits )
403+ _overwrite_methods (jits )
430404
431405 # Trace the top function (which indirectly traces all the others)
432- maybejits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
406+ jits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
433407
434408 # Sort call info in pre-order traversal order
435409 rows .sort (key = lambda x : x .call_order )
0 commit comments