Skip to content

MCM Fix for ig integration #1541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchrec/distributed/mc_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def __init__(
device=device,
)
)
# TODO: This is a hack since _embedding_bag_collection doesn't need input
# dist, so eliminating it so all fused a2a will ignore it.
self._embedding_bag_collection._has_uninitialized_input_dist = False
self._managed_collision_collection: ShardedManagedCollisionCollection = mc_sharder.shard(
module._managed_collision_collection,
table_name_to_parameter_sharding,
Expand Down
17 changes: 17 additions & 0 deletions torchrec/distributed/tests/test_mc_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
import torch.nn as nn
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
from torchrec.distributed.mc_embeddingbag import (
ManagedCollisionEmbeddingBagCollectionSharder,
ShardedManagedCollisionEmbeddingBagCollection,
Expand Down Expand Up @@ -220,6 +221,22 @@ def _test_sharding_and_remapping( # noqa C901
assert isinstance(
sharded_sparse_arch._mc_ebc, ShardedManagedCollisionEmbeddingBagCollection
)
assert isinstance(
sharded_sparse_arch._mc_ebc._embedding_bag_collection,
ShardedEmbeddingBagCollection,
)
assert (
sharded_sparse_arch._mc_ebc._embedding_bag_collection._has_uninitialized_input_dist
is False
)
assert (
not hasattr(
sharded_sparse_arch._mc_ebc._embedding_bag_collection, "_input_dists"
)
or len(sharded_sparse_arch._mc_ebc._embedding_bag_collection._input_dists)
== 0
)

assert isinstance(
sharded_sparse_arch._mc_ebc._managed_collision_collection,
ShardedManagedCollisionCollection,
Expand Down
9 changes: 7 additions & 2 deletions torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def apply_mc_method_to_jt_dict(
return mc_output


@torch.fx.wrap
def coalesce_feature_dict(features_dict: Dict[str, JaggedTensor]) -> KeyedJaggedTensor:
return KeyedJaggedTensor.from_jt_dict(features_dict)


class ManagedCollisionModule(nn.Module):
"""
Abstract base class for ManagedCollisionModule.
Expand Down Expand Up @@ -190,7 +195,7 @@ def forward(
table_to_features=self._table_to_features,
managed_collisions=self._managed_collision_modules,
)
return KeyedJaggedTensor.from_jt_dict(features_dict)
return coalesce_feature_dict(features_dict)

def evict(self) -> Dict[str, Optional[torch.Tensor]]:
evictions: Dict[str, Optional[torch.Tensor]] = {}
Expand Down Expand Up @@ -818,7 +823,7 @@ def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:

remapped_features: Dict[str, JaggedTensor] = {}
for name, feature in features.items():
values = feature.values()
values = feature.values().to(torch.int64)
remapped_ids = torch.empty_like(values)

# compute overlap between incoming IDs and remapping table
Expand Down
33 changes: 33 additions & 0 deletions torchrec/modules/tests/test_mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,39 @@ def test_zch_ebc_eval(self) -> None:

assert torch.all(remapped_kjt4["f2"].values() == remapped_kjt2["f2"].values())

def test_mc_collection_traceable(self) -> None:
device = torch.device("cpu")
zch_size = 20
update_interval = 2

embedding_configs = [
EmbeddingBagConfig(
name="t1",
embedding_dim=8,
num_embeddings=zch_size,
feature_names=["f1", "f2"],
),
]
mc_modules = {
"t1": cast(
ManagedCollisionModule,
MCHManagedCollisionModule(
zch_size=zch_size,
device=device,
input_hash_size=2 * zch_size,
eviction_interval=update_interval,
eviction_policy=DistanceLFU_EvictionPolicy(),
),
),
}
mcc = ManagedCollisionCollection(
managed_collision_modules=mc_modules,
# pyre-ignore[6]
embedding_configs=embedding_configs,
)
gm: torch.fx.GraphModule = torch.fx.symbolic_trace(mcc)
gm.print_readable()

def test_mch_ebc(self) -> None:
device = torch.device("cpu")
zch_size = 10
Expand Down