Skip to content

Commit 442b091

Browse files
committed
Clean up code
1 parent a73bbde commit 442b091

File tree

2 files changed

+58
-67
lines changed

2 files changed

+58
-67
lines changed

flax/nnx/summary.py

Lines changed: 58 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import typing as tp
2121
from types import MappingProxyType
2222
import functools
23+
import itertools
2324
from types import SimpleNamespace
2425

2526
import jax
@@ -83,7 +84,7 @@ def lower(self, *args, **kwargs):
8384
except TypeError as e:
8485
result = self.f(*args, **kwargs)
8586
# Mock a `Lowered` instance with a SimpleNamespace
86-
return SimpleNamespace(cost_analysis=-1, lowered=SimpleNamespace(out_info=(None, None, result)))
87+
return SimpleNamespace(cost_analysis="0", lowered=SimpleNamespace(out_info=(None, None, result)))
8788

8889
class SizeBytes(typing.SizeBytes):
8990
def __repr__(self) -> str:
@@ -165,6 +166,7 @@ def __str__(self):
165166

166167
@dataclasses.dataclass
167168
class CallInfo:
169+
call_order: int
168170
object_id: int
169171
type: type
170172
path: statelib.PathParts
@@ -185,13 +187,14 @@ def __repr__(self):
185187

186188

187189
def _to_dummy_array(x):
188-
try:
190+
if isinstance(x,jax.ShapeDtypeStruct):
189191
return ArrayRepr(x.shape, x.dtype)
190-
except:
191-
if graph.is_graph_node(x):
192-
return SimpleObjectRepr(x)
193-
else:
194-
return x
192+
elif isinstance(x, jax.Array | np.ndarray):
193+
return ArrayRepr.from_array(x)
194+
elif graph.is_graph_node(x):
195+
return SimpleObjectRepr(x)
196+
else:
197+
return x
195198

196199
def _pure_nnx_vjp(f, model, *args, **kwargs):
197200
"Wrap nnx functional api around jax.vjp. Only handles pure method calls."
@@ -201,28 +204,6 @@ def inner(state, *args, **kwargs):
201204
return f(model, *args, **kwargs)
202205
return jax.vjp(inner, state, *args, **kwargs)
203206

204-
def _get_call_info(lowered, method_name, node_stats, obj, compute_flops, inputs_repr, vjp_flops):
205-
flops = _get_flops(lowered) if compute_flops else None
206-
outputs = lowered.lowered.out_info[2]
207-
output_repr = jax.tree.map(_to_dummy_array, outputs)
208-
object_id: int = getattr(obj, '_nnx_tabulate_id')
209-
node_info = node_stats[object_id]
210-
assert node_info is not None
211-
path = node_info.path
212-
if method_name != '__call__':
213-
path = (*path, method_name)
214-
215-
return CallInfo(
216-
object_id=object_id,
217-
type=type(obj),
218-
path=path,
219-
inputs_repr=inputs_repr,
220-
outputs=output_repr,
221-
flops=flops,
222-
vjp_flops=vjp_flops
223-
)
224-
225-
226207
def filter_rng_streams(row: CallInfo):
227208
return not issubclass(row.type, nnx.RngStream)
228209

@@ -235,38 +216,52 @@ def _create_obj_env(object_types):
235216
result[(obj_type, name)] = top_method
236217
return result
237218

238-
def _argsave(counter, tracer_args, f, compute_vjp_flops):
219+
def _get_inputs_repr(args, kwargs):
220+
input_args, input_kwargs = jax.tree.map(
221+
_to_dummy_array, (args, kwargs)
222+
)
223+
inputs_repr = ''
224+
if input_args:
225+
if len(input_args) == 1 and not input_kwargs:
226+
inputs_repr += _as_yaml_str(input_args[0])
227+
else:
228+
inputs_repr += _as_yaml_str(input_args)
229+
if input_kwargs:
230+
inputs_repr += '\n'
231+
if input_kwargs:
232+
inputs_repr += _as_yaml_str(input_kwargs)
233+
return inputs_repr
234+
235+
def _save_call_info(counter, tracer_args, f, node_stats, compute_flops, compute_vjp_flops):
239236
"Wrap a function to save its arguments"
237+
238+
# Used when computing vjp flops
240239
def do_vjp(*args, **kwargs):
241240
primals, f_vjp = jax.vjp(f, *args, **kwargs)
242241
return f_vjp(primals)
243-
n = f.__name__
242+
243+
method_name = f.__name__
244+
244245
@wraps(f)
245246
def wrapper(obj, *args, **kwargs):
246-
input_args, input_kwargs = jax.tree.map(
247-
_to_dummy_array, (args, kwargs)
248-
)
249-
inputs_repr = ''
250-
if input_args:
251-
if len(input_args) == 1 and not input_kwargs:
252-
inputs_repr += _as_yaml_str(input_args[0])
253-
else:
254-
inputs_repr += _as_yaml_str(input_args)
255-
if input_kwargs:
256-
inputs_repr += '\n'
257-
if input_kwargs:
258-
inputs_repr += _as_yaml_str(input_kwargs)
259-
260-
identifier = (inputs_repr, getattr(obj, '_nnx_tabulate_id'))
247+
inputs_repr = _get_inputs_repr(args, kwargs)
248+
object_id = getattr(obj, '_nnx_tabulate_id')
249+
node_info = node_stats[object_id]
250+
path = node_info.path
251+
if method_name != '__call__':
252+
path = (*path, method_name)
253+
identifier = (inputs_repr, object_id)
261254
if identifier not in f.seen:
262-
counter_val = counter[0]
263-
counter[0] += 1
255+
counter_val = next(counter)
264256
lowered = f.lower(obj, *args, **kwargs)
265-
if compute_vjp_flops:
266-
vjp_flops = _get_flops(jax.jit(do_vjp).lower(obj, *args, **kwargs))
267-
else:
268-
vjp_flops = None
269-
tracer_args.append((counter_val, obj, n, lowered, inputs_repr, vjp_flops))
257+
flops = _get_flops(lowered) if compute_flops else None
258+
outputs = lowered.lowered.out_info[2]
259+
output_repr = jax.tree.map(_to_dummy_array, outputs)
260+
vjp_flops = _get_flops(jax.jit(do_vjp).lower(
261+
obj, *args, **kwargs)) if compute_vjp_flops else None
262+
tracer_args.append(
263+
CallInfo(counter_val, object_id, type(obj), path, inputs_repr,
264+
output_repr, flops, vjp_flops))
270265
f.seen.add(identifier)
271266
return f(obj, *args, **kwargs)
272267
return wrapper
@@ -278,7 +273,7 @@ def _overwrite_methods(env):
278273

279274
def _get_flops(e) -> int:
280275
cost = e.cost_analysis() or e.compile().cost_analysis()
281-
return -1 if cost is None or 'flops' not in cost else int(cost['flops'])
276+
return 0 if cost is None or 'flops' not in cost else int(cost['flops'])
282277

283278
def tabulate(
284279
obj,
@@ -424,22 +419,20 @@ def tabulate(
424419
env = _create_obj_env(object_types)
425420

426421
# Information is recorded in post-order, but should be presented as a pre-order traversal.
427-
# This counter is incremented in pre-order traversal to keep track of the order of calls.
428-
counter = [0]
422+
# This keeps track of the order of calls.
423+
counter = itertools.count(0)
429424

430425
# Modify all the object's methods to save their lowered JIT representations.
431-
tracer_args = []
432-
jits = {k: _argsave(counter, tracer_args, MaybeJit(v), compute_vjp_flops) for k,v in env.items()}
433-
_overwrite_methods(jits)
426+
rows : list[CallInfo] = []
427+
maybejits = {k: _save_call_info(counter, rows, MaybeJit(v), node_stats, compute_flops, compute_vjp_flops)
428+
for k, v in env.items()}
429+
_overwrite_methods(maybejits)
434430

435431
# Trace the top function (which indirectly traces all the others)
436-
jits[(type(obj), method)](obj, *input_args, **input_kwargs)
432+
maybejits[(type(obj), method)](obj, *input_args, **input_kwargs)
437433

438-
# Get call_info
439-
rows : list[CallInfo] = [_get_call_info(
440-
lowered, name, node_stats, object,
441-
compute_flops, inputs_repr, vjp_flops)
442-
for (_, object, name, lowered, inputs_repr, vjp_flops) in sorted(tracer_args, key=lambda x: x[0])]
434+
# Sort call info in pre-order traversal order
435+
rows.sort(key=lambda x: x.call_order)
443436

444437
# Restore the object's original methods
445438
_overwrite_methods(env)

tests/nnx/summary_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,5 @@ def __call__(self, x):
308308
x = jnp.zeros((4, 8))
309309
nnx.tabulate(net, x, console_kwargs={"width": 200})
310310

311-
# TODO: should test dynamic shapes with nested calls. This will probably lead to duplicate rows.
312-
313311
if __name__ == '__main__':
314312
absltest.main()

0 commit comments

Comments
 (0)