Skip to content

Commit 9e41c88

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
2D for embeddingcollection
Summary: Adding support for EmbeddingCollection modules in 2D parallel. This supports all sharding types that are supported for EC. Also fixes TWRW DTensor.Placement in 2D case. Differential Revision: D68980589
1 parent 1afbf08 commit 9e41c88

File tree

4 files changed

+322
-16
lines changed

4 files changed

+322
-16
lines changed

torchrec/distributed/embedding.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
3333
from torch.distributed._tensor import DTensor
3434
from torch.nn.parallel import DistributedDataParallel
35+
from torchrec.distributed.comm import get_local_size
3536
from torchrec.distributed.embedding_sharding import (
3637
EmbeddingSharding,
3738
EmbeddingShardingInfo,
@@ -69,13 +70,16 @@
6970
QuantizedCommCodecs,
7071
ShardedTensor,
7172
ShardingEnv,
73+
ShardingEnv2D,
7274
ShardMetadata,
7375
)
7476
from torchrec.distributed.utils import (
7577
add_params_from_parameter_sharding,
7678
convert_to_fbgemm_types,
79+
create_global_tensor_shape_stride_from_metadata,
7780
maybe_annotate_embedding_event,
7881
merge_fused_params,
82+
none_throws,
7983
optimizer_type_to_emb_opt_type,
8084
)
8185
from torchrec.modules.embedding_configs import (
@@ -534,12 +538,9 @@ def __init__(
534538
if table_name in self._table_names
535539
},
536540
)
537-
# output parameters as DTensor in state dict
538-
self._output_dtensor: bool = (
539-
fused_params.get("output_dtensor", False) if fused_params else False
540-
)
541-
542541
self._env = env
542+
# output parameters as DTensor in state dict
543+
self._output_dtensor: bool = env.output_dtensor
543544
# TODO get rid of get_ec_index_dedup global flag
544545
self._use_index_dedup: bool = use_index_dedup or get_ec_index_dedup()
545546
sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
@@ -842,6 +843,14 @@ def _initialize_torch_state(self) -> None: # noqa
842843
)
843844
)
844845
else:
846+
shape, stride = create_global_tensor_shape_stride_from_metadata(
847+
none_throws(self.module_sharding_plan[table_name]),
848+
(
849+
self._env.node_group_size
850+
if isinstance(self._env, ShardingEnv2D)
851+
else get_local_size(self._env.world_size)
852+
),
853+
)
845854
# empty shard case
846855
self._model_parallel_name_to_dtensor[table_name] = (
847856
DTensor.from_local(
@@ -851,6 +860,8 @@ def _initialize_torch_state(self) -> None: # noqa
851860
),
852861
device_mesh=self._env.device_mesh,
853862
run_check=False,
863+
shape=shape,
864+
stride=stride,
854865
)
855866
)
856867
else:
@@ -861,7 +872,11 @@ def _initialize_torch_state(self) -> None: # noqa
861872
ShardedTensor._init_from_local_shards(
862873
local_shards,
863874
self._name_to_table_size[table_name],
864-
process_group=self._env.process_group,
875+
process_group=(
876+
self._env.sharding_pg
877+
if isinstance(self._env, ShardingEnv2D)
878+
else self._env.process_group
879+
),
865880
)
866881
)
867882

torchrec/distributed/sharding/tw_sharding.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _shard(
128128
)
129129

130130
dtensor_metadata = None
131-
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
131+
if self._env.output_dtensor:
132132
dtensor_metadata = DTensorMetadata(
133133
mesh=(
134134
self._env.device_mesh["replicate"] # pyre-ignore[16]
@@ -142,8 +142,6 @@ def _shard(
142142
),
143143
stride=info.param.stride(),
144144
)
145-
# to not pass onto TBE
146-
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]
147145

148146
rank = (
149147
# pyre-ignore [16]

torchrec/distributed/sharding/twrw_sharding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515
import torch.distributed as dist
16-
from torch.distributed._tensor import Shard
16+
from torch.distributed._tensor import Replicate, Shard
1717
from torch.distributed.distributed_c10d import get_process_group_ranks
1818
from torchrec.distributed.comm import (
1919
get_local_size,
@@ -165,10 +165,11 @@ def _shard(
165165

166166
dtensor_metadata = None
167167
if self._env.output_dtensor:
168-
placements = (Shard(0),)
169168
dtensor_metadata = DTensorMetadata(
170169
mesh=self._env.device_mesh,
171-
placements=placements,
170+
placements=(
171+
(Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),)
172+
),
172173
size=(
173174
info.embedding_config.num_embeddings,
174175
info.embedding_config.embedding_dim,

0 commit comments

Comments
 (0)