Skip to content

Commit 771eadb

Browse files
author
Flax Authors
committed
Merge pull request #4961 from samanklesaria:issues/4943
PiperOrigin-RevId: 814774751
2 parents 6080725 + 1bd6b07 commit 771eadb

File tree

4 files changed

+108
-77
lines changed

4 files changed

+108
-77
lines changed

docs_nnx/api_reference/flax.nnx/module.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module
22
------------------------
33

44
.. automodule:: flax.nnx
5+
:members: iter_children, iter_modules
56
.. currentmodule:: flax.nnx
6-
77
.. autoclass:: Module
88
:members:

flax/nnx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from .helpers import TrainState as TrainState
4848
from .module import M as M
4949
from .module import Module as Module
50+
from .module import iter_children as iter_children, iter_modules as iter_modules
5051
from .graph import merge as merge
5152
from .graph import UpdateContext as UpdateContext
5253
from .graph import update_context as update_context

flax/nnx/module.py

Lines changed: 104 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from flax.nnx.pytreelib import Pytree, PytreeMeta
2828
from flax.nnx.graph import GraphState
2929
from flax.typing import Key, Path, PathParts
30+
import warnings
3031

3132
A = tp.TypeVar('A')
3233
B = tp.TypeVar('B')
@@ -268,85 +269,33 @@ def perturb(
268269
return old_value.value + value
269270

270271
def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]:
271-
"""Recursively iterates over all nested :class:`Module`'s of the current Module, including
272-
the current Module.
273-
274-
``iter_modules`` creates a generator that yields the path and the Module instance, where
275-
the path is a tuple of strings or integers representing the path to the Module from the
276-
root Module.
277-
278-
Example::
272+
"""
273+
Warning: this method is method is deprecated; use :func:`iter_modules` instead.
279274
280-
>>> from flax import nnx
281-
...
282-
>>> class SubModule(nnx.Module):
283-
... def __init__(self, din, dout, rngs):
284-
... self.linear1 = nnx.Linear(din, dout, rngs=rngs)
285-
... self.linear2 = nnx.Linear(din, dout, rngs=rngs)
286-
...
287-
>>> class Block(nnx.Module):
288-
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
289-
... self.linear = nnx.Linear(din, dout, rngs=rngs)
290-
... self.submodule = SubModule(din, dout, rngs=rngs)
291-
... self.dropout = nnx.Dropout(0.5)
292-
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
293-
...
294-
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
295-
>>> for path, module in model.iter_modules():
296-
... print(path, type(module).__name__)
297-
...
298-
('batch_norm',) BatchNorm
299-
('dropout',) Dropout
300-
('linear',) Linear
301-
('submodule', 'linear1') Linear
302-
('submodule', 'linear2') Linear
303-
('submodule',) SubModule
304-
() Block
275+
Recursively iterates over all nested :class:`Module`'s of the current Module, including
276+
the current Module. Alias of :func:`iter_modules`.
305277
"""
306-
for path, value in graph.iter_graph(self):
307-
if isinstance(value, Module):
308-
yield path, value
278+
warnings.warn(
279+
"The 'm.iter_modules()' method is deprecated; use the 'nnx.iter_modules(m)' function instead.",
280+
DeprecationWarning,
281+
stacklevel=2,
282+
)
283+
yield from iter_modules(self)
309284

310285
def iter_children(self) -> tp.Iterator[tuple[Key, Module]]:
311-
"""Iterates over all children :class:`Module`'s of the current Module. This
312-
method is similar to :func:`iter_modules`, except it only iterates over the
313-
immediate children, and does not recurse further down.
314-
315-
``iter_children`` creates a generator that yields the key and the Module instance,
316-
where the key is a string representing the attribute name of the Module to access
317-
the corresponding child Module.
318-
319-
Example::
286+
"""
287+
Warning: this method is method is deprecated; use :func:`iter_children` instead.
320288
321-
>>> from flax import nnx
322-
...
323-
>>> class SubModule(nnx.Module):
324-
... def __init__(self, din, dout, rngs):
325-
... self.linear1 = nnx.Linear(din, dout, rngs=rngs)
326-
... self.linear2 = nnx.Linear(din, dout, rngs=rngs)
327-
...
328-
>>> class Block(nnx.Module):
329-
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
330-
... self.linear = nnx.Linear(din, dout, rngs=rngs)
331-
... self.submodule = SubModule(din, dout, rngs=rngs)
332-
... self.dropout = nnx.Dropout(0.5)
333-
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
334-
...
335-
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
336-
>>> for path, module in model.iter_children():
337-
... print(path, type(module).__name__)
338-
...
339-
batch_norm BatchNorm
340-
dropout Dropout
341-
linear Linear
342-
submodule SubModule
289+
Iterates over all children :class:`Module`'s of the current Module. This
290+
method is similar to :func:`iter_modules`, except it only iterates over the
291+
immediate children, and does not recurse further down. Alias of :func:`iter_children`.
343292
"""
344-
node_impl = graph.get_node_impl(self)
345-
assert node_impl is not None
346-
node_dict = node_impl.node_dict(self)
347-
for key, value in node_dict.items():
348-
if isinstance(value, Module):
349-
yield key, value
293+
warnings.warn(
294+
"The 'm.iter_children()' method is deprecated; use the 'nnx.iter_children(m)' function instead.",
295+
DeprecationWarning,
296+
stacklevel=2,
297+
)
298+
yield from iter_children(self)
350299

351300
def set_attributes(
352301
self,
@@ -392,7 +341,7 @@ def set_attributes(
392341
if not filters:
393342
filters = (True,)
394343
predicates = tuple(map(filterlib.to_predicate, filters))
395-
for path, module in self.iter_modules():
344+
for path, module in iter_modules(self):
396345
for predicate in predicates:
397346
if predicate(path, module):
398347
for name, value in attributes.items():
@@ -494,3 +443,84 @@ def first_from(*args: tp.Optional[A], error_msg: str) -> A:
494443
if arg is not None:
495444
return arg
496445
raise ValueError(error_msg)
446+
447+
def iter_modules(module: Module) -> tp.Iterator[tuple[PathParts, Module]]:
448+
"""Recursively iterates over all nested :class:`Module`'s of the given Module, including
449+
the argument.
450+
451+
Specifically, this function creates a generator that yields the path and the Module instance, where
452+
the path is a tuple of strings or integers representing the path to the Module from the
453+
root Module.
454+
455+
Example::
456+
457+
>>> from flax import nnx
458+
...
459+
>>> class SubModule(nnx.Module):
460+
... def __init__(self, din, dout, rngs):
461+
... self.linear1 = nnx.Linear(din, dout, rngs=rngs)
462+
... self.linear2 = nnx.Linear(din, dout, rngs=rngs)
463+
...
464+
>>> class Block(nnx.Module):
465+
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
466+
... self.linear = nnx.Linear(din, dout, rngs=rngs)
467+
... self.submodule = SubModule(din, dout, rngs=rngs)
468+
... self.dropout = nnx.Dropout(0.5)
469+
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
470+
...
471+
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
472+
>>> for path, module in nnx.iter_modules(model):
473+
... print(path, type(module).__name__)
474+
...
475+
('batch_norm',) BatchNorm
476+
('dropout',) Dropout
477+
('linear',) Linear
478+
('submodule', 'linear1') Linear
479+
('submodule', 'linear2') Linear
480+
('submodule',) SubModule
481+
() Block
482+
"""
483+
for path, value in graph.iter_graph(module):
484+
if isinstance(value, Module):
485+
yield path, value
486+
487+
def iter_children(module: Module) -> tp.Iterator[tuple[Key, Module]]:
488+
"""Iterates over all children :class:`Module`'s of a given Module. This
489+
method is similar to :func:`iter_modules`, except it only iterates over the
490+
immediate children, and does not recurse further down.
491+
492+
Specifically, this function creates a generator that yields the key and the Module instance,
493+
where the key is a string representing the attribute name of the Module to access
494+
the corresponding child Module.
495+
496+
Example::
497+
498+
>>> from flax import nnx
499+
...
500+
>>> class SubModule(nnx.Module):
501+
... def __init__(self, din, dout, rngs):
502+
... self.linear1 = nnx.Linear(din, dout, rngs=rngs)
503+
... self.linear2 = nnx.Linear(din, dout, rngs=rngs)
504+
...
505+
>>> class Block(nnx.Module):
506+
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
507+
... self.linear = nnx.Linear(din, dout, rngs=rngs)
508+
... self.submodule = SubModule(din, dout, rngs=rngs)
509+
... self.dropout = nnx.Dropout(0.5)
510+
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
511+
...
512+
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
513+
>>> for path, module in nnx.iter_children(model):
514+
... print(path, type(module).__name__)
515+
...
516+
batch_norm BatchNorm
517+
dropout Dropout
518+
linear Linear
519+
submodule SubModule
520+
"""
521+
node_impl = graph.get_node_impl(module)
522+
assert node_impl is not None
523+
node_dict = node_impl.node_dict(module)
524+
for key, value in node_dict.items():
525+
if isinstance(value, Module):
526+
yield key, value

tests/nnx/module_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ def __init__(self, *, rngs: nnx.Rngs):
829829

830830
module = Foo(rngs=nnx.Rngs(0))
831831

832-
modules = list(module.iter_modules())
832+
modules = list(nnx.iter_modules(module))
833833

834834
assert len(modules) == 5
835835
assert modules[0][0] == ('dropout',)
@@ -855,7 +855,7 @@ def __init__(self, *, rngs: nnx.Rngs):
855855

856856
module = Foo(rngs=nnx.Rngs(0))
857857

858-
modules = list(module.iter_children())
858+
modules = list(nnx.iter_children(module))
859859

860860
assert len(modules) == 2
861861
assert modules[0][0] == 'dropout'

0 commit comments

Comments
 (0)