1919import io
2020import typing as tp
2121from types import MappingProxyType
22+ import functools
23+ from types import SimpleNamespace
2224
2325import jax
2426import numpy as np
3436
3537from functools import wraps
3638
39+
3740try :
3841 from IPython import get_ipython
3942
@@ -50,6 +53,43 @@ class NoneDumper(yaml.SafeDumper):
5053 lambda dumper , data : dumper .represent_scalar ('tag:yaml.org,2002:str' , 'None' ),
5154)
5255
56+
57+ # Although we call 'uniform', it doesn't seem like uniform is stored anywhere.
58+ class MaybeJit :
59+ """
60+ Wraps a function with nnx.jit, but saves the original to run
61+ if the function turns out to be non-concrete. We can't get the flops of non-concrete functions,
62+ but we should still be able to trace the input and output shapes.
63+ """
64+ def __init__ (self , f ):
65+ self .f = f
66+ self .jitted = nnx .jit (f )
67+ functools .update_wrapper (self , self .f )
68+
69+ # implement descriptor protocol so that we can use this as a method
70+ def __get__ (self , obj , objtype = None ):
71+ if obj is None :
72+ return self
73+ return functools .partial (self , obj )
74+
75+ def __call__ (self , * args ):
76+ try :
77+ return self .jitted (* args )
78+ except TypeError as e :
79+ return self .f (* args )
80+
81+ # This will only be used for the top-level method
82+ def trace (self , * args ):
83+ return self (* args )
84+
85+ def lower (self , * args ):
86+ try :
87+ return self .jitted .lower (* args )
88+ except TypeError as e :
89+ result = self .f (* args )
90+ # Mock a `Lowered` instance with a SimpleNamespace
91+ return SimpleNamespace (cost_analysis = - 1 , lowered = SimpleNamespace (out_info = (None , None , result )))
92+
5393class SizeBytes (typing .SizeBytes ):
5494 def __repr__ (self ) -> str :
5595 bytes_repr = _bytes_repr (self .bytes )
@@ -222,7 +262,7 @@ def _overwrite_methods(env):
222262
223263def _get_flops (e ) -> int :
224264 cost = e .cost_analysis () or e .compile ().cost_analysis ()
225- return 0 if cost is None or 'flops' not in cost else int (cost ['flops' ])
265+ return - 1 if cost is None or 'flops' not in cost else int (cost ['flops' ])
226266
227267def tabulate (
228268 obj ,
@@ -378,7 +418,8 @@ def tabulate(
378418 # that each method will only be traced (and added to the table) once.
379419 jits = {} # Maps (class, method_name) to jit
380420 for key , value in saver_env .items ():
381- jits [key ] = nnx .jit (value )
421+ jits [key ] = MaybeJit (value )
422+
382423 _overwrite_methods (jits )
383424
384425 # Trace the top function (which indirectly traces all the others)
0 commit comments