Skip to content

Commit b116f16

Browse files
committed
Clean up code
1 parent a73bbde commit b116f16

File tree

2 files changed

+49
-59
lines changed

2 files changed

+49
-59
lines changed

flax/nnx/summary.py

Lines changed: 49 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import typing as tp
2121
from types import MappingProxyType
2222
import functools
23+
import itertools
2324
from types import SimpleNamespace
2425

2526
import jax
@@ -83,7 +84,7 @@ def lower(self, *args, **kwargs):
8384
except TypeError as e:
8485
result = self.f(*args, **kwargs)
8586
# Mock a `Lowered` instance with a SimpleNamespace
86-
return SimpleNamespace(cost_analysis=-1, lowered=SimpleNamespace(out_info=(None, None, result)))
87+
return SimpleNamespace(cost_analysis="0", lowered=SimpleNamespace(out_info=(None, None, result)))
8788

8889
class SizeBytes(typing.SizeBytes):
8990
def __repr__(self) -> str:
@@ -165,6 +166,7 @@ def __str__(self):
165166

166167
@dataclasses.dataclass
167168
class CallInfo:
169+
call_order: int
168170
object_id: int
169171
type: type
170172
path: statelib.PathParts
@@ -201,28 +203,6 @@ def inner(state, *args, **kwargs):
201203
return f(model, *args, **kwargs)
202204
return jax.vjp(inner, state, *args, **kwargs)
203205

204-
def _get_call_info(lowered, method_name, node_stats, obj, compute_flops, inputs_repr, vjp_flops):
205-
flops = _get_flops(lowered) if compute_flops else None
206-
outputs = lowered.lowered.out_info[2]
207-
output_repr = jax.tree.map(_to_dummy_array, outputs)
208-
object_id: int = getattr(obj, '_nnx_tabulate_id')
209-
node_info = node_stats[object_id]
210-
assert node_info is not None
211-
path = node_info.path
212-
if method_name != '__call__':
213-
path = (*path, method_name)
214-
215-
return CallInfo(
216-
object_id=object_id,
217-
type=type(obj),
218-
path=path,
219-
inputs_repr=inputs_repr,
220-
outputs=output_repr,
221-
flops=flops,
222-
vjp_flops=vjp_flops
223-
)
224-
225-
226206
def filter_rng_streams(row: CallInfo):
227207
return not issubclass(row.type, nnx.RngStream)
228208

@@ -235,38 +215,52 @@ def _create_obj_env(object_types):
235215
result[(obj_type, name)] = top_method
236216
return result
237217

238-
def _argsave(counter, tracer_args, f, compute_vjp_flops):
218+
def _get_inputs_repr(args, kwargs):
219+
input_args, input_kwargs = jax.tree.map(
220+
_to_dummy_array, (args, kwargs)
221+
)
222+
inputs_repr = ''
223+
if input_args:
224+
if len(input_args) == 1 and not input_kwargs:
225+
inputs_repr += _as_yaml_str(input_args[0])
226+
else:
227+
inputs_repr += _as_yaml_str(input_args)
228+
if input_kwargs:
229+
inputs_repr += '\n'
230+
if input_kwargs:
231+
inputs_repr += _as_yaml_str(input_kwargs)
232+
return inputs_repr
233+
234+
def _argsave(counter, tracer_args, f, node_stats, compute_flops, compute_vjp_flops):
239235
"Wrap a function to save its arguments"
236+
237+
# Used when computing vjp flops
240238
def do_vjp(*args, **kwargs):
241239
primals, f_vjp = jax.vjp(f, *args, **kwargs)
242240
return f_vjp(primals)
243-
n = f.__name__
241+
242+
method_name = f.__name__
243+
244244
@wraps(f)
245245
def wrapper(obj, *args, **kwargs):
246-
input_args, input_kwargs = jax.tree.map(
247-
_to_dummy_array, (args, kwargs)
248-
)
249-
inputs_repr = ''
250-
if input_args:
251-
if len(input_args) == 1 and not input_kwargs:
252-
inputs_repr += _as_yaml_str(input_args[0])
253-
else:
254-
inputs_repr += _as_yaml_str(input_args)
255-
if input_kwargs:
256-
inputs_repr += '\n'
257-
if input_kwargs:
258-
inputs_repr += _as_yaml_str(input_kwargs)
259-
260-
identifier = (inputs_repr, getattr(obj, '_nnx_tabulate_id'))
246+
inputs_repr = _get_inputs_repr(args, kwargs)
247+
object_id = getattr(obj, '_nnx_tabulate_id')
248+
node_info = node_stats[object_id]
249+
path = node_info.path
250+
if method_name != '__call__':
251+
path = (*path, method_name)
252+
identifier = (inputs_repr, object_id)
261253
if identifier not in f.seen:
262-
counter_val = counter[0]
263-
counter[0] += 1
254+
counter_val = next(counter)
264255
lowered = f.lower(obj, *args, **kwargs)
265-
if compute_vjp_flops:
266-
vjp_flops = _get_flops(jax.jit(do_vjp).lower(obj, *args, **kwargs))
267-
else:
268-
vjp_flops = None
269-
tracer_args.append((counter_val, obj, n, lowered, inputs_repr, vjp_flops))
256+
flops = _get_flops(lowered) if compute_flops else None
257+
outputs = lowered.lowered.out_info[2]
258+
output_repr = jax.tree.map(_to_dummy_array, outputs)
259+
vjp_flops = _get_flops(jax.jit(do_vjp).lower(
260+
obj, *args, **kwargs)) if compute_vjp_flops else None
261+
tracer_args.append(
262+
CallInfo(counter_val, object_id, type(obj), path, inputs_repr,
263+
output_repr, flops, vjp_flops))
270264
f.seen.add(identifier)
271265
return f(obj, *args, **kwargs)
272266
return wrapper
@@ -278,7 +272,7 @@ def _overwrite_methods(env):
278272

279273
def _get_flops(e) -> int:
280274
cost = e.cost_analysis() or e.compile().cost_analysis()
281-
return -1 if cost is None or 'flops' not in cost else int(cost['flops'])
275+
return 0 if cost is None or 'flops' not in cost else int(cost['flops'])
282276

283277
def tabulate(
284278
obj,
@@ -424,22 +418,20 @@ def tabulate(
424418
env = _create_obj_env(object_types)
425419

426420
# Information is recorded in post-order, but should be presented as a pre-order traversal.
427-
# This counter is incremented in pre-order traversal to keep track of the order of calls.
428-
counter = [0]
421+
# This keeps track of the order of calls.
422+
counter = itertools.count(0)
429423

430424
# Modify all the object's methods to save their lowered JIT representations.
431-
tracer_args = []
432-
jits = {k: _argsave(counter, tracer_args, MaybeJit(v), compute_vjp_flops) for k,v in env.items()}
425+
rows : list[CallInfo] = []
426+
jits = {k: _argsave(counter, rows, MaybeJit(v), node_stats, compute_flops, compute_vjp_flops)
427+
for k,v in env.items()}
433428
_overwrite_methods(jits)
434429

435430
# Trace the top function (which indirectly traces all the others)
436431
jits[(type(obj), method)](obj, *input_args, **input_kwargs)
437432

438-
# Get call_info
439-
rows : list[CallInfo] = [_get_call_info(
440-
lowered, name, node_stats, object,
441-
compute_flops, inputs_repr, vjp_flops)
442-
for (_, object, name, lowered, inputs_repr, vjp_flops) in sorted(tracer_args, key=lambda x: x[0])]
433+
# Sort call info in pre-order traversal order
434+
rows.sort(key=lambda x: x.call_order)
443435

444436
# Restore the object's original methods
445437
_overwrite_methods(env)

tests/nnx/summary_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,5 @@ def __call__(self, x):
308308
x = jnp.zeros((4, 8))
309309
nnx.tabulate(net, x, console_kwargs={"width": 200})
310310

311-
# TODO: should test dynamic shapes with nested calls. This will probably lead to duplicate rows.
312-
313311
if __name__ == '__main__':
314312
absltest.main()

0 commit comments

Comments
 (0)