File tree Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Original file line number Diff line number Diff line change @@ -225,6 +225,32 @@ def __call__(self, x):
225225 # We should see 3 calls per block, plus one overall call
226226 self .assertEqual (sum ([s .startswith ("├─" ) for s in table .splitlines ()]), 7 )
227227
228+ def test_complexity (self ):
229+ counter = []
230+
231+ class Block (nnx .Module ):
232+ def __init__ (self , rngs ):
233+ self .linear = nnx .Linear (2 , 2 , rngs = rngs )
234+
235+ def __call__ (self , x ):
236+ counter .append (1 )
237+ return self .linear (x )
238+
239+ class Model (nnx .Module ):
240+ def __init__ (self , rngs ):
241+ for d in range (10 ):
242+ setattr (self , f"linear{ d } " , Block (rngs ))
243+
244+ def __call__ (self , x ):
245+ for d in range (10 ):
246+ x = getattr (self , f"linear{ d } " )(x )
247+ return x
248+
249+ m = Model (nnx .Rngs (0 ))
250+ x = jnp .ones ((4 , 2 ))
251+ nnx .tabulate (m , x , compute_flops = True , compute_vjp_flops = False )
252+ self .assertEqual (len (counter ), 10 )
253+
228254 def test_shared (self ):
229255 class Block (nnx .Module ):
230256 def __init__ (self , linear : nnx .Linear , * , rngs ):
You can’t perform that action at this time.
0 commit comments