Skip to content

Commit 73b33de

Browse files
Andrew Gupytorchmergebot
authored andcommitted
[FSDP] Include buffers in ignored_modules
Pull Request resolved: pytorch#76784 Approved by: https://github.com/rohan-varma
1 parent 33fabe9 commit 73b33de

File tree

3 files changed

+150
-74
lines changed

3 files changed

+150
-74
lines changed

test/distributed/fsdp/test_fsdp_state_dict.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252

5353
INNER_SHAPE = [4, 4]
5454
OUTER_SHAPE = [4, 5]
55+
BUFFER_SHAPE = [5, 5]
5556

5657
_SUPPORTED_STATE_DICT_IMPLS = ["state_dict", "local_state_dict"]
5758

@@ -63,12 +64,14 @@
6364

6465

6566
class Model(Module):
66-
def __init__(self, wrap_fsdp):
67+
def __init__(self, wrap_fsdp, register_buffer=False):
6768
super().__init__()
6869
self.inner = Linear(*INNER_SHAPE)
6970
if wrap_fsdp:
7071
self.inner = FSDP(self.inner)
7172
self.outer = Linear(*OUTER_SHAPE)
73+
if register_buffer:
74+
self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
7275

7376
def forward(self, x):
7477
# Forward twice.
@@ -444,34 +447,42 @@ def test_wrong_state_dict_config(self):
444447

445448
@skip_if_lt_x_gpu(2)
446449
def test_state_dict_with_ignored_modules(self):
447-
# Initialize an FSDP-wrapped model with an ignored module
448-
model = Model(wrap_fsdp=True).cuda()
450+
# Initialize an FSDP-wrapped model with an ignored module that includes
451+
# both parameters and a buffer
452+
model = Model(wrap_fsdp=True, register_buffer=True).cuda()
449453
ignored_modules = [model.outer]
450-
ignored_param_to_param_name = {
454+
ignored_tensor_to_tensor_name = {
451455
model.outer.bias: "outer.bias", model.outer.weight: "outer.weight",
456+
model.outer.buffer: "outer.buffer",
452457
}
453458
fsdp_model = FSDP(model, ignored_modules=ignored_modules)
454459
with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
455-
sd = fsdp_model.state_dict()
456-
460+
sd1 = fsdp_model.state_dict()
457461
with FSDP.summon_full_params(fsdp_model):
458462
fsdp_params = deepcopy(list(fsdp_model.parameters()))
459463
# Check that the ignored parameters are not cloned
460-
461-
for param, param_name in ignored_param_to_param_name.items():
462-
self.assertTrue(param_name in sd)
463-
self.assertEqual(param.data_ptr(), sd[param_name].data_ptr())
464+
for tensor, tensor_name in ignored_tensor_to_tensor_name.items():
465+
self.assertTrue(tensor_name in sd1)
466+
self.assertEqual(tensor.data_ptr(), sd1[tensor_name].data_ptr())
464467
# Check that the state dict can be loaded into a non-wrapped version of
465468
# the model
466-
nonwrapped_model = Model(wrap_fsdp=False).cuda()
469+
nonwrapped_model = Model(wrap_fsdp=False, register_buffer=True).cuda()
467470
for param in nonwrapped_model.parameters():
468471
with torch.no_grad():
469472
param.zero_()
470-
471-
nonwrapped_model.load_state_dict(sd)
473+
nonwrapped_model.load_state_dict(sd1)
472474
local_params = list(nonwrapped_model.parameters())
473475
for fsdp_param, local_param in zip(fsdp_params, local_params):
474476
self.assertEqual(fsdp_param, local_param)
477+
# Check that if we save a state dict again, the ignored parameters and
478+
# buffers still have the same data pointer
479+
with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
480+
sd2 = fsdp_model.state_dict()
481+
for tensor, tensor_name in ignored_tensor_to_tensor_name.items():
482+
self.assertTrue(tensor_name in sd1) # check again just in case
483+
self.assertTrue(tensor_name in sd2)
484+
self.assertEqual(tensor.data_ptr(), sd2[tensor_name].data_ptr())
485+
self.assertEqual(sd1[tensor_name].data_ptr(), sd2[tensor_name].data_ptr())
475486

476487

477488
instantiate_parametrized_tests(TestFSDPStateDict)

torch/distributed/fsdp/_utils.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
from typing import Dict, List, Tuple, Union, Any, Callable, Set
2-
from torch.nn.utils.rnn import PackedSequence
1+
from collections import OrderedDict
2+
from typing import Any, Callable, Dict, List, Set, Tuple, Union
33

44
import torch
5-
6-
from collections import OrderedDict
5+
from torch.nn.utils.rnn import PackedSequence
76

87
"""Useful functions to deal with tensor types with other python container types."""
98

@@ -56,3 +55,29 @@ def _replace_by_prefix(
5655
new_key = new_prefix + key[len(old_prefix) :]
5756
state_dict[new_key] = state_dict[key]
5857
del state_dict[key]
58+
59+
60+
def _apply_to_modules(
61+
root_module: torch.nn.Module,
62+
module_fn: Callable,
63+
return_fn: Callable,
64+
*args,
65+
**kwargs,
66+
):
67+
"""
68+
Performs a pre-order traversal of the modules in the hierarchy rooted at
69+
``root_module``, applying ``module_fn`` at each module and finally
70+
returning a value using ``return_fn``. The traversal constructs the full
71+
module prefix name (e.g. "module.submodule." just like in model state dict)
72+
and makes that available to ``module_fn``.
73+
"""
74+
def f(module: torch.nn.Module, prefix: str, *args, **kwargs):
75+
# Call the module function before recursing over children (pre-order)
76+
module_fn(module, prefix, *args, **kwargs)
77+
for submodule_name, submodule in module.named_children():
78+
if submodule is not None:
79+
new_prefix = prefix + submodule_name + "."
80+
f(submodule, new_prefix, *args, **kwargs)
81+
82+
f(root_module, "", *args, **kwargs)
83+
return return_fn(*args, **kwargs)

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 97 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,6 @@
4040
from torch.distributed.distributed_c10d import _get_default_group
4141
from torch.nn.parameter import Parameter
4242

43-
from .flatten_params_wrapper import (
44-
FLAT_PARAM,
45-
FPW_MODULE,
46-
FlatParameter,
47-
FlattenParamsWrapper,
48-
)
4943
from ._optim_utils import (
5044
_broadcast_pos_dim_tensor_states,
5145
_broadcast_processed_optim_state_dict,
@@ -56,15 +50,21 @@
5650
_process_pos_dim_tensor_state,
5751
_unflatten_optim_state,
5852
)
59-
from ._utils import _apply_to_tensors, _replace_by_prefix
53+
from ._utils import _apply_to_modules, _apply_to_tensors, _replace_by_prefix
54+
from .flatten_params_wrapper import (
55+
FLAT_PARAM,
56+
FPW_MODULE,
57+
FlatParameter,
58+
FlattenParamsWrapper,
59+
)
6060
from .wrap import _recursive_wrap
6161

6262
if TYPE_CHECKING:
6363
from collections import OrderedDict # noqa: F401
6464

6565
_TORCHDISTX_AVAIL = True
6666
try:
67-
from torchdistx import fake, deferred_init
67+
from torchdistx import deferred_init, fake
6868
except ImportError:
6969
_TORCHDISTX_AVAIL = False
7070

@@ -490,10 +490,10 @@ class FullyShardedDataParallel(nn.Module):
490490
accuracy during model training. If ``None``, no mixed precision is applied.
491491
(Default: ``None``)
492492
ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose
493-
own parameters and child modules' parameters are ignored by this
494-
instance. None of the modules directly in ``ignored_modules``
495-
should be :class:`FullyShardedDataParallel` instances, and any
496-
child modules that are already-constructed
493+
own parameters and child modules' parameters and buffers are
494+
ignored by this instance. None of the modules directly in
495+
``ignored_modules`` should be :class:`FullyShardedDataParallel`
496+
instances, and any child modules that are already-constructed
497497
:class:`FullyShardedDataParallel` instances will not be ignored if
498498
they are nested under this instance. This argument may be used to
499499
avoid sharding specific parameters when using an
@@ -549,16 +549,7 @@ def __init__(
549549
# Save the ignored modules and their parameters, including the
550550
# parameter names, which are needed to filter the model state dict
551551
self._ignored_modules = self._get_ignored_modules(ignored_modules)
552-
ignored_params = self._get_ignored_params(self._ignored_modules)
553-
param_to_unflat_param_names = _get_param_to_unflat_param_names(module)
554-
self._ignored_param_to_param_name = {}
555-
for param in ignored_params:
556-
unflat_param_names = param_to_unflat_param_names[param]
557-
assert len(unflat_param_names) == 1, \
558-
"Only `FlatParameter`s can map to >1 unflattened parameter " \
559-
"name, and `_get_ignored_params()` should have excluded " \
560-
"them; check `_get_param_to_unflat_param_names()`"
561-
self._ignored_param_to_param_name[param] = unflat_param_names[0]
552+
ignored_params = self._get_ignored_parameters()
562553
# if auto_wrap_policy is specified, submodules should not be
563554
# already wrapped, otherwise we'd attempt to double wrap them resulting
564555
# in errors.
@@ -776,19 +767,67 @@ def _get_ignored_modules(
776767
)
777768
return ignored_modules
778769

779-
def _get_ignored_params(
780-
self,
781-
ignored_modules: Set[torch.nn.Module],
782-
) -> Set[torch.nn.Parameter]:
783-
"""
784-
Returns the parameters of the modules in ``ignored_modules`` as a
770+
def _get_ignored_parameters(self) -> Set[torch.nn.Parameter]:
771+
"""Returns the parameters of the modules in ``ignored_modules`` as a
785772
:class:`set`, excluding any :class:`FlatParameter` s.
786773
"""
774+
assert hasattr(self, "_ignored_modules"), \
775+
"Expects `self._ignored_modules` to be initialized"
787776
return set(
788-
p for m in ignored_modules for p in m.parameters()
777+
p for m in self._ignored_modules for p in m.parameters()
789778
if not isinstance(p, FlatParameter)
790779
)
791780

781+
def _get_ignored_named_tensors(
782+
self,
783+
ignored_modules: Set[torch.nn.Module],
784+
named_tensor_fn: Callable,
785+
) -> Set[Tuple[str, torch.Tensor]]:
786+
"""
787+
This performs a module walk to get the full parameter and buffer names
788+
depending on ``named_tensor_fn``, which should either be
789+
``named_parameters()`` or ``named_buffers()`. We require a separate
790+
:meth:`_get_ignored_parameters` that does not use this module walk
791+
since that method needs to be called in the FSDP constructor before any
792+
wrapping occurs, which means that we cannot start a module walk from
793+
``self`` as in this method.
794+
"""
795+
def module_fn(module, prefix, ignored_named_tensors, ignored_modules):
796+
if module in ignored_modules:
797+
assert not isinstance(module, FullyShardedDataParallel) and \
798+
not isinstance(module, FlattenParamsWrapper), \
799+
"Ignoring FSDP modules is meaningless since their " \
800+
"parameters are not flattened into this FSDP module anyway"
801+
for param_name, param in named_tensor_fn(module):
802+
prefixed_param_name = clean_param_name(prefix + param_name)
803+
ignored_named_tensors.add((prefixed_param_name, param))
804+
805+
def return_fn(ignored_named_tensors, *args):
806+
return ignored_named_tensors
807+
808+
ignored_named_tensors = set()
809+
return _apply_to_modules(
810+
self, module_fn, return_fn, ignored_named_tensors, ignored_modules,
811+
)
812+
813+
def _get_ignored_named_parameters(self) -> Set[Tuple[str, torch.Tensor]]:
814+
"""Returns the named parameters of the modules in ``ignored_modules``,
815+
excluding any :class:`FlatParameter` s."""
816+
assert hasattr(self, "_ignored_modules"), \
817+
"Expects `self._ignored_modules` to be initialized"
818+
return self._get_ignored_named_tensors(
819+
self._ignored_modules, lambda m: m.named_parameters(recurse=False),
820+
)
821+
822+
def _get_ignored_named_buffers(self) -> Set[Tuple[str, torch.Tensor]]:
823+
"""Returns the named buffers of the modules in ``ignored_modules``,
824+
excluding any :class:`FlatParameter` s."""
825+
assert hasattr(self, "_ignored_modules"), \
826+
"Expects `self._ignored_modules` to be initialized"
827+
return self._get_ignored_named_tensors(
828+
self._ignored_modules, lambda m: m.named_buffers(recurse=False),
829+
)
830+
792831
@classmethod
793832
def _check_wrapped(cls, begin_module, check_fn, err_fn):
794833
for _, mod in begin_module.named_modules():
@@ -1496,12 +1535,14 @@ def _full_post_state_dict_hook(
14961535
if not state_dict:
14971536
return state_dict
14981537

1499-
ignored_param_names = set(self._ignored_param_to_param_name.values())
1538+
ignored_named_params = self._get_ignored_named_parameters()
1539+
ignored_named_buffers = self._get_ignored_named_buffers()
1540+
ignored_names = set(n for n, _ in ignored_named_params)
1541+
ignored_names.update(n for n, _ in ignored_named_buffers)
15001542
for key in state_dict:
1501-
# Do not need to clone ignored parameters since they are not
1502-
# sharded
1503-
clean_param_name = key.replace(FSDP_WRAPPED_MODULE + ".", "").replace(FPW_MODULE + ".", "")
1504-
if clean_param_name in ignored_param_names:
1543+
# Do not need to clone ignored parameters and buffers since they
1544+
# are not sharded
1545+
if clean_param_name(key) in ignored_names:
15051546
continue
15061547
# Due to recursive call of summon_full_params, avoid unnecessary
15071548
# reclone of tensors in case they have already been cloned.
@@ -2547,11 +2588,7 @@ def _finalize_params(fsdp_module: FullyShardedDataParallel) -> None:
25472588
if isinstance(m, FullyShardedDataParallel):
25482589
_finalize_params(m)
25492590
m._pre_backward_hook_has_run = False
2550-
if any(
2551-
p not in self._ignored_param_to_param_name
2552-
and p.requires_grad
2553-
for p in m.parameters()
2554-
):
2591+
if any(p.requires_grad for p in m.parameters()):
25552592
# Check if the module has params and if any of them has
25562593
# the `requires_grad` field set. If `requires_grad=False` for
25572594
# all the params, the post_backward hook will not fire and the
@@ -3477,25 +3514,19 @@ def _get_param_to_unflat_param_names(
34773514
model (torch.nn.Module): Root module (which may or may not be a
34783515
:class:`FullyShardedDataParallel` instance).
34793516
"""
3480-
param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
3481-
3482-
def clean_param_name(prefix, param_info):
3517+
def _clean_param_name(prefix, param_info):
34833518
"""This replicates the parameter name cleaning logic in model state
34843519
dict but avoids gathering any parameters."""
3485-
name = prefix + param_info.module_name + "." + param_info.param_name
3486-
# FSDP full parameter names may not have both (i.e. `FSDP_PREFIX`), so
3487-
# we call `replace()` twice separately
3488-
name = name.replace(FSDP_WRAPPED_MODULE + ".", "")
3489-
name = name.replace(FPW_MODULE + ".", "")
3520+
name = clean_param_name(
3521+
prefix + param_info.module_name + "." + param_info.param_name
3522+
)
34903523
return name
34913524

3492-
def f(param_to_unflat_param_names, module: torch.nn.Module, prefix: str):
3493-
# For FSDP modules, only add the entry when considering the contained
3494-
# `FlattenParamsWrapper` to avoid duplication
3525+
def module_fn(module, prefix, param_to_unflat_param_names):
34953526
if not isinstance(module, FullyShardedDataParallel):
34963527
for param_name, param in module.named_parameters(recurse=False):
34973528
prefixed_param_names = [
3498-
clean_param_name(prefix, param_info)
3529+
_clean_param_name(prefix, param_info)
34993530
for param_info in param._param_infos
35003531
] if isinstance(param, FlatParameter) else [prefix + param_name]
35013532
# If this parameter has already been visited, then it is a
@@ -3504,13 +3535,13 @@ def f(param_to_unflat_param_names, module: torch.nn.Module, prefix: str):
35043535
if not is_shared_param:
35053536
param_to_unflat_param_names[param] = prefixed_param_names
35063537

3507-
for submodule_name, submodule in module.named_children():
3508-
if submodule is not None:
3509-
new_prefix = prefix + submodule_name + "."
3510-
f(param_to_unflat_param_names, submodule, new_prefix)
3538+
def return_fn(param_to_unflat_param_names):
3539+
return param_to_unflat_param_names
35113540

3512-
f(param_to_unflat_param_names, model, "")
3513-
return param_to_unflat_param_names
3541+
param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
3542+
return _apply_to_modules(
3543+
model, module_fn, return_fn, param_to_unflat_param_names,
3544+
)
35143545

35153546

35163547
def _get_param_to_param_name(
@@ -3550,3 +3581,12 @@ def _get_param_name_to_param(
35503581
"""Constructs the inverse mapping of :meth:`_get_param_to_param_name`."""
35513582
param_to_param_name = _get_param_to_param_name(model)
35523583
return dict(zip(param_to_param_name.values(), param_to_param_name.keys()))
3584+
3585+
3586+
def clean_param_name(param_name: str) -> str:
3587+
"""Cleans the parameter name by removing any FSDP-related prefixes."""
3588+
# FSDP full parameter names may not have both (i.e. `FSDP_PREFIX`), so we
3589+
# call `replace()` twice separately
3590+
param_name = param_name.replace(FSDP_WRAPPED_MODULE + ".", "")
3591+
param_name = param_name.replace(FPW_MODULE + ".", "")
3592+
return param_name

0 commit comments

Comments
 (0)