Skip to content

Commit dbad57e

Browse files
levythufacebook-github-bot
authored andcommitted
Count shard state in HBM usage (pytorch#203)
Summary: X-link: pytorch/torchrec#2380 Pull Request resolved: facebookresearch/FBGEMM#203 X-link: pytorch#3114 This PR improve sparse HBM cost by accounting the size of auxilirary state for maintaining UVM cache. As noted in the comment of split_table_batched_embeddings_ops_training, for now the significant space is `4 * hash_size + 8 * cache_slot_size + 8 * cache_slot_size`. This is becoming more nontrivial if we have a table with many rows but few dimensions. Impact: - Not UVM-offloaded job: NoOp - UVM-offloaded job: More balanced memory usage from precise estimation, but for existing UVM jobs with scale up proposer + fixed percentage reservation this might lead to scale up proposer making less aggressive cache scale-up and therefore leading to worse performance. In this case we should tune to more slack reservation percentage . Reviewed By: sarckk Differential Revision: D61576911 fbshipit-source-id: 6b501dc63cbe86c5274661b1d985af6a7a0a87c6
1 parent f85dfb8 commit dbad57e

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,6 +2286,7 @@ def _apply_cache_state(
22862286
)
22872287

22882288
self.total_cache_hash_size = cache_state.total_cache_hash_size
2289+
# 8x of # tables, trivial size
22892290
self.register_buffer(
22902291
"cache_hash_size_cumsum",
22912292
torch.tensor(
@@ -2294,6 +2295,7 @@ def _apply_cache_state(
22942295
dtype=torch.int64,
22952296
),
22962297
)
2298+
# 4x total embedding hash size with uvm cache
22972299
self.register_buffer(
22982300
"cache_index_table_map",
22992301
torch.tensor(
@@ -2302,12 +2304,14 @@ def _apply_cache_state(
23022304
dtype=torch.int32,
23032305
),
23042306
)
2307+
# 8x of total cache slots (embedding hash size * clf)
23052308
self.register_buffer(
23062309
"lxu_cache_state",
23072310
torch.zeros(
23082311
cache_sets, DEFAULT_ASSOC, device=self.current_device, dtype=torch.int64
23092312
).fill_(-1),
23102313
)
2314+
# Cache itself, not auxiliary size
23112315
self.register_buffer(
23122316
"lxu_cache_weights",
23132317
torch.zeros(
@@ -2317,6 +2321,8 @@ def _apply_cache_state(
23172321
dtype=dtype,
23182322
),
23192323
)
2324+
# LRU: 8x of total cache slots (embedding hash size * clf)
2325+
# LFU: 8x of total embedding hash size with uvm cache
23202326
self.register_buffer(
23212327
"lxu_state",
23222328
torch.zeros(

0 commit comments

Comments
 (0)