40
40
from torch .distributed .distributed_c10d import _get_default_group
41
41
from torch .nn .parameter import Parameter
42
42
43
- from .flatten_params_wrapper import (
44
- FLAT_PARAM ,
45
- FPW_MODULE ,
46
- FlatParameter ,
47
- FlattenParamsWrapper ,
48
- )
49
43
from ._optim_utils import (
50
44
_broadcast_pos_dim_tensor_states ,
51
45
_broadcast_processed_optim_state_dict ,
56
50
_process_pos_dim_tensor_state ,
57
51
_unflatten_optim_state ,
58
52
)
59
- from ._utils import _apply_to_tensors , _replace_by_prefix
53
+ from ._utils import _apply_to_modules , _apply_to_tensors , _replace_by_prefix
54
+ from .flatten_params_wrapper import (
55
+ FLAT_PARAM ,
56
+ FPW_MODULE ,
57
+ FlatParameter ,
58
+ FlattenParamsWrapper ,
59
+ )
60
60
from .wrap import _recursive_wrap
61
61
62
62
if TYPE_CHECKING :
63
63
from collections import OrderedDict # noqa: F401
64
64
65
65
_TORCHDISTX_AVAIL = True
66
66
try :
67
- from torchdistx import fake , deferred_init
67
+ from torchdistx import deferred_init , fake
68
68
except ImportError :
69
69
_TORCHDISTX_AVAIL = False
70
70
@@ -490,10 +490,10 @@ class FullyShardedDataParallel(nn.Module):
490
490
accuracy during model training. If ``None``, no mixed precision is applied.
491
491
(Default: ``None``)
492
492
ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose
493
- own parameters and child modules' parameters are ignored by this
494
- instance. None of the modules directly in ``ignored_modules``
495
- should be :class:`FullyShardedDataParallel` instances, and any
496
- child modules that are already-constructed
493
+ own parameters and child modules' parameters and buffers are
494
+ ignored by this instance. None of the modules directly in
495
+ ``ignored_modules`` should be :class:`FullyShardedDataParallel`
496
+ instances, and any child modules that are already-constructed
497
497
:class:`FullyShardedDataParallel` instances will not be ignored if
498
498
they are nested under this instance. This argument may be used to
499
499
avoid sharding specific parameters when using an
@@ -549,16 +549,7 @@ def __init__(
549
549
# Save the ignored modules and their parameters, including the
550
550
# parameter names, which are needed to filter the model state dict
551
551
self ._ignored_modules = self ._get_ignored_modules (ignored_modules )
552
- ignored_params = self ._get_ignored_params (self ._ignored_modules )
553
- param_to_unflat_param_names = _get_param_to_unflat_param_names (module )
554
- self ._ignored_param_to_param_name = {}
555
- for param in ignored_params :
556
- unflat_param_names = param_to_unflat_param_names [param ]
557
- assert len (unflat_param_names ) == 1 , \
558
- "Only `FlatParameter`s can map to >1 unflattened parameter " \
559
- "name, and `_get_ignored_params()` should have excluded " \
560
- "them; check `_get_param_to_unflat_param_names()`"
561
- self ._ignored_param_to_param_name [param ] = unflat_param_names [0 ]
552
+ ignored_params = self ._get_ignored_parameters ()
562
553
# if auto_wrap_policy is specified, submodules should not be
563
554
# already wrapped, otherwise we'd attempt to double wrap them resulting
564
555
# in errors.
@@ -776,19 +767,67 @@ def _get_ignored_modules(
776
767
)
777
768
return ignored_modules
778
769
779
- def _get_ignored_params (
780
- self ,
781
- ignored_modules : Set [torch .nn .Module ],
782
- ) -> Set [torch .nn .Parameter ]:
783
- """
784
- Returns the parameters of the modules in ``ignored_modules`` as a
770
+ def _get_ignored_parameters (self ) -> Set [torch .nn .Parameter ]:
771
+ """Returns the parameters of the modules in ``ignored_modules`` as a
785
772
:class:`set`, excluding any :class:`FlatParameter` s.
786
773
"""
774
+ assert hasattr (self , "_ignored_modules" ), \
775
+ "Expects `self._ignored_modules` to be initialized"
787
776
return set (
788
- p for m in ignored_modules for p in m .parameters ()
777
+ p for m in self . _ignored_modules for p in m .parameters ()
789
778
if not isinstance (p , FlatParameter )
790
779
)
791
780
781
+ def _get_ignored_named_tensors (
782
+ self ,
783
+ ignored_modules : Set [torch .nn .Module ],
784
+ named_tensor_fn : Callable ,
785
+ ) -> Set [Tuple [str , torch .Tensor ]]:
786
+ """
787
+ This performs a module walk to get the full parameter and buffer names
788
+ depending on ``named_tensor_fn``, which should either be
789
+ ``named_parameters()`` or ``named_buffers()`. We require a separate
790
+ :meth:`_get_ignored_parameters` that does not use this module walk
791
+ since that method needs to be called in the FSDP constructor before any
792
+ wrapping occurs, which means that we cannot start a module walk from
793
+ ``self`` as in this method.
794
+ """
795
+ def module_fn (module , prefix , ignored_named_tensors , ignored_modules ):
796
+ if module in ignored_modules :
797
+ assert not isinstance (module , FullyShardedDataParallel ) and \
798
+ not isinstance (module , FlattenParamsWrapper ), \
799
+ "Ignoring FSDP modules is meaningless since their " \
800
+ "parameters are not flattened into this FSDP module anyway"
801
+ for param_name , param in named_tensor_fn (module ):
802
+ prefixed_param_name = clean_param_name (prefix + param_name )
803
+ ignored_named_tensors .add ((prefixed_param_name , param ))
804
+
805
+ def return_fn (ignored_named_tensors , * args ):
806
+ return ignored_named_tensors
807
+
808
+ ignored_named_tensors = set ()
809
+ return _apply_to_modules (
810
+ self , module_fn , return_fn , ignored_named_tensors , ignored_modules ,
811
+ )
812
+
813
+ def _get_ignored_named_parameters (self ) -> Set [Tuple [str , torch .Tensor ]]:
814
+ """Returns the named parameters of the modules in ``ignored_modules``,
815
+ excluding any :class:`FlatParameter` s."""
816
+ assert hasattr (self , "_ignored_modules" ), \
817
+ "Expects `self._ignored_modules` to be initialized"
818
+ return self ._get_ignored_named_tensors (
819
+ self ._ignored_modules , lambda m : m .named_parameters (recurse = False ),
820
+ )
821
+
822
+ def _get_ignored_named_buffers (self ) -> Set [Tuple [str , torch .Tensor ]]:
823
+ """Returns the named buffers of the modules in ``ignored_modules``,
824
+ excluding any :class:`FlatParameter` s."""
825
+ assert hasattr (self , "_ignored_modules" ), \
826
+ "Expects `self._ignored_modules` to be initialized"
827
+ return self ._get_ignored_named_tensors (
828
+ self ._ignored_modules , lambda m : m .named_buffers (recurse = False ),
829
+ )
830
+
792
831
@classmethod
793
832
def _check_wrapped (cls , begin_module , check_fn , err_fn ):
794
833
for _ , mod in begin_module .named_modules ():
@@ -1496,12 +1535,14 @@ def _full_post_state_dict_hook(
1496
1535
if not state_dict :
1497
1536
return state_dict
1498
1537
1499
- ignored_param_names = set (self ._ignored_param_to_param_name .values ())
1538
+ ignored_named_params = self ._get_ignored_named_parameters ()
1539
+ ignored_named_buffers = self ._get_ignored_named_buffers ()
1540
+ ignored_names = set (n for n , _ in ignored_named_params )
1541
+ ignored_names .update (n for n , _ in ignored_named_buffers )
1500
1542
for key in state_dict :
1501
- # Do not need to clone ignored parameters since they are not
1502
- # sharded
1503
- clean_param_name = key .replace (FSDP_WRAPPED_MODULE + "." , "" ).replace (FPW_MODULE + "." , "" )
1504
- if clean_param_name in ignored_param_names :
1543
+ # Do not need to clone ignored parameters and buffers since they
1544
+ # are not sharded
1545
+ if clean_param_name (key ) in ignored_names :
1505
1546
continue
1506
1547
# Due to recursive call of summon_full_params, avoid unnecessary
1507
1548
# reclone of tensors in case they have already been cloned.
@@ -2547,11 +2588,7 @@ def _finalize_params(fsdp_module: FullyShardedDataParallel) -> None:
2547
2588
if isinstance (m , FullyShardedDataParallel ):
2548
2589
_finalize_params (m )
2549
2590
m ._pre_backward_hook_has_run = False
2550
- if any (
2551
- p not in self ._ignored_param_to_param_name
2552
- and p .requires_grad
2553
- for p in m .parameters ()
2554
- ):
2591
+ if any (p .requires_grad for p in m .parameters ()):
2555
2592
# Check if the module has params and if any of them has
2556
2593
# the `requires_grad` field set. If `requires_grad=False` for
2557
2594
# all the params, the post_backward hook will not fire and the
@@ -3477,25 +3514,19 @@ def _get_param_to_unflat_param_names(
3477
3514
model (torch.nn.Module): Root module (which may or may not be a
3478
3515
:class:`FullyShardedDataParallel` instance).
3479
3516
"""
3480
- param_to_unflat_param_names : Dict [torch .nn .Parameter , List [str ]] = {}
3481
-
3482
- def clean_param_name (prefix , param_info ):
3517
+ def _clean_param_name (prefix , param_info ):
3483
3518
"""This replicates the parameter name cleaning logic in model state
3484
3519
dict but avoids gathering any parameters."""
3485
- name = prefix + param_info .module_name + "." + param_info .param_name
3486
- # FSDP full parameter names may not have both (i.e. `FSDP_PREFIX`), so
3487
- # we call `replace()` twice separately
3488
- name = name .replace (FSDP_WRAPPED_MODULE + "." , "" )
3489
- name = name .replace (FPW_MODULE + "." , "" )
3520
+ name = clean_param_name (
3521
+ prefix + param_info .module_name + "." + param_info .param_name
3522
+ )
3490
3523
return name
3491
3524
3492
- def f (param_to_unflat_param_names , module : torch .nn .Module , prefix : str ):
3493
- # For FSDP modules, only add the entry when considering the contained
3494
- # `FlattenParamsWrapper` to avoid duplication
3525
+ def module_fn (module , prefix , param_to_unflat_param_names ):
3495
3526
if not isinstance (module , FullyShardedDataParallel ):
3496
3527
for param_name , param in module .named_parameters (recurse = False ):
3497
3528
prefixed_param_names = [
3498
- clean_param_name (prefix , param_info )
3529
+ _clean_param_name (prefix , param_info )
3499
3530
for param_info in param ._param_infos
3500
3531
] if isinstance (param , FlatParameter ) else [prefix + param_name ]
3501
3532
# If this parameter has already been visited, then it is a
@@ -3504,13 +3535,13 @@ def f(param_to_unflat_param_names, module: torch.nn.Module, prefix: str):
3504
3535
if not is_shared_param :
3505
3536
param_to_unflat_param_names [param ] = prefixed_param_names
3506
3537
3507
- for submodule_name , submodule in module .named_children ():
3508
- if submodule is not None :
3509
- new_prefix = prefix + submodule_name + "."
3510
- f (param_to_unflat_param_names , submodule , new_prefix )
3538
+ def return_fn (param_to_unflat_param_names ):
3539
+ return param_to_unflat_param_names
3511
3540
3512
- f (param_to_unflat_param_names , model , "" )
3513
- return param_to_unflat_param_names
3541
+ param_to_unflat_param_names : Dict [torch .nn .Parameter , List [str ]] = {}
3542
+ return _apply_to_modules (
3543
+ model , module_fn , return_fn , param_to_unflat_param_names ,
3544
+ )
3514
3545
3515
3546
3516
3547
def _get_param_to_param_name (
@@ -3550,3 +3581,12 @@ def _get_param_name_to_param(
3550
3581
"""Constructs the inverse mapping of :meth:`_get_param_to_param_name`."""
3551
3582
param_to_param_name = _get_param_to_param_name (model )
3552
3583
return dict (zip (param_to_param_name .values (), param_to_param_name .keys ()))
3584
+
3585
+
3586
+ def clean_param_name (param_name : str ) -> str :
3587
+ """Cleans the parameter name by removing any FSDP-related prefixes."""
3588
+ # FSDP full parameter names may not have both (i.e. `FSDP_PREFIX`), so we
3589
+ # call `replace()` twice separately
3590
+ param_name = param_name .replace (FSDP_WRAPPED_MODULE + "." , "" )
3591
+ param_name = param_name .replace (FPW_MODULE + "." , "" )
3592
+ return param_name
0 commit comments