Skip to content

Commit 12b2e42

Browse files
test pathcontains with exact=false
1 parent 3d34055 commit 12b2e42

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tests/nnx/filters_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,20 @@ class TestFilters(absltest.TestCase):
2121
def test_path_contains(self):
2222
class Model(nnx.Module):
2323
def __init__(self, rngs):
24-
self.backbone = nnx.Linear(2, 3, rngs=rngs)
24+
self.backbone1 = nnx.Linear(2, 3, rngs=rngs)
25+
self.backbone2 = nnx.Linear(3, 3, rngs=rngs)
2526
self.head = nnx.Linear(3, 10, rngs=rngs)
2627

2728
model = Model(nnx.Rngs(0))
2829

2930
head_state = nnx.state(model, nnx.PathContains('head'))
31+
backbones_state = nnx.state(model, nnx.PathContains('backbone', exact=False))
3032

3133
self.assertIn('head', head_state)
3234
self.assertNotIn('backbone', head_state)
35+
self.assertIn('backbone1', backbones_state)
36+
self.assertIn('backbone2', backbones_state)
37+
self.assertNotIn('head', backbones_state)
3338

3439
if __name__ == '__main__':
3540
absltest.main()

0 commit comments

Comments
 (0)