Skip to content

fix state_dict compatibility with ebc family #703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions torchrec/distributed/quant_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import OrderedDict
from typing import Any, Dict, List, Optional, Type

import torch
from torch import nn
from torch.nn.modules.module import _IncompatibleKeys
from torchrec.distributed.embedding_sharding import (
EmbeddingSharding,
EmbeddingShardingInfo,
Expand Down Expand Up @@ -36,6 +38,7 @@
ShardingEnv,
ShardingType,
)
from torchrec.distributed.utils import filter_state_dict
from torchrec.modules.embedding_configs import (
data_type_to_sparse_type,
dtype_to_data_type,
Expand Down Expand Up @@ -118,27 +121,6 @@ def __init__(
self._has_uninitialized_output_dist: bool = True
self._has_features_permute: bool = True

# This provides consistency between this class and the EmbeddingBagCollection's
# nn.Module API calls (state_dict, named_modules, etc)
# Currently, Sharded Quant EBC only uses TW sharding, and returns non-sharded tensors as part of state dict
# TODO - revisit if we state_dict can be represented as sharded tensor
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
for table in self._embedding_bag_configs:
self.embedding_bags[table.name] = torch.nn.Module()

for _sharding_type, lookup in zip(
self._sharding_type_to_sharding.keys(), self._lookups
):
lookup_state_dict = lookup.state_dict()
for key in lookup_state_dict:
if not key.endswith(".weight"):
continue
table_name = key[: -len(".weight")]
# Register as buffer because this is an inference model, and can potentially use uint8 types.
self.embedding_bags[table_name].register_buffer(
"weight", lookup_state_dict[key]
)

def _create_input_dist(
self,
input_feature_names: List[str],
Expand Down Expand Up @@ -246,6 +228,41 @@ def compute_and_output_dist(
) -> LazyAwaitable[KeyedTensor]:
return self.output_dist(ctx, self.compute(ctx, input))

# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
def state_dict(
self,
destination: Optional[Dict[str, Any]] = None,
prefix: str = "",
keep_vars: bool = False,
) -> Dict[str, Any]:
if destination is None:
destination = OrderedDict()
# pyre-ignore [16]
destination._metadata = OrderedDict()
for lookup in self._lookups:
lookup.state_dict(destination, prefix + "embedding_bags.", keep_vars)
return destination

# pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`
# inconsistently.
def load_state_dict(
self,
state_dict: "OrderedDict[str, torch.Tensor]",
strict: bool = True,
) -> _IncompatibleKeys:
missing_keys = []
unexpected_keys = []
for lookup in self._lookups:
missing, unexpected = lookup.load_state_dict(
filter_state_dict(state_dict, "embedding_bags"),
strict,
)
missing_keys.extend(missing)
unexpected_keys.extend(unexpected)
return _IncompatibleKeys(
missing_keys=missing_keys, unexpected_keys=unexpected_keys
)

def copy(self, device: torch.device) -> nn.Module:
if self._has_uninitialized_output_dist:
self._create_output_dist(device)
Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/tests/test_quant_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ def test_quant_pred_state_dict(self, output_type: torch.dtype) -> None:
tables=self.tables,
weighted_tables=self.weighted_tables,
)

print(dmp_copy.state_dict().keys())
print(dmp.state_dict().keys())
# pyre-ignore
dmp_copy.load_state_dict(dmp.state_dict())
torch.testing.assert_close(
Expand Down