1212on how the EPLB algorithm works.
1313"""
1414
15+ import numpy as np
1516import torch
1617
1718
@@ -34,29 +35,44 @@ def balanced_packing(
3435 assert num_groups % num_packs == 0
3536 groups_per_pack = num_groups // num_packs
3637
38+ device = weight .device
39+
3740 if groups_per_pack == 1 :
3841 pack_index = torch .arange (
39- weight .size (- 1 ), dtype = torch .int64 , device = weight . device
42+ weight .size (- 1 ), dtype = torch .int64 , device = device
4043 ).expand (weight .shape )
41- rank_in_pack = torch .zeros_like (weight , dtype = torch .int64 )
44+ rank_in_pack = torch .zeros_like (weight , dtype = torch .int64 , device = device )
4245 return pack_index , rank_in_pack
4346
44- indices = weight .float ().sort (- 1 , descending = True ).indices .cpu ()
45- pack_index = torch .full_like (weight , fill_value = - 1 , dtype = torch .int64 , device = "cpu" )
46- rank_in_pack = torch .full_like (pack_index , fill_value = - 1 )
47+ weight_np = weight .cpu ().numpy ()
48+
49+ # Sort and get indices in decending order
50+ indices_np = np .argsort (- weight_np , axis = - 1 )
51+
52+ pack_index_np = np .full ((num_layers , num_groups ), - 1 , dtype = np .int64 )
53+ rank_in_pack_np = np .full ((num_layers , num_groups ), - 1 , dtype = np .int64 )
54+
55+ # Run the packing algorithm
4756 for i in range (num_layers ):
48- pack_weights = [0 ] * num_packs
57+ pack_weights = [0.0 ] * num_packs
4958 pack_items = [0 ] * num_packs
50- for group in indices [i ]:
59+
60+ for group in indices_np [i ]:
61+ # Find a pack with capacity that has the lowest weight
5162 pack = min (
52- (i for i in range (num_packs ) if pack_items [i ] < groups_per_pack ),
63+ (j for j in range (num_packs ) if pack_items [j ] < groups_per_pack ),
5364 key = pack_weights .__getitem__ ,
5465 )
66+
5567 assert pack_items [pack ] < groups_per_pack
56- pack_index [i , group ] = pack
57- rank_in_pack [i , group ] = pack_items [pack ]
58- pack_weights [pack ] += weight [i , group ]
68+ pack_index_np [i , group ] = pack
69+ rank_in_pack_np [i , group ] = pack_items [pack ]
70+ pack_weights [pack ] += weight_np [i , group ]
5971 pack_items [pack ] += 1
72+
73+ pack_index = torch .from_numpy (pack_index_np ).to (device )
74+ rank_in_pack = torch .from_numpy (rank_in_pack_np ).to (device )
75+
6076 return pack_index , rank_in_pack
6177
6278
@@ -212,7 +228,7 @@ def rebalance_experts(
212228 replicas for each logical expert
213229 """
214230 num_layers , num_logical_experts = weight .shape
215- weight = weight .float (). cpu ()
231+ weight = weight .float ()
216232 if num_groups % num_nodes == 0 :
217233 # use hierarchical load-balance policy
218234 phy2log , phyrank , logcnt = rebalance_experts_hierarchical (
0 commit comments