Skip to content

Commit 6d62119

Browse files
committed
Clean up code
1 parent a73bbde commit 6d62119

File tree

2 files changed

+45
-55
lines changed

2 files changed

+45
-55
lines changed

flax/nnx/summary.py

Lines changed: 45 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def lower(self, *args, **kwargs):
8383
except TypeError as e:
8484
result = self.f(*args, **kwargs)
8585
# Mock a `Lowered` instance with a SimpleNamespace
86-
return SimpleNamespace(cost_analysis=-1, lowered=SimpleNamespace(out_info=(None, None, result)))
86+
return SimpleNamespace(cost_analysis="0", lowered=SimpleNamespace(out_info=(None, None, result)))
8787

8888
class SizeBytes(typing.SizeBytes):
8989
def __repr__(self) -> str:
@@ -165,6 +165,7 @@ def __str__(self):
165165

166166
@dataclasses.dataclass
167167
class CallInfo:
168+
call_order: int
168169
object_id: int
169170
type: type
170171
path: statelib.PathParts
@@ -201,28 +202,6 @@ def inner(state, *args, **kwargs):
201202
return f(model, *args, **kwargs)
202203
return jax.vjp(inner, state, *args, **kwargs)
203204

204-
def _get_call_info(lowered, method_name, node_stats, obj, compute_flops, inputs_repr, vjp_flops):
205-
flops = _get_flops(lowered) if compute_flops else None
206-
outputs = lowered.lowered.out_info[2]
207-
output_repr = jax.tree.map(_to_dummy_array, outputs)
208-
object_id: int = getattr(obj, '_nnx_tabulate_id')
209-
node_info = node_stats[object_id]
210-
assert node_info is not None
211-
path = node_info.path
212-
if method_name != '__call__':
213-
path = (*path, method_name)
214-
215-
return CallInfo(
216-
object_id=object_id,
217-
type=type(obj),
218-
path=path,
219-
inputs_repr=inputs_repr,
220-
outputs=output_repr,
221-
flops=flops,
222-
vjp_flops=vjp_flops
223-
)
224-
225-
226205
def filter_rng_streams(row: CallInfo):
227206
return not issubclass(row.type, nnx.RngStream)
228207

@@ -235,38 +214,53 @@ def _create_obj_env(object_types):
235214
result[(obj_type, name)] = top_method
236215
return result
237216

238-
def _argsave(counter, tracer_args, f, compute_vjp_flops):
217+
def _get_inputs_repr(args, kwargs):
218+
input_args, input_kwargs = jax.tree.map(
219+
_to_dummy_array, (args, kwargs)
220+
)
221+
inputs_repr = ''
222+
if input_args:
223+
if len(input_args) == 1 and not input_kwargs:
224+
inputs_repr += _as_yaml_str(input_args[0])
225+
else:
226+
inputs_repr += _as_yaml_str(input_args)
227+
if input_kwargs:
228+
inputs_repr += '\n'
229+
if input_kwargs:
230+
inputs_repr += _as_yaml_str(input_kwargs)
231+
return inputs_repr
232+
233+
def _argsave(counter, tracer_args, f, node_stats, compute_flops, compute_vjp_flops):
239234
"Wrap a function to save its arguments"
235+
236+
# Used when computing vjp flops
240237
def do_vjp(*args, **kwargs):
241238
primals, f_vjp = jax.vjp(f, *args, **kwargs)
242239
return f_vjp(primals)
243-
n = f.__name__
240+
241+
method_name = f.__name__
242+
244243
@wraps(f)
245244
def wrapper(obj, *args, **kwargs):
246-
input_args, input_kwargs = jax.tree.map(
247-
_to_dummy_array, (args, kwargs)
248-
)
249-
inputs_repr = ''
250-
if input_args:
251-
if len(input_args) == 1 and not input_kwargs:
252-
inputs_repr += _as_yaml_str(input_args[0])
253-
else:
254-
inputs_repr += _as_yaml_str(input_args)
255-
if input_kwargs:
256-
inputs_repr += '\n'
257-
if input_kwargs:
258-
inputs_repr += _as_yaml_str(input_kwargs)
259-
260-
identifier = (inputs_repr, getattr(obj, '_nnx_tabulate_id'))
245+
inputs_repr = _get_inputs_repr(args, kwargs)
246+
object_id = getattr(obj, '_nnx_tabulate_id')
247+
node_info = node_stats[object_id]
248+
path = node_info.path
249+
if method_name != '__call__':
250+
path = (*path, method_name)
251+
identifier = (inputs_repr, object_id)
261252
if identifier not in f.seen:
262253
counter_val = counter[0]
263254
counter[0] += 1
264255
lowered = f.lower(obj, *args, **kwargs)
265-
if compute_vjp_flops:
266-
vjp_flops = _get_flops(jax.jit(do_vjp).lower(obj, *args, **kwargs))
267-
else:
268-
vjp_flops = None
269-
tracer_args.append((counter_val, obj, n, lowered, inputs_repr, vjp_flops))
256+
flops = _get_flops(lowered) if compute_flops else None
257+
outputs = lowered.lowered.out_info[2]
258+
output_repr = jax.tree.map(_to_dummy_array, outputs)
259+
vjp_flops = _get_flops(jax.jit(do_vjp).lower(
260+
obj, *args, **kwargs)) if compute_vjp_flops else None
261+
tracer_args.append(
262+
CallInfo(counter_val, object_id, type(obj), path, inputs_repr,
263+
output_repr, flops, vjp_flops))
270264
f.seen.add(identifier)
271265
return f(obj, *args, **kwargs)
272266
return wrapper
@@ -278,7 +272,7 @@ def _overwrite_methods(env):
278272

279273
def _get_flops(e) -> int:
280274
cost = e.cost_analysis() or e.compile().cost_analysis()
281-
return -1 if cost is None or 'flops' not in cost else int(cost['flops'])
275+
return 0 if cost is None or 'flops' not in cost else int(cost['flops'])
282276

283277
def tabulate(
284278
obj,
@@ -428,18 +422,16 @@ def tabulate(
428422
counter = [0]
429423

430424
# Modify all the object's methods to save their lowered JIT representations.
431-
tracer_args = []
432-
jits = {k: _argsave(counter, tracer_args, MaybeJit(v), compute_vjp_flops) for k,v in env.items()}
425+
rows = []
426+
jits = {k: _argsave(counter, rows, MaybeJit(v), node_stats, compute_flops, compute_vjp_flops)
427+
for k,v in env.items()}
433428
_overwrite_methods(jits)
434429

435430
# Trace the top function (which indirectly traces all the others)
436431
jits[(type(obj), method)](obj, *input_args, **input_kwargs)
437432

438-
# Get call_info
439-
rows : list[CallInfo] = [_get_call_info(
440-
lowered, name, node_stats, object,
441-
compute_flops, inputs_repr, vjp_flops)
442-
for (_, object, name, lowered, inputs_repr, vjp_flops) in sorted(tracer_args, key=lambda x: x[0])]
433+
# Sort call info in pre-order traversal order
434+
rows.sort(key=lambda x: x.call_order)
443435

444436
# Restore the object's original methods
445437
_overwrite_methods(env)

tests/nnx/summary_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,5 @@ def __call__(self, x):
308308
x = jnp.zeros((4, 8))
309309
nnx.tabulate(net, x, console_kwargs={"width": 200})
310310

311-
# TODO: should test dynamic shapes with nested calls. This will probably lead to duplicate rows.
312-
313311
if __name__ == '__main__':
314312
absltest.main()

0 commit comments

Comments
 (0)