Skip to content

fix dmp copy #1100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import abc
import copy
from collections import OrderedDict
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple

Expand Down Expand Up @@ -34,6 +35,7 @@
append_prefix,
copy_to_device,
filter_state_dict,
sharded_model_copy,
)
from torchrec.optim.fused import FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
Expand Down Expand Up @@ -278,8 +280,10 @@ def copy(
`ShardedModule` for inference).
"""
assert isinstance(device, torch.device)
copy_dmp = copy_to_device(self, self.device, device)
return cast(DistributedModelParallel, copy_dmp)
with sharded_model_copy(device=None):
dmp = copy.deepcopy(self)
dmp._dmp_wrapped_module = copy_to_device(self.module, self.device, device)
return dmp

def _init_dmp(self, module: nn.Module) -> nn.Module:
return self._shard_modules_impl(module)
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/tests/test_quant_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def test_copy_mixin(self) -> None:
tables=self.tables,
weighted_tables=self.weighted_tables,
num_float_features=10,
dense_device=device,
dense_device=torch.device("cpu"),
sparse_device=torch.device("meta"),
)
# pyre-ignore [16]
Expand Down