Skip to content

Commit ee32998

Browse files
committed
Add test that tabulate has linear time complexity
1 parent 2c2c7d2 commit ee32998

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/nnx/summary_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,32 @@ def __call__(self, x):
225225
# We should see 3 calls per block, plus one overall call
226226
self.assertEqual(sum([s.startswith("├─") for s in table.splitlines()]), 7)
227227

228+
def test_complexity(self):
229+
counter = []
230+
231+
class Block(nnx.Module):
232+
def __init__(self, rngs):
233+
self.linear = nnx.Linear(2, 2, rngs=rngs)
234+
235+
def __call__(self, x):
236+
counter.append(1)
237+
return self.linear(x)
238+
239+
class Model(nnx.Module):
240+
def __init__(self, rngs):
241+
for d in range(10):
242+
setattr(self, f"linear{d}", Block(rngs))
243+
244+
def __call__(self, x):
245+
for d in range(10):
246+
x = getattr(self, f"linear{d}")(x)
247+
return x
248+
249+
m = Model(nnx.Rngs(0))
250+
x = jnp.ones((4, 2))
251+
nnx.tabulate(m, x, compute_flops=True, compute_vjp_flops=False)
252+
self.assertEqual(len(counter), 10)
253+
228254
def test_shared(self):
229255
class Block(nnx.Module):
230256
def __init__(self, linear: nnx.Linear, *, rngs):

0 commit comments

Comments
 (0)