Skip to content

Commit

Permalink
[FSDP] Do not clean FQNs even for use_orig_params=True (pytorch#91767)
Browse files Browse the repository at this point in the history
Cleaning FQN for `FullyShardedDataParallel(use_orig_params=True)` can cause some discrepancies with respect to the FQN compared to manually looping over `named_modules()` and `named_parameters()` together.

There is no requirement for the FQNs to be clean when using wrapper FSDP + `use_orig_params=True`. We can leave clean FQNs to `fully_shard`.
Pull Request resolved: pytorch#91767
Approved by: https://github.com/zhaojuanmao
  • Loading branch information
awgu authored and pytorchmergebot committed Jan 12, 2023
1 parent 7f50ff1 commit a383789
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 61 deletions.
10 changes: 8 additions & 2 deletions test/distributed/_composable/test_fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import torch.nn as nn
from torch.distributed._composable import fully_shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _FSDPState, _is_fsdp_flattened
from torch.distributed.fsdp._common_utils import (
_FSDPState,
_is_fsdp_flattened,
clean_tensor_name,
)
from torch.distributed.fsdp.api import MixedPrecision
from torch.distributed.fsdp.flat_param import _HandlesKey, FlatParamHandle
from torch.distributed.fsdp.wrap import _FSDPPolicy, ModuleWrapPolicy
Expand Down Expand Up @@ -258,7 +262,9 @@ def _param_init_fn(module: nn.Module):
composable_module.named_parameters(),
fsdp_wrapped_model.named_parameters(),
):
self.assertEqual(composable_param_name, fsdp_wrapped_param_name)
self.assertEqual(
composable_param_name, clean_tensor_name(fsdp_wrapped_param_name)
)
self.assertEqual(
composable_param.device,
torch.device("cuda", torch.cuda.current_device()),
Expand Down
59 changes: 8 additions & 51 deletions test/distributed/fsdp/test_fsdp_use_orig_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import itertools
import sys
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import torch.nn as nn
Expand Down Expand Up @@ -189,7 +189,8 @@ def _check_ddp_fsdp_param_parity(self, ddp_model: DDP, fsdp_model: FSDP):
for (n1, p1), (n2, p2) in zip(
ddp_model.module.named_parameters(), fsdp_model.named_parameters()
):
self.assertEqual(n1, n2)
# Allow for FSDP prefixes
self.assertEqual(n1, clean_tensor_name(n2))
torch.testing.assert_close(p1, p2)

def _get_sharding_strategy_from_str(
Expand Down Expand Up @@ -448,7 +449,7 @@ def run_iter():
ddp_model.module.named_parameters(),
fsdp_model.named_parameters(),
):
self.assertEqual(ddp_n, fsdp_n)
self.assertEqual(ddp_n, clean_tensor_name(fsdp_n))
if fsdp_p.numel() == 0:
# Not in this rank's shard
self.assertTrue(fsdp_p.grad is None)
Expand Down Expand Up @@ -961,53 +962,6 @@ def test_writeback_shape_mismatch(self):


class TestFSDPUseOrigParamsFQNs(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_param_and_buffer_names(self):
"""
Tests that, for ``use_orig_params=True``, the parameter and buffer
names match those of a local model even when sharded, meaning that they
do not include FSDP-specific prefixes.
"""
self.run_subtests(
{"auto_wrap_policy": [None, always_wrap_policy]},
self._test_param_and_buffer_names,
)

def _test_param_and_buffer_names(self, auto_wrap_policy: Optional[Callable]):
class Container(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.randn((5, 5)))
self.register_buffer("buf", torch.randn((5, 5)))

def forward(self, x):
return x @ self.param + self.buf

class Model(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.randn((5, 5)))
self.lin = nn.Linear(5, 5)
self.container = Container()
self.register_buffer("buf", torch.randn((5, 5)))

def forward(self, x):
z = self.container(x)
z = z @ self.param + self.buf
z = self.lin(z)
return z

model = Model()
fsdp_model = FSDP(
Model(), auto_wrap_policy=auto_wrap_policy, use_orig_params=True
)
param_names = [n for n, _ in model.named_parameters()]
fsdp_param_names = [n for n, _ in fsdp_model.named_parameters()]
self.assertEqual(param_names, fsdp_param_names)
buffer_names = [n for n, _ in model.named_buffers()]
fsdp_buffer_names = [n for n, _ in fsdp_model.named_buffers()]
self.assertEqual(buffer_names, fsdp_buffer_names)

@skip_if_lt_x_gpu(2)
def test_named_parameters_in_forward(self):
"""
Expand All @@ -1024,7 +978,10 @@ def __init__(self) -> None:

def forward(self, x: torch.Tensor) -> torch.Tensor:
nonlocal param_shapes
param_names = [tup[0] for tup in self.named_parameters()]
# Allow for FSDP prefixes
param_names = [
clean_tensor_name(tup[0]) for tup in self.named_parameters()
]
params = [tup[1] for tup in self.named_parameters()]
assert (
param_shapes[0] is not None and param_shapes[1] is not None
Expand Down
10 changes: 2 additions & 8 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,10 +883,7 @@ def named_buffers(
remove all occurrences of the FSDP-specific flattened buffer prefix
when inside the :meth:`summon_full_params` context manager.
"""
should_clean_name = (
self.training_state == TrainingState.SUMMON_FULL_PARAMS
or self._use_orig_params
)
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
for buffer_name, buffer in super().named_buffers(*args, **kwargs):
if should_clean_name:
# Remove any instances of the FSDP-specific prefix; there can
Expand All @@ -904,10 +901,7 @@ def named_parameters(
remove all occurrences of the FSDP-specific flattened parameter prefix
when inside the :meth:`summon_full_params` context manager.
"""
should_clean_name = (
self.training_state == TrainingState.SUMMON_FULL_PARAMS
or self._use_orig_params
)
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
for param_name, param in super().named_parameters(*args, **kwargs):
if should_clean_name:
# Remove any instances of the FSDP-specific prefix; there can
Expand Down

0 comments on commit a383789

Please sign in to comment.