Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
from torch.distributed._tensor import DTensor
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.embedding_sharding import (
EmbeddingSharding,
EmbeddingShardingInfo,
Expand Down Expand Up @@ -69,13 +70,16 @@
QuantizedCommCodecs,
ShardedTensor,
ShardingEnv,
ShardingEnv2D,
ShardMetadata,
)
from torchrec.distributed.utils import (
add_params_from_parameter_sharding,
convert_to_fbgemm_types,
create_global_tensor_shape_stride_from_metadata,
maybe_annotate_embedding_event,
merge_fused_params,
none_throws,
optimizer_type_to_emb_opt_type,
)
from torchrec.modules.embedding_configs import (
Expand Down Expand Up @@ -534,12 +538,9 @@ def __init__(
if table_name in self._table_names
},
)
# output parameters as DTensor in state dict
self._output_dtensor: bool = (
fused_params.get("output_dtensor", False) if fused_params else False
)

self._env = env
# output parameters as DTensor in state dict
self._output_dtensor: bool = env.output_dtensor
# TODO get rid of get_ec_index_dedup global flag
self._use_index_dedup: bool = use_index_dedup or get_ec_index_dedup()
sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
Expand Down Expand Up @@ -842,6 +843,14 @@ def _initialize_torch_state(self) -> None: # noqa
)
)
else:
shape, stride = create_global_tensor_shape_stride_from_metadata(
none_throws(self.module_sharding_plan[table_name]),
(
self._env.node_group_size
if isinstance(self._env, ShardingEnv2D)
else get_local_size(self._env.world_size)
),
)
# empty shard case
self._model_parallel_name_to_dtensor[table_name] = (
DTensor.from_local(
Expand All @@ -851,6 +860,8 @@ def _initialize_torch_state(self) -> None: # noqa
),
device_mesh=self._env.device_mesh,
run_check=False,
shape=shape,
stride=stride,
)
)
else:
Expand All @@ -861,7 +872,11 @@ def _initialize_torch_state(self) -> None: # noqa
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=self._env.process_group,
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group
),
)
)

Expand Down
7 changes: 4 additions & 3 deletions torchrec/distributed/sharding/twrw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch
import torch.distributed as dist
from torch.distributed._tensor import Shard
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.distributed_c10d import get_process_group_ranks
from torchrec.distributed.comm import (
get_local_size,
Expand Down Expand Up @@ -165,10 +165,11 @@ def _shard(

dtensor_metadata = None
if self._env.output_dtensor:
placements = (Shard(0),)
dtensor_metadata = DTensorMetadata(
mesh=self._env.device_mesh,
placements=placements,
placements=(
(Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),)
),
size=(
info.embedding_config.num_embeddings,
info.embedding_config.embedding_dim,
Expand Down
Loading
Loading