Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
364 changes: 364 additions & 0 deletions torchrec/distributed/batched_embedding_kernel.py

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def __init__(
module_fqn: Optional[str] = None,
sharding_types: Optional[List[str]] = None,
use_gather_select: bool = False,
resize_awaitables: Optional[List[Awaitable[torch.Tensor]]] = None,
) -> None:
super().__init__()
self._awaitables_per_sharding = awaitables_per_sharding
Expand All @@ -354,6 +355,7 @@ def __init__(
self._module_fqn = module_fqn
self._sharding_types = sharding_types
self._use_gather_select = use_gather_select
self._resize_awaitables = resize_awaitables

def _wait_impl(self) -> Dict[str, JaggedTensor]:
jt_dict: Dict[str, JaggedTensor] = {}
Expand Down Expand Up @@ -398,6 +400,12 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
use_gather_select=self._use_gather_select,
)
)

# free memory and resize
# pyre-ignore[16]
for awaitable in self._resize_awaitables:
awaitable.wait()

return jt_dict


Expand Down Expand Up @@ -1588,6 +1596,8 @@ def compute_and_output_dist(
) -> LazyAwaitable[Dict[str, JaggedTensor]]:
awaitables_per_sharding: List[Awaitable[torch.Tensor]] = []
features_before_all2all_per_sharding: List[KeyedJaggedTensor] = []
resize_awaitables = []

for lookup, odist, features, sharding_ctx, sharding_type in zip(
self._lookups,
self._output_dists,
Expand All @@ -1604,6 +1614,9 @@ def compute_and_output_dist(
EmbeddingEvent.LOOKUP, self._module_fqn, sharding_type
):
embs = lookup(features)
if hasattr(lookup, "get_resize_awaitables"):
# pyre-ignore[29]
resize_awaitables.extend(lookup.get_resize_awaitables())
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(features, embs, self, None)

Expand Down Expand Up @@ -1631,6 +1644,7 @@ def compute_and_output_dist(
module_fqn=self._module_fqn,
sharding_types=list(self._sharding_type_to_sharding.keys()),
use_gather_select=self._use_gather_select,
resize_awaitables=resize_awaitables,
)

def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
Expand Down
84 changes: 70 additions & 14 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
BatchedFusedEmbeddingBag,
KeyValueEmbedding,
KeyValueEmbeddingBag,
ShardedBatchedFusedEmbedding,
ShardedBatchedFusedEmbeddingBag,
ZeroCollisionEmbeddingCache,
ZeroCollisionKeyValueEmbedding,
ZeroCollisionKeyValueEmbeddingBag,
Expand All @@ -65,7 +67,15 @@
QuantBatchedEmbedding,
QuantBatchedEmbeddingBag,
)
from torchrec.distributed.types import rank_device, ShardedTensor, ShardingType
from torchrec.distributed.types import (
LazyAwaitable,
rank_device,
ShardedTensor,
ShardingEnv,
ShardingEnv2D,
ShardingStrategy,
ShardingType,
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -185,12 +195,15 @@ def __init__(
grouped_configs: List[GroupedEmbeddingConfig],
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
env: Optional[ShardingEnv] = None,
) -> None:
super().__init__()
self._emb_modules: nn.ModuleList = nn.ModuleList()
self._need_prefetch: bool = False
for config in grouped_configs:
self._emb_modules.append(self._create_embedding_kernel(config, pg, device))
self._emb_modules.append(
self._create_embedding_kernel(config, pg, device, env)
)

self._feature_splits: List[int] = []
for config in grouped_configs:
Expand Down Expand Up @@ -218,6 +231,7 @@ def _create_embedding_kernel(
config: GroupedEmbeddingConfig,
pg: Optional[dist.ProcessGroup],
device: Optional[torch.device],
env: Optional[ShardingEnv] = None,
) -> BaseEmbedding:
for table in config.embedding_tables:
if (
Expand All @@ -234,11 +248,20 @@ def _create_embedding_kernel(
device=device,
)
elif config.compute_kernel == EmbeddingComputeKernel.FUSED:
return BatchedFusedEmbedding(
config=config,
pg=pg,
device=device,
)
if (
env
and isinstance(env, ShardingEnv2D)
and env.sharding_strategy == ShardingStrategy.FULLY_SHARDED
):
return ShardedBatchedFusedEmbedding(
config=config, pg=pg, device=device, env=env
)
else:
return BatchedFusedEmbedding(
config=config,
pg=pg,
device=device,
)
elif config.compute_kernel == EmbeddingComputeKernel.KEY_VALUE:
return KeyValueEmbedding(
config=config,
Expand Down Expand Up @@ -329,6 +352,14 @@ def forward(

return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor)

def get_resize_awaitables(self) -> List[LazyAwaitable[torch.Tensor]]:
# TODO - we can probably do some smart grouping to make this more efficient
return [
emb_module.get_rs_awaitable() # pyre-ignore[29]
for emb_module in self._emb_modules
if hasattr(emb_module, "get_rs_awaitable")
]

# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
def state_dict(
self,
Expand Down Expand Up @@ -512,12 +543,14 @@ def __init__(
feature_processor: Optional[BaseGroupedFeatureProcessor] = None,
scale_weight_gradients: bool = True,
sharding_type: Optional[ShardingType] = None,
env: Optional[ShardingEnv] = None,
) -> None:
super().__init__()
self._env = env
self._emb_modules: nn.ModuleList = nn.ModuleList()
for config in grouped_configs:
self._emb_modules.append(
self._create_embedding_kernel(config, device, pg, sharding_type)
self._create_embedding_kernel(config, device, pg, sharding_type, env)
)

self._feature_splits: List[int] = []
Expand Down Expand Up @@ -555,6 +588,7 @@ def _create_embedding_kernel(
device: Optional[torch.device],
pg: Optional[dist.ProcessGroup],
sharding_type: Optional[ShardingType],
env: Optional[ShardingEnv],
) -> BaseEmbedding:
if config.compute_kernel == EmbeddingComputeKernel.DENSE:
return BatchedDenseEmbeddingBag(
Expand All @@ -564,12 +598,26 @@ def _create_embedding_kernel(
sharding_type=sharding_type,
)
elif config.compute_kernel == EmbeddingComputeKernel.FUSED:
return BatchedFusedEmbeddingBag(
config=config,
pg=pg,
device=device,
sharding_type=sharding_type,
)
if (
env
and isinstance(env, ShardingEnv2D)
and env.sharding_strategy == ShardingStrategy.FULLY_SHARDED
):
return ShardedBatchedFusedEmbeddingBag(
config=config,
pg=pg,
device=device,
sharding_type=sharding_type,
env=env,
)
else:
return BatchedFusedEmbeddingBag(
config=config,
pg=pg,
device=device,
sharding_type=sharding_type,
env=env,
)
elif config.compute_kernel in {
EmbeddingComputeKernel.KEY_VALUE,
}:
Expand Down Expand Up @@ -744,6 +792,14 @@ def forward(
dim=1,
)

def get_resize_awaitables(self) -> List[LazyAwaitable[torch.Tensor]]:
# TODO - we can probably do some smart grouping to make this more efficient
return [
emb_module.get_rs_awaitable() # pyre-ignore[29]
for emb_module in self._emb_modules
if hasattr(emb_module, "get_rs_awaitable")
]

# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
def state_dict(
self,
Expand Down
14 changes: 14 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,15 @@ def __init__(
embedding_names: List[str],
module_fqn: Optional[str] = None,
sharding_types: Optional[List[str]] = None,
resize_awaitables: Optional[List[Awaitable[torch.Tensor]]] = None,
) -> None:
super().__init__()
self._awaitables = awaitables
self._embedding_dims = embedding_dims
self._embedding_names = embedding_names
self._module_fqn = module_fqn
self._sharding_types = sharding_types
self._resize_awaitables = resize_awaitables

def _wait_impl(self) -> KeyedTensor:
embeddings = []
Expand All @@ -425,6 +427,12 @@ def _wait_impl(self) -> KeyedTensor:
):
embeddings.append(w.wait())

# free memory and resize
if self._resize_awaitables is not None:
# pyre-ignore[16]
for awaitable in self._resize_awaitables:
awaitable.wait()

return construct_output_kt(
embeddings=embeddings,
embedding_names=self._embedding_names,
Expand Down Expand Up @@ -1655,6 +1663,7 @@ def compute_and_output_dist(
"""
batch_size_per_feature_pre_a2a = []
awaitables = []
resize_awaitables = []

# No usage of zip for dynamo
for i in range(len(self._lookups)):
Expand All @@ -1669,7 +1678,11 @@ def compute_and_output_dist(
self._module_fqn,
sharding_type,
):
# with fully sharded 2D enabled, it returns an awaitable for the reduce scatter and resize operation
embs = lookup(features)
if hasattr(lookup, "get_resize_awaitables"):
# pyre-ignore[29]
resize_awaitables.extend(lookup.get_resize_awaitables())
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(features, embs, self, None)

Expand Down Expand Up @@ -1710,6 +1723,7 @@ def compute_and_output_dist(
embedding_names=self._embedding_names,
module_fqn=self._module_fqn,
sharding_types=self._sharding_types,
resize_awaitables=resize_awaitables,
)

# register callback if there are features that need mean pooling
Expand Down
Loading
Loading