Skip to content

Commit e326403

Browse files
author
pytorchbot
committed
2025-09-24 nightly release (c26367f)
1 parent 0c7e33c commit e326403

File tree

5 files changed

+22
-0
lines changed

5 files changed

+22
-0
lines changed

torchrec/distributed/embedding_sharding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ def _group_tables_per_rank(
565565
),
566566
_prefetch_and_cached(table),
567567
table.use_virtual_table if is_inference else None,
568+
table.enable_embedding_update,
568569
)
569570
# micromanage the order of we traverse the groups to ensure backwards compatibility
570571
if grouping_key not in groups:
@@ -581,6 +582,7 @@ def _group_tables_per_rank(
581582
_,
582583
_,
583584
use_virtual_table,
585+
enable_embedding_update,
584586
) = grouping_key
585587
grouped_tables = groups[grouping_key]
586588
# remove non-native fused params
@@ -602,6 +604,7 @@ def _group_tables_per_rank(
602604
compute_kernel=compute_kernel_type,
603605
embedding_tables=grouped_tables,
604606
fused_params=per_tbe_fused_params,
607+
enable_embedding_update=enable_embedding_update,
605608
)
606609
)
607610
return grouped_embedding_configs

torchrec/distributed/embedding_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ class GroupedEmbeddingConfig:
251251
compute_kernel: EmbeddingComputeKernel
252252
embedding_tables: List[ShardedEmbeddingTable]
253253
fused_params: Optional[Dict[str, Any]] = None
254+
# Write-enabled Embedding Tables cannot be grouped with read-only Embedding Tables TBE needs to be separate.
255+
enable_embedding_update: bool = False
254256

255257
def feature_hash_sizes(self) -> List[int]:
256258
feature_hash_sizes = []

torchrec/distributed/sharding/rw_sharding.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def _shard(
223223
total_num_buckets=info.embedding_config.total_num_buckets,
224224
use_virtual_table=info.embedding_config.use_virtual_table,
225225
virtual_table_eviction_policy=info.embedding_config.virtual_table_eviction_policy,
226+
enable_embedding_update=info.embedding_config.enable_embedding_update,
226227
)
227228
)
228229
return tables_per_rank
@@ -278,6 +279,20 @@ def _get_feature_hash_sizes(self) -> List[int]:
278279
feature_hash_sizes.extend(group_config.feature_hash_sizes())
279280
return feature_hash_sizes
280281

282+
def _get_num_writable_features(self) -> int:
283+
return sum(
284+
group_config.num_features()
285+
for group_config in self._grouped_embedding_configs
286+
if group_config.enable_embedding_update
287+
)
288+
289+
def _get_writable_feature_hash_sizes(self) -> List[int]:
290+
feature_hash_sizes: List[int] = []
291+
for group_config in self._grouped_embedding_configs:
292+
if group_config.enable_embedding_update:
293+
feature_hash_sizes.extend(group_config.feature_hash_sizes())
294+
return feature_hash_sizes
295+
281296

282297
class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]):
283298
"""

torchrec/modules/embedding_configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ class BaseEmbeddingConfig:
370370
total_num_buckets: Optional[int] = None
371371
use_virtual_table: bool = False
372372
virtual_table_eviction_policy: Optional[VirtualTableEvictionPolicy] = None
373+
enable_embedding_update: bool = False
373374

374375
def get_weight_init_max(self) -> float:
375376
if self.weight_init_max is None:

torchrec/schema/api_tests/test_embedding_config_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class StableEmbeddingBagConfig:
4343
total_num_buckets: Optional[int] = None
4444
use_virtual_table: bool = False
4545
virtual_table_eviction_policy: Optional[VirtualTableEvictionPolicy] = None
46+
enable_embedding_update: bool = False
4647
pooling: PoolingType = PoolingType.SUM
4748

4849

0 commit comments

Comments
 (0)