Skip to content

Commit 807a214

Browse files
Levy Zhaofacebook-github-bot
Levy Zhao
authored andcommitted
Multitple fixes to MC modules to facilitate integration (#1391)
Summary: Some bug fixes during the integration test in PyPER O3: ### fix #1 `_embedding_bag_collection` (`ShardedEmbeddingBagCollection`) is not really called by input_dist (because the same thing is already distributed by ShardedManagedCollisionCollection) . So it never get a chance to initiate `_input_dist`. As a result, TREC pipelining thinks it's not ready for input distribution. This is not expected, since the module is not used in the stage anyway, nor should it be put in fused a2a communication. With this change, https://fburl.com/code/ud8lnixv it'll satisfy the assertion, meanwhile doesn't carry _input_dists so won't be put into fused a2a. ### fix #2 ManagedCollisionCollection.forward is not traceable because it uses unwarpped `KeyedJaggedTensor.from_jt_dict`. We don't care about its internal detail so just keep it atomic. ### fix #3 Due to how remap table is set, `MCHManagedCollisionModule` doesn't support i32 id list for now. An easy fix is to convert to i64 regardless. A more memory efficient fix is probably change the remapper to i32 if necessary Differential Revision: D48804332
1 parent 1ffff9b commit 807a214

File tree

4 files changed

+63
-2
lines changed

4 files changed

+63
-2
lines changed

torchrec/distributed/mc_embeddingbag.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def __init__(
9191
device=device,
9292
)
9393
)
94+
# TODO: This is a hack since _embedding_bag_collection doesn't need input
95+
# dist, so eliminating it so all fused a2a will ignore it.
96+
self._embedding_bag_collection._has_uninitialized_input_dist = False
9497
self._managed_collision_collection: ShardedManagedCollisionCollection = mc_sharder.shard(
9598
module._managed_collision_collection,
9699
table_name_to_parameter_sharding,

torchrec/distributed/tests/test_mc_embeddingbag.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
import torch.nn as nn
14+
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
1415
from torchrec.distributed.mc_embeddingbag import (
1516
ManagedCollisionEmbeddingBagCollectionSharder,
1617
ShardedManagedCollisionEmbeddingBagCollection,
@@ -160,6 +161,22 @@ def _test_sharding( # noqa C901
160161
assert isinstance(
161162
sharded_sparse_arch._mc_ebc, ShardedManagedCollisionEmbeddingBagCollection
162163
)
164+
assert isinstance(
165+
sharded_sparse_arch._mc_ebc._embedding_bag_collection,
166+
ShardedEmbeddingBagCollection,
167+
)
168+
assert (
169+
sharded_sparse_arch._mc_ebc._embedding_bag_collection._has_uninitialized_input_dist
170+
is False
171+
)
172+
assert (
173+
not hasattr(
174+
sharded_sparse_arch._mc_ebc._embedding_bag_collection, "_input_dists"
175+
)
176+
or len(sharded_sparse_arch._mc_ebc._embedding_bag_collection._input_dists)
177+
== 0
178+
)
179+
163180
assert isinstance(
164181
sharded_sparse_arch._mc_ebc._managed_collision_collection,
165182
ShardedManagedCollisionCollection,

torchrec/modules/mc_modules.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ def apply_mc_method_to_jt_dict(
4343
return mc_output
4444

4545

46+
@torch.fx.wrap
47+
def coalesce_feature_dict(features_dict: Dict[str, JaggedTensor]) -> KeyedJaggedTensor:
48+
return KeyedJaggedTensor.from_jt_dict(features_dict)
49+
50+
4651
class ManagedCollisionModule(nn.Module):
4752
"""
4853
Abstract base class for ManagedCollisionModule.
@@ -190,7 +195,7 @@ def forward(
190195
table_to_features=self._table_to_features,
191196
managed_collisions=self._managed_collision_modules,
192197
)
193-
return KeyedJaggedTensor.from_jt_dict(features_dict)
198+
return coalesce_feature_dict(features_dict)
194199

195200
def evict(self) -> Dict[str, Optional[torch.Tensor]]:
196201
evictions: Dict[str, Optional[torch.Tensor]] = {}
@@ -818,7 +823,10 @@ def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
818823

819824
remapped_features: Dict[str, JaggedTensor] = {}
820825
for name, feature in features.items():
821-
values = feature.values()
826+
# TODO: This is a temporary hack to support i32 ID list so it could
827+
# match remapper size. A more memory-efficient fix would be make
828+
# remapper i32-tensor instead.
829+
values = feature.values().to(torch.int64)
822830
remapped_ids = torch.empty_like(values)
823831

824832
# compute overlap between incoming IDs and remapping table

torchrec/modules/tests/test_mc_embedding_modules.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,39 @@ def test_zch_ebc_eval(self) -> None:
256256

257257
assert torch.all(remapped_kjt4["f2"].values() == remapped_kjt2["f2"].values())
258258

259+
def test_mc_collection_traceable(self) -> None:
260+
device = torch.device("cpu")
261+
zch_size = 20
262+
update_interval = 2
263+
264+
embedding_configs = [
265+
EmbeddingBagConfig(
266+
name="t1",
267+
embedding_dim=8,
268+
num_embeddings=zch_size,
269+
feature_names=["f1", "f2"],
270+
),
271+
]
272+
mc_modules = {
273+
"t1": cast(
274+
ManagedCollisionModule,
275+
MCHManagedCollisionModule(
276+
zch_size=zch_size,
277+
device=device,
278+
input_hash_size=2 * zch_size,
279+
eviction_interval=update_interval,
280+
eviction_policy=DistanceLFU_EvictionPolicy(),
281+
),
282+
),
283+
}
284+
mcc = ManagedCollisionCollection(
285+
managed_collision_modules=mc_modules,
286+
# pyre-ignore[6]
287+
embedding_configs=embedding_configs,
288+
)
289+
gm: torch.fx.GraphModule = torch.fx.symbolic_trace(mcc)
290+
gm.print_readable()
291+
259292
def test_mch_ebc(self) -> None:
260293
device = torch.device("cpu")
261294
zch_size = 10

0 commit comments

Comments
 (0)