Skip to content

Added remove_duplicate parameter to nn.Module (#6) #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,12 +424,12 @@ def fused_optimizer(self) -> FusedOptimizer:
return self._optim

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
yield from ()

def named_buffers(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for config, param in zip(
self._config.embedding_tables,
Expand Down Expand Up @@ -471,7 +471,7 @@ def emb_module(
return self._emb_module

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
combined_key = "/".join(
[config.name for config in self._config.embedding_tables]
Expand Down Expand Up @@ -678,12 +678,12 @@ def fused_optimizer(self) -> FusedOptimizer:
return self._optim

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
yield from ()

def named_buffers(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for config, param in zip(
self._config.embedding_tables,
Expand Down Expand Up @@ -725,7 +725,7 @@ def emb_module(
return self._emb_module

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
combined_key = "/".join(
[config.name for config in self._config.embedding_tables]
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def state_dict(
)

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
for config, emb_module in zip(
self._config.embedding_tables,
Expand Down Expand Up @@ -320,7 +320,7 @@ def state_dict(
)

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
for config, emb_module in zip(
self._config.embedding_tables,
Expand Down
12 changes: 6 additions & 6 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,13 @@ def load_state_dict(
return _IncompatibleKeys(missing_keys=m, unexpected_keys=u)

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
for emb_module in self._emb_modules:
yield from emb_module.named_parameters(prefix, recurse)

def named_buffers(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for emb_module in self._emb_modules:
yield from emb_module.named_buffers(prefix, recurse)
Expand Down Expand Up @@ -370,15 +370,15 @@ def load_state_dict(
return _IncompatibleKeys(missing_keys=m1 + m2, unexpected_keys=u1 + u2)

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
for emb_module in self._emb_modules:
yield from emb_module.named_parameters(prefix, recurse)
for emb_module in self._score_emb_modules:
yield from emb_module.named_parameters(prefix, recurse)

def named_buffers(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for emb_module in self._emb_modules:
yield from emb_module.named_buffers(prefix, recurse)
Expand Down Expand Up @@ -466,13 +466,13 @@ def load_state_dict(
)

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
for rank_modules in self._embedding_lookups_per_rank:
yield from rank_modules.named_parameters(prefix, recurse)

def named_buffers(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for rank_modules in self._embedding_lookups_per_rank:
yield from rank_modules.named_buffers(prefix, recurse)
Expand Down
8 changes: 4 additions & 4 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def named_modules(
yield from [(prefix, self)]

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
for lookup in self._lookups:
yield from lookup.named_parameters(
Expand All @@ -460,7 +460,7 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
yield name

def named_buffers(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for lookup in self._lookups:
yield from lookup.named_buffers(
Expand Down Expand Up @@ -731,7 +731,7 @@ def named_modules(
yield from [(prefix, self)]

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
for name, parameter in self._lookup.named_parameters("", recurse):
# update name to match embeddingBag parameter name
Expand All @@ -745,7 +745,7 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
yield append_prefix(prefix, name.split(".")[-1])

def named_buffers(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for name, buffer in self._lookup.named_buffers("", recurse):
yield append_prefix(prefix, name.split(".")[-1]), buffer
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/grouped_position_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
)

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
for name, param in self.position_weights.items():
yield append_prefix(prefix, f"position_weights.{name}"), param

def named_buffers(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
yield from ()

Expand Down
35 changes: 26 additions & 9 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,11 @@ def _load_state_dict(
)

def _named_parameters(
self, module: nn.Module, prefix: str = "", recurse: bool = True
self,
module: nn.Module,
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True,
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
if isinstance(module, ShardedModule):
yield from module.named_parameters(prefix, recurse)
Expand All @@ -408,9 +412,11 @@ def _named_parameters(
)

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
yield from self._named_parameters(self.dmp_module, prefix, recurse)
yield from self._named_parameters(
self.dmp_module, prefix, recurse, remove_duplicate
)

@staticmethod
def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[str]:
Expand All @@ -423,21 +429,32 @@ def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[st
)

def _named_buffers(
self, module: nn.Module, prefix: str = "", recurse: bool = True
self,
module: nn.Module,
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True,
) -> Iterator[Tuple[str, torch.Tensor]]:
if isinstance(module, ShardedModule):
yield from module.named_buffers(prefix, recurse)
yield from module.named_buffers(prefix, recurse, remove_duplicate)
else:
yield from module.named_buffers(prefix, recurse=False)
yield from module.named_buffers(
prefix, recurse=False, remove_duplicate=True
)
for name, child in module.named_children():
yield from self._named_buffers(
child, append_prefix(prefix, name), recurse
child, append_prefix(prefix, name), recurse, remove_duplicate
)

def named_buffers(
self, prefix: str = "", recurse: bool = True
self,
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True,
) -> Iterator[Tuple[str, torch.Tensor]]:
yield from self._named_buffers(self.dmp_module, prefix, recurse)
yield from self._named_buffers(
self.dmp_module, prefix, recurse, remove_duplicate
)

@property
def fused_optimizer(self) -> KeyedOptimizer:
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
)

def named_buffers(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for config, weight in zip(
self._config.embedding_tables,
Expand Down
2 changes: 1 addition & 1 deletion torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def state_dict(
return destination

def named_buffers(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
state_dict = self.state_dict(prefix=prefix, keep_vars=True)
for key, value in state_dict.items():
Expand Down