File tree Expand file tree Collapse file tree 5 files changed +22
-0
lines changed Expand file tree Collapse file tree 5 files changed +22
-0
lines changed Original file line number Diff line number Diff line change @@ -565,6 +565,7 @@ def _group_tables_per_rank(
565
565
),
566
566
_prefetch_and_cached (table ),
567
567
table .use_virtual_table if is_inference else None ,
568
+ table .enable_embedding_update ,
568
569
)
569
570
# micromanage the order of we traverse the groups to ensure backwards compatibility
570
571
if grouping_key not in groups :
@@ -581,6 +582,7 @@ def _group_tables_per_rank(
581
582
_ ,
582
583
_ ,
583
584
use_virtual_table ,
585
+ enable_embedding_update ,
584
586
) = grouping_key
585
587
grouped_tables = groups [grouping_key ]
586
588
# remove non-native fused params
@@ -602,6 +604,7 @@ def _group_tables_per_rank(
602
604
compute_kernel = compute_kernel_type ,
603
605
embedding_tables = grouped_tables ,
604
606
fused_params = per_tbe_fused_params ,
607
+ enable_embedding_update = enable_embedding_update ,
605
608
)
606
609
)
607
610
return grouped_embedding_configs
Original file line number Diff line number Diff line change @@ -251,6 +251,8 @@ class GroupedEmbeddingConfig:
251
251
compute_kernel : EmbeddingComputeKernel
252
252
embedding_tables : List [ShardedEmbeddingTable ]
253
253
fused_params : Optional [Dict [str , Any ]] = None
254
+ # Write-enabled Embedding Tables cannot be grouped with read-only Embedding Tables TBE needs to be separate.
255
+ enable_embedding_update : bool = False
254
256
255
257
def feature_hash_sizes (self ) -> List [int ]:
256
258
feature_hash_sizes = []
Original file line number Diff line number Diff line change @@ -223,6 +223,7 @@ def _shard(
223
223
total_num_buckets = info .embedding_config .total_num_buckets ,
224
224
use_virtual_table = info .embedding_config .use_virtual_table ,
225
225
virtual_table_eviction_policy = info .embedding_config .virtual_table_eviction_policy ,
226
+ enable_embedding_update = info .embedding_config .enable_embedding_update ,
226
227
)
227
228
)
228
229
return tables_per_rank
@@ -278,6 +279,20 @@ def _get_feature_hash_sizes(self) -> List[int]:
278
279
feature_hash_sizes .extend (group_config .feature_hash_sizes ())
279
280
return feature_hash_sizes
280
281
282
+ def _get_num_writable_features (self ) -> int :
283
+ return sum (
284
+ group_config .num_features ()
285
+ for group_config in self ._grouped_embedding_configs
286
+ if group_config .enable_embedding_update
287
+ )
288
+
289
+ def _get_writable_feature_hash_sizes (self ) -> List [int ]:
290
+ feature_hash_sizes : List [int ] = []
291
+ for group_config in self ._grouped_embedding_configs :
292
+ if group_config .enable_embedding_update :
293
+ feature_hash_sizes .extend (group_config .feature_hash_sizes ())
294
+ return feature_hash_sizes
295
+
281
296
282
297
class RwSparseFeaturesDist (BaseSparseFeaturesDist [KeyedJaggedTensor ]):
283
298
"""
Original file line number Diff line number Diff line change @@ -370,6 +370,7 @@ class BaseEmbeddingConfig:
370
370
total_num_buckets : Optional [int ] = None
371
371
use_virtual_table : bool = False
372
372
virtual_table_eviction_policy : Optional [VirtualTableEvictionPolicy ] = None
373
+ enable_embedding_update : bool = False
373
374
374
375
def get_weight_init_max (self ) -> float :
375
376
if self .weight_init_max is None :
Original file line number Diff line number Diff line change @@ -43,6 +43,7 @@ class StableEmbeddingBagConfig:
43
43
total_num_buckets : Optional [int ] = None
44
44
use_virtual_table : bool = False
45
45
virtual_table_eviction_policy : Optional [VirtualTableEvictionPolicy ] = None
46
+ enable_embedding_update : bool = False
46
47
pooling : PoolingType = PoolingType .SUM
47
48
48
49
You can’t perform that action at this time.
0 commit comments