Skip to content

Commit 97f8dea

Browse files
Caner Gocmenfacebook-github-bot
Caner Gocmen
authored andcommitted
Reduce unnecessary GreedyPerfPartitioner calls from MemoryBalancedPartitioner (#2914)
Summary: Pull Request resolved: #2914 MemoryBalancedPartitioner works by adjusting the max memory on devices and calling GreedyPerfPartitioner repeatedly. The max memory is adjusted with a binary search procedure to identify a more memory efficient plan than what GreedyPerfPartitioner gives by default. The search boundaries for the binary search procedure were inefficient which this diff addresses. * **Upper bound** * **Before:** Max device HBM (e.g. 80 GB) * **After:** Max HBM usage of the default plan since there is no point in searching for plans that use more max memory than what the default plan uses. * **Lower bound:** * **Before:** [Avg. HBM per Device] = [Total HBM Needed Across All Shards] / [World Size] * **After:** max([Avg. HBM per Device], [Max HBM Needed Across All Shards]). A feasible solution requires at least the max HBM that the biggest shard needs so there is no point in searching for options below that. Making these changes can have impact in two ways: 1. Search procedure is more efficient leading to plans with lower memory 2. We can reduce `search_count` to get comparable plans as before while calling `GreedyPerfPartitioner` less number of times from `MemoryBalancedPartitioner`. The default impact without further changes from #1 should lead to a marginal max memory improvement. Reviewed By: iamzainhuda Differential Revision: D73598477 fbshipit-source-id: 64b001de5a84e5f24afec9684b4602bcbe694e59
1 parent 33aeafa commit 97f8dea

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

torchrec/distributed/planner/partitioners.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -599,17 +599,28 @@ def partition(
599599
default_plan = copy.deepcopy(default_plan)
600600
original_plan_perf = _perf_model.rate(default_plan)
601601

602-
max_hbm_per_device: int = _topology.devices[0].storage.hbm
602+
# compute shard and default plan HBM stats
603+
hbm_by_rank = [0] * _topology.world_size
604+
hbm_requirement: int = 0
605+
max_shard_hbm: int = 0
606+
for sharding_option in default_plan:
607+
for shard in sharding_option.shards:
608+
if shard.storage is not None and shard.rank is not None:
609+
hbm_used = shard.storage.hbm
610+
rank = shard.rank
611+
hbm_by_rank[rank] += hbm_used
612+
hbm_requirement += hbm_used
613+
max_shard_hbm = max(max_shard_hbm, hbm_used)
614+
615+
# Upper bound for the search is the default plan's max HBM usage
616+
max_hbm_per_device: int = max(hbm_by_rank)
603617
logger.info(
604-
f"Default plan uses {round(bytes_to_gb(max_hbm_per_device), 3)} GB per device."
618+
f"Default plan max HBM is {round(bytes_to_gb(max_hbm_per_device), 3)} GB."
605619
)
606620

607-
hbm_requirement: int = 0
608-
for sharding_option in proposal:
609-
for shard in sharding_option.shards:
610-
if shard.storage is not None:
611-
hbm_requirement += shard.storage.hbm
612-
min_hbm_per_device: int = int(hbm_requirement / _topology.world_size)
621+
# Lower bound for the search is the maximum of avg. HBM usage or the biggest shard
622+
avg_hbm_usage: int = int(hbm_requirement / _topology.world_size)
623+
min_hbm_per_device: int = max(avg_hbm_usage, max_shard_hbm)
613624
logger.info(
614625
"Searching in the range (min_hbm_per_device, max_hbm_per_device): "
615626
f"({round(bytes_to_gb(min_hbm_per_device), 3)}, "
@@ -660,7 +671,7 @@ def partition(
660671
max_hbm_per_device = mid_hbm_per_device
661672
except PlannerError:
662673
logger.info(
663-
f"Couldn't find a plan with {round(bytes_to_gb(max_hbm_per_device), 3)} "
674+
f"Couldn't find a plan with {round(bytes_to_gb(mid_hbm_per_device), 3)} "
664675
f"GB per device for embedding tables."
665676
)
666677
min_hbm_per_device = mid_hbm_per_device

0 commit comments

Comments
 (0)