Skip to content

Commit 287f55c

Browse files
committed
Use MaybeJit for summaries
1 parent 74985b2 commit 287f55c

File tree

3 files changed

+56
-2
lines changed

3 files changed

+56
-2
lines changed

flax/nnx/rnglib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def _to_keyless(
5858

5959

6060
def _function_to_method(random_f):
61+
@functools.wraps(random_f)
6162
def rngs_random_method(self: Rngs | RngStream, *args, **kwargs) -> jax.Array:
6263
return random_f(self(), *args, **kwargs)
6364

flax/nnx/summary.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import io
2020
import typing as tp
2121
from types import MappingProxyType
22+
import functools
23+
from types import SimpleNamespace
2224

2325
import jax
2426
import numpy as np
@@ -34,6 +36,7 @@
3436

3537
from functools import wraps
3638

39+
3740
try:
3841
from IPython import get_ipython
3942

@@ -50,6 +53,43 @@ class NoneDumper(yaml.SafeDumper):
5053
lambda dumper, data: dumper.represent_scalar('tag:yaml.org,2002:str', 'None'),
5154
)
5255

56+
57+
# Although we call 'uniform', it doesn't seem like uniform is stored anywhere.
58+
class MaybeJit:
59+
"""
60+
Wraps a function with nnx.jit, but saves the original to run
61+
if the function turns out to be non-concrete. We can't get the flops of non-concrete functions,
62+
but we should still be able to trace the input and output shapes.
63+
"""
64+
def __init__(self, f):
65+
self.f = f
66+
self.jitted = nnx.jit(f)
67+
functools.update_wrapper(self, self.f)
68+
69+
# implement descriptor protocol so that we can use this as a method
70+
def __get__(self, obj, objtype=None):
71+
if obj is None:
72+
return self
73+
return functools.partial(self, obj)
74+
75+
def __call__(self, *args):
76+
try:
77+
return self.jitted(*args)
78+
except TypeError as e:
79+
return self.f(*args)
80+
81+
# This will only be used for the top-level method
82+
def trace(self, *args):
83+
return self(*args)
84+
85+
def lower(self, *args):
86+
try:
87+
return self.jitted.lower(*args)
88+
except TypeError as e:
89+
result = self.f(*args)
90+
# Mock a `Lowered` instance with a SimpleNamespace
91+
return SimpleNamespace(cost_analysis=-1, lowered=SimpleNamespace(out_info=(None, None, result)))
92+
5393
class SizeBytes(typing.SizeBytes):
5494
def __repr__(self) -> str:
5595
bytes_repr = _bytes_repr(self.bytes)
@@ -222,7 +262,7 @@ def _overwrite_methods(env):
222262

223263
def _get_flops(e) -> int:
224264
cost = e.cost_analysis() or e.compile().cost_analysis()
225-
return 0 if cost is None or 'flops' not in cost else int(cost['flops'])
265+
return -1 if cost is None or 'flops' not in cost else int(cost['flops'])
226266

227267
def tabulate(
228268
obj,
@@ -378,7 +418,8 @@ def tabulate(
378418
# that each method will only be traced (and added to the table) once.
379419
jits = {} # Maps (class, method_name) to jit
380420
for key, value in saver_env.items():
381-
jits[key] = nnx.jit(value)
421+
jits[key] = MaybeJit(value)
422+
382423
_overwrite_methods(jits)
383424

384425
# Trace the top function (which indirectly traces all the others)

tests/nnx/summary_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,5 +295,17 @@ def __call__(self, x):
295295
self.assertEqual(module.hooked_param.get_metadata('description'), 'Custom parameter')
296296
self.assertEqual(module.hooked_param.get_metadata('trainable'), True)
297297

298+
def test_tabulate_concrete_shape(self):
299+
class Net(nnx.Module):
300+
def __init__(self):
301+
self.rngs = nnx.Rngs(0)
302+
303+
def __call__(self, x):
304+
return self.rngs.uniform((x.shape[0], 10))
305+
306+
net = Net()
307+
x = jnp.zeros((4, 8))
308+
print(nnx.tabulate(net, x, console_kwargs={"width": 200}))
309+
298310
if __name__ == '__main__':
299311
absltest.main()

0 commit comments

Comments
 (0)