Skip to content

Commit

Permalink
expose all DDP params in DefaultDDPWrapper (pytorch#2329)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2329

TSIA

Reviewed By: yuhuishi-convect, joshuadeng

Differential Revision: D61613723

fbshipit-source-id: 343c9631dd714cfd0d64fdac3c3210afc2d744f5
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Aug 22, 2024
1 parent 4ed441e commit 9418355
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ def __init__(
find_unused_parameters: bool = False,
allreduce_comm_precision: Optional[str] = None,
params_to_ignore: Optional[List[str]] = None,
ddp_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self._bucket_cap_mb: int = bucket_cap_mb
self._static_graph: bool = static_graph
self._find_unused_parameters: bool = find_unused_parameters
self._allreduce_comm_precision = allreduce_comm_precision
self._additional_params_to_ignore: Set[str] = set(params_to_ignore or [])
self._ddp_kwargs: Dict[str, Any] = ddp_kwargs or {}

def _ddp_wrap(
self,
Expand Down Expand Up @@ -114,6 +116,7 @@ def _ddp_wrap(
static_graph=self._static_graph,
find_unused_parameters=self._find_unused_parameters,
bucket_cap_mb=self._bucket_cap_mb,
**self._ddp_kwargs,
),
)
if self._allreduce_comm_precision == "fp16":
Expand Down

0 comments on commit 9418355

Please sign in to comment.