|
27 | 27 | from flax.nnx.pytreelib import Pytree, PytreeMeta |
28 | 28 | from flax.nnx.graph import GraphState |
29 | 29 | from flax.typing import Key, Path, PathParts |
| 30 | +import warnings |
30 | 31 |
|
31 | 32 | A = tp.TypeVar('A') |
32 | 33 | B = tp.TypeVar('B') |
@@ -268,85 +269,33 @@ def perturb( |
268 | 269 | return old_value.value + value |
269 | 270 |
|
270 | 271 | def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]: |
271 | | - """Recursively iterates over all nested :class:`Module`'s of the current Module, including |
272 | | - the current Module. |
273 | | -
|
274 | | - ``iter_modules`` creates a generator that yields the path and the Module instance, where |
275 | | - the path is a tuple of strings or integers representing the path to the Module from the |
276 | | - root Module. |
277 | | -
|
278 | | - Example:: |
| 272 | + """ |
| 273 | + Warning: this method is method is deprecated; use :func:`iter_modules` instead. |
279 | 274 |
|
280 | | - >>> from flax import nnx |
281 | | - ... |
282 | | - >>> class SubModule(nnx.Module): |
283 | | - ... def __init__(self, din, dout, rngs): |
284 | | - ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) |
285 | | - ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) |
286 | | - ... |
287 | | - >>> class Block(nnx.Module): |
288 | | - ... def __init__(self, din, dout, *, rngs: nnx.Rngs): |
289 | | - ... self.linear = nnx.Linear(din, dout, rngs=rngs) |
290 | | - ... self.submodule = SubModule(din, dout, rngs=rngs) |
291 | | - ... self.dropout = nnx.Dropout(0.5) |
292 | | - ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) |
293 | | - ... |
294 | | - >>> model = Block(2, 5, rngs=nnx.Rngs(0)) |
295 | | - >>> for path, module in model.iter_modules(): |
296 | | - ... print(path, type(module).__name__) |
297 | | - ... |
298 | | - ('batch_norm',) BatchNorm |
299 | | - ('dropout',) Dropout |
300 | | - ('linear',) Linear |
301 | | - ('submodule', 'linear1') Linear |
302 | | - ('submodule', 'linear2') Linear |
303 | | - ('submodule',) SubModule |
304 | | - () Block |
| 275 | + Recursively iterates over all nested :class:`Module`'s of the current Module, including |
| 276 | + the current Module. Alias of :func:`iter_modules`. |
305 | 277 | """ |
306 | | - for path, value in graph.iter_graph(self): |
307 | | - if isinstance(value, Module): |
308 | | - yield path, value |
| 278 | + warnings.warn( |
| 279 | + "The 'm.iter_modules()' method is deprecated; use the 'nnx.iter_modules(m)' function instead.", |
| 280 | + DeprecationWarning, |
| 281 | + stacklevel=2, |
| 282 | + ) |
| 283 | + yield from iter_modules(self) |
309 | 284 |
|
310 | 285 | def iter_children(self) -> tp.Iterator[tuple[Key, Module]]: |
311 | | - """Iterates over all children :class:`Module`'s of the current Module. This |
312 | | - method is similar to :func:`iter_modules`, except it only iterates over the |
313 | | - immediate children, and does not recurse further down. |
314 | | -
|
315 | | - ``iter_children`` creates a generator that yields the key and the Module instance, |
316 | | - where the key is a string representing the attribute name of the Module to access |
317 | | - the corresponding child Module. |
318 | | -
|
319 | | - Example:: |
| 286 | + """ |
| 287 | + Warning: this method is method is deprecated; use :func:`iter_children` instead. |
320 | 288 |
|
321 | | - >>> from flax import nnx |
322 | | - ... |
323 | | - >>> class SubModule(nnx.Module): |
324 | | - ... def __init__(self, din, dout, rngs): |
325 | | - ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) |
326 | | - ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) |
327 | | - ... |
328 | | - >>> class Block(nnx.Module): |
329 | | - ... def __init__(self, din, dout, *, rngs: nnx.Rngs): |
330 | | - ... self.linear = nnx.Linear(din, dout, rngs=rngs) |
331 | | - ... self.submodule = SubModule(din, dout, rngs=rngs) |
332 | | - ... self.dropout = nnx.Dropout(0.5) |
333 | | - ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) |
334 | | - ... |
335 | | - >>> model = Block(2, 5, rngs=nnx.Rngs(0)) |
336 | | - >>> for path, module in model.iter_children(): |
337 | | - ... print(path, type(module).__name__) |
338 | | - ... |
339 | | - batch_norm BatchNorm |
340 | | - dropout Dropout |
341 | | - linear Linear |
342 | | - submodule SubModule |
| 289 | + Iterates over all children :class:`Module`'s of the current Module. This |
| 290 | + method is similar to :func:`iter_modules`, except it only iterates over the |
| 291 | + immediate children, and does not recurse further down. Alias of :func:`iter_children`. |
343 | 292 | """ |
344 | | - node_impl = graph.get_node_impl(self) |
345 | | - assert node_impl is not None |
346 | | - node_dict = node_impl.node_dict(self) |
347 | | - for key, value in node_dict.items(): |
348 | | - if isinstance(value, Module): |
349 | | - yield key, value |
| 293 | + warnings.warn( |
| 294 | + "The 'm.iter_children()' method is deprecated; use the 'nnx.iter_children(m)' function instead.", |
| 295 | + DeprecationWarning, |
| 296 | + stacklevel=2, |
| 297 | + ) |
| 298 | + yield from iter_children(self) |
350 | 299 |
|
351 | 300 | def set_attributes( |
352 | 301 | self, |
@@ -392,7 +341,7 @@ def set_attributes( |
392 | 341 | if not filters: |
393 | 342 | filters = (True,) |
394 | 343 | predicates = tuple(map(filterlib.to_predicate, filters)) |
395 | | - for path, module in self.iter_modules(): |
| 344 | + for path, module in iter_modules(self): |
396 | 345 | for predicate in predicates: |
397 | 346 | if predicate(path, module): |
398 | 347 | for name, value in attributes.items(): |
@@ -494,3 +443,84 @@ def first_from(*args: tp.Optional[A], error_msg: str) -> A: |
494 | 443 | if arg is not None: |
495 | 444 | return arg |
496 | 445 | raise ValueError(error_msg) |
| 446 | + |
| 447 | +def iter_modules(module: Module) -> tp.Iterator[tuple[PathParts, Module]]: |
| 448 | + """Recursively iterates over all nested :class:`Module`'s of the given Module, including |
| 449 | + the argument. |
| 450 | +
|
| 451 | + Specifically, this function creates a generator that yields the path and the Module instance, where |
| 452 | + the path is a tuple of strings or integers representing the path to the Module from the |
| 453 | + root Module. |
| 454 | +
|
| 455 | + Example:: |
| 456 | +
|
| 457 | + >>> from flax import nnx |
| 458 | + ... |
| 459 | + >>> class SubModule(nnx.Module): |
| 460 | + ... def __init__(self, din, dout, rngs): |
| 461 | + ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) |
| 462 | + ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) |
| 463 | + ... |
| 464 | + >>> class Block(nnx.Module): |
| 465 | + ... def __init__(self, din, dout, *, rngs: nnx.Rngs): |
| 466 | + ... self.linear = nnx.Linear(din, dout, rngs=rngs) |
| 467 | + ... self.submodule = SubModule(din, dout, rngs=rngs) |
| 468 | + ... self.dropout = nnx.Dropout(0.5) |
| 469 | + ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) |
| 470 | + ... |
| 471 | + >>> model = Block(2, 5, rngs=nnx.Rngs(0)) |
| 472 | + >>> for path, module in nnx.iter_modules(model): |
| 473 | + ... print(path, type(module).__name__) |
| 474 | + ... |
| 475 | + ('batch_norm',) BatchNorm |
| 476 | + ('dropout',) Dropout |
| 477 | + ('linear',) Linear |
| 478 | + ('submodule', 'linear1') Linear |
| 479 | + ('submodule', 'linear2') Linear |
| 480 | + ('submodule',) SubModule |
| 481 | + () Block |
| 482 | + """ |
| 483 | + for path, value in graph.iter_graph(module): |
| 484 | + if isinstance(value, Module): |
| 485 | + yield path, value |
| 486 | + |
| 487 | +def iter_children(module: Module) -> tp.Iterator[tuple[Key, Module]]: |
| 488 | + """Iterates over all children :class:`Module`'s of a given Module. This |
| 489 | + method is similar to :func:`iter_modules`, except it only iterates over the |
| 490 | + immediate children, and does not recurse further down. |
| 491 | +
|
| 492 | + Specifically, this function creates a generator that yields the key and the Module instance, |
| 493 | + where the key is a string representing the attribute name of the Module to access |
| 494 | + the corresponding child Module. |
| 495 | +
|
| 496 | + Example:: |
| 497 | +
|
| 498 | + >>> from flax import nnx |
| 499 | + ... |
| 500 | + >>> class SubModule(nnx.Module): |
| 501 | + ... def __init__(self, din, dout, rngs): |
| 502 | + ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) |
| 503 | + ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) |
| 504 | + ... |
| 505 | + >>> class Block(nnx.Module): |
| 506 | + ... def __init__(self, din, dout, *, rngs: nnx.Rngs): |
| 507 | + ... self.linear = nnx.Linear(din, dout, rngs=rngs) |
| 508 | + ... self.submodule = SubModule(din, dout, rngs=rngs) |
| 509 | + ... self.dropout = nnx.Dropout(0.5) |
| 510 | + ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) |
| 511 | + ... |
| 512 | + >>> model = Block(2, 5, rngs=nnx.Rngs(0)) |
| 513 | + >>> for path, module in nnx.iter_children(model): |
| 514 | + ... print(path, type(module).__name__) |
| 515 | + ... |
| 516 | + batch_norm BatchNorm |
| 517 | + dropout Dropout |
| 518 | + linear Linear |
| 519 | + submodule SubModule |
| 520 | + """ |
| 521 | + node_impl = graph.get_node_impl(module) |
| 522 | + assert node_impl is not None |
| 523 | + node_dict = node_impl.node_dict(module) |
| 524 | + for key, value in node_dict.items(): |
| 525 | + if isinstance(value, Module): |
| 526 | + yield key, value |
0 commit comments