Skip to content

Commit 82dba72

Browse files
committed
Use nnx.split to avoid compilation issue
1 parent cb6aa68 commit 82dba72

File tree

1 file changed

+17
-43
lines changed

1 file changed

+17
-43
lines changed

flax/nnx/summary.py

Lines changed: 17 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from types import MappingProxyType
2222
import functools
2323
import itertools
24-
from types import SimpleNamespace
2524

2625
import jax
2726
import numpy as np
@@ -54,38 +53,6 @@ class NoneDumper(yaml.SafeDumper):
5453
lambda dumper, data: dumper.represent_scalar('tag:yaml.org,2002:str', 'None'),
5554
)
5655

57-
class MaybeJit:
58-
"""
59-
Wraps a function with nnx.jit, but saves the original to run
60-
if the function turns out to be non-concrete. We can't get the flops of non-concrete functions,
61-
but we should still be able to trace the input and output shapes.
62-
"""
63-
def __init__(self, f):
64-
functools.update_wrapper(self, f)
65-
self.f = f
66-
self.jitted = nnx.jit(f)
67-
self.seen = set()
68-
69-
# implement descriptor protocol so that we can use this as a method
70-
def __get__(self, obj, objtype=None):
71-
if obj is None:
72-
return self
73-
return functools.partial(self, obj)
74-
75-
def __call__(self, *args, **kwargs):
76-
try:
77-
return self.jitted(*args, **kwargs)
78-
except TypeError as e:
79-
return self.f(*args, **kwargs)
80-
81-
def lower(self, *args, **kwargs):
82-
try:
83-
return self.jitted.lower(*args, **kwargs)
84-
except TypeError as e:
85-
result = self.f(*args, **kwargs)
86-
# Mock a `Lowered` instance with a SimpleNamespace
87-
return SimpleNamespace(cost_analysis="0", lowered=SimpleNamespace(out_info=(None, None, result)))
88-
8956
class SizeBytes(typing.SizeBytes):
9057
def __repr__(self) -> str:
9158
bytes_repr = _bytes_repr(self.bytes)
@@ -232,7 +199,7 @@ def _get_inputs_repr(args, kwargs):
232199
inputs_repr += _as_yaml_str(input_kwargs)
233200
return inputs_repr
234201

235-
def _save_call_info(counter, tracer_args, f, node_stats, compute_flops, compute_vjp_flops):
202+
def _save_call_info(counter, tracer_args, f, node_stats, compute_flops, compute_vjp_flops, seen):
236203
"Wrap a function to save its arguments"
237204

238205
# Used when computing vjp flops
@@ -242,6 +209,11 @@ def do_vjp(*args, **kwargs):
242209

243210
method_name = f.__name__
244211

212+
@functools.partial(jax.jit)
213+
def jit_f(graphdef, state):
214+
args, kwargs = nnx.merge(graphdef, state)
215+
return f(*args, **kwargs)
216+
245217
@wraps(f)
246218
def wrapper(obj, *args, **kwargs):
247219
inputs_repr = _get_inputs_repr(args, kwargs)
@@ -251,19 +223,20 @@ def wrapper(obj, *args, **kwargs):
251223
if method_name != '__call__':
252224
path = (*path, method_name)
253225
identifier = (inputs_repr, object_id)
254-
if identifier not in f.seen:
255-
counter_val = next(counter)
256-
lowered = f.lower(obj, *args, **kwargs)
226+
counter_val = next(counter)
227+
graphdef, state = nnx.split(((obj, *args), kwargs))
228+
lowered = jit_f.lower(graphdef, state)
229+
if identifier not in seen:
230+
seen.add(identifier)
257231
flops = _get_flops(lowered) if compute_flops else None
258-
outputs = lowered.lowered.out_info[2]
232+
outputs = lowered.out_info
259233
output_repr = jax.tree.map(_to_dummy_array, outputs)
260234
vjp_flops = _get_flops(jax.jit(do_vjp).lower(
261235
obj, *args, **kwargs)) if compute_vjp_flops else None
262236
tracer_args.append(
263237
CallInfo(counter_val, object_id, type(obj), path, inputs_repr,
264238
output_repr, flops, vjp_flops))
265-
f.seen.add(identifier)
266-
return f(obj, *args, **kwargs)
239+
return jit_f(graphdef, state)
267240
return wrapper
268241

269242
def _overwrite_methods(env):
@@ -424,12 +397,13 @@ def tabulate(
424397

425398
# Modify all the object's methods to save their lowered JIT representations.
426399
rows : list[CallInfo] = []
427-
maybejits = {k: _save_call_info(counter, rows, MaybeJit(v), node_stats, compute_flops, compute_vjp_flops)
400+
seen : set = set()
401+
jits = {k: _save_call_info(counter, rows, v, node_stats, compute_flops, compute_vjp_flops, seen)
428402
for k, v in env.items()}
429-
_overwrite_methods(maybejits)
403+
_overwrite_methods(jits)
430404

431405
# Trace the top function (which indirectly traces all the others)
432-
maybejits[(type(obj), method)](obj, *input_args, **input_kwargs)
406+
jits[(type(obj), method)](obj, *input_args, **input_kwargs)
433407

434408
# Sort call info in pre-order traversal order
435409
rows.sort(key=lambda x: x.call_order)

0 commit comments

Comments
 (0)