Skip to content

Commit 2c2c7d2

Browse files
committed
Avoid lowering unless compute_flops is true
1 parent e379298 commit 2c2c7d2

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

flax/nnx/summary.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,15 @@ def wrapper(obj, *args, **kwargs):
225225
identifier = (inputs_repr, object_id)
226226
counter_val = next(counter)
227227
graphdef, state = nnx.split(((obj, *args), kwargs))
228-
lowered = jit_f.lower(graphdef, state)
228+
if compute_flops:
229+
lowered = jit_f.lower(graphdef, state)
230+
flops = _get_flops(lowered)
231+
outputs = lowered.out_info
232+
else:
233+
flops = None
234+
outputs = jit_f(graphdef, state)
229235
if identifier not in seen:
230236
seen.add(identifier)
231-
flops = _get_flops(lowered) if compute_flops else None
232-
outputs = lowered.out_info
233237
output_repr = jax.tree.map(_to_dummy_array, outputs)
234238
vjp_flops = _get_flops(jax.jit(do_vjp).lower(
235239
obj, *args, **kwargs)) if compute_vjp_flops else None

0 commit comments

Comments
 (0)