Skip to content

Commit 798c7be

Browse files
authored
[EPLB] Refactor balance_packing to use numpy and optimize GPU-CPU transfers in EPLB (#28369)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
1 parent 4fd4b74 commit 798c7be

File tree

2 files changed

+37
-17
lines changed

2 files changed

+37
-17
lines changed

vllm/distributed/eplb/rebalance_algo.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
on how the EPLB algorithm works.
1313
"""
1414

15+
import numpy as np
1516
import torch
1617

1718

@@ -34,29 +35,44 @@ def balanced_packing(
3435
assert num_groups % num_packs == 0
3536
groups_per_pack = num_groups // num_packs
3637

38+
device = weight.device
39+
3740
if groups_per_pack == 1:
3841
pack_index = torch.arange(
39-
weight.size(-1), dtype=torch.int64, device=weight.device
42+
weight.size(-1), dtype=torch.int64, device=device
4043
).expand(weight.shape)
41-
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
44+
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
4245
return pack_index, rank_in_pack
4346

44-
indices = weight.float().sort(-1, descending=True).indices.cpu()
45-
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu")
46-
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
47+
weight_np = weight.cpu().numpy()
48+
49+
# Sort and get indices in decending order
50+
indices_np = np.argsort(-weight_np, axis=-1)
51+
52+
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
53+
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
54+
55+
# Run the packing algorithm
4756
for i in range(num_layers):
48-
pack_weights = [0] * num_packs
57+
pack_weights = [0.0] * num_packs
4958
pack_items = [0] * num_packs
50-
for group in indices[i]:
59+
60+
for group in indices_np[i]:
61+
# Find a pack with capacity that has the lowest weight
5162
pack = min(
52-
(i for i in range(num_packs) if pack_items[i] < groups_per_pack),
63+
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
5364
key=pack_weights.__getitem__,
5465
)
66+
5567
assert pack_items[pack] < groups_per_pack
56-
pack_index[i, group] = pack
57-
rank_in_pack[i, group] = pack_items[pack]
58-
pack_weights[pack] += weight[i, group]
68+
pack_index_np[i, group] = pack
69+
rank_in_pack_np[i, group] = pack_items[pack]
70+
pack_weights[pack] += weight_np[i, group]
5971
pack_items[pack] += 1
72+
73+
pack_index = torch.from_numpy(pack_index_np).to(device)
74+
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
75+
6076
return pack_index, rank_in_pack
6177

6278

@@ -212,7 +228,7 @@ def rebalance_experts(
212228
replicas for each logical expert
213229
"""
214230
num_layers, num_logical_experts = weight.shape
215-
weight = weight.float().cpu()
231+
weight = weight.float()
216232
if num_groups % num_nodes == 0:
217233
# use hierarchical load-balance policy
218234
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(

vllm/distributed/eplb/rebalance_execute.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -321,15 +321,19 @@ def rearrange_expert_weights_inplace(
321321
)
322322
return
323323

324+
old_global_expert_indices_cpu = old_global_expert_indices.cpu()
325+
new_global_expert_indices_cpu = new_global_expert_indices.cpu()
326+
327+
# NOTE(bowen): We need this synchronize to run, but I don't know why.
328+
# If you figure out the reason, please let me know -- thank you!
329+
torch.cuda.synchronize()
330+
324331
for layer in range(num_moe_layers):
325-
# NOTE(bowen): We need this synchronize to run, but I don't know why.
326-
# If you figure out the reason, please let me know -- thank you!
327-
torch.cuda.synchronize()
328332
shuffle_layer(
329333
num_local_physical_experts,
330334
ep_rank,
331-
old_global_expert_indices[layer].tolist(),
332-
new_global_expert_indices[layer].tolist(),
335+
old_global_expert_indices_cpu[layer].tolist(),
336+
new_global_expert_indices_cpu[layer].tolist(),
333337
expert_weights[layer],
334338
expert_weights_buffer,
335339
ep_group,

0 commit comments

Comments
 (0)